diff --git a/README.md b/README.md
index 65ee734668..71d0cc5557 100644
--- a/README.md
+++ b/README.md
@@ -112,6 +112,8 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
| 天气预报 | [FourCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fourcastnet) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
| 天气预报 | [NowCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/nowcastnet) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
| 天气预报 | [GraphCast 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/graphcast) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
+| 天气预报 | [FengWu 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fengwu) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2304.02948) |
+| 天气预报 | [Pangu-Weather 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/pangu_weather) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) |
| 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - |
| 天气预报 | [DGMR 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/dgmr.md) | 数据驱动 | DGMR | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) |
| 地震波形反演 | [VelocityGAN 地震波形反演](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) |
diff --git a/docs/index.md b/docs/index.md
index 602c238c2c..732aeca6f0 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -148,6 +148,8 @@
| 天气预报 | [FourCastNet 气象预报](./zh/examples/fourcastnet.md) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
| 天气预报 | [NowCastNet 气象预报](./zh/examples/nowcastnet.md) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
| 天气预报 | [GraphCast 气象预报](./zh/examples/graphcast.md) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
+| 天气预报 | [FengWu 气象预报](./zh/examples/fengwu.md) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2304.02948) |
+| 天气预报 | [Pangu-Weather 气象预报](./zh/examples/pangu_weather.md) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) |
| 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - |
| 天气预报 | [DGMR 气象预报](./zh/examples/dgmr.md) | 数据驱动 | DGMR | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) |
| 地震波形反演 | [VelocityGAN 地震波形反演](./zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) |
diff --git a/docs/zh/examples/fengwu.md b/docs/zh/examples/fengwu.md
new file mode 100644
index 0000000000..2466b62924
--- /dev/null
+++ b/docs/zh/examples/fengwu.md
@@ -0,0 +1,101 @@
+# FengWu
+
+=== "模型训练命令"
+
+ 暂无
+
+=== "模型评估命令"
+
+ 暂无
+
+=== "模型导出命令"
+
+ 暂无
+
+=== "模型推理命令"
+
+ ``` sh
+ # Download sample input data
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/input1.npy -P ./data
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/input2.npy -P ./data
+
+ # Download pretrain model weight
+ wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/fengwu_v2.onnx -P ./inference
+
+ # inference
+ python predict.py
+ ```
+
+## 1. 背景简介
+
+随着近年来全球气候变化加剧,极端天气频发,各界对天气预报的时效和精度的期待更是与日俱增。如何提高天气预报的时效和准确度,一直是业内的重点课题。AI大模型“风乌”基于多模态和多任务深度学习方法构建,实现在高分辨率上对核心大气变量进行超过10天的有效预报,并在80%的评估指标上超越DeepMind发布的模型GraphCast。同时,“风乌”仅需30秒即可生成未来10天全球高精度预报结果,在效率上大幅优于传统模型。
+
+## 2. 模型原理
+
+本章节仅对风乌气象大模型的原理进行简单地介绍,详细的理论推导请阅读 [FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead](https://arxiv.org/pdf/2304.02948)。
+
+模型的总体结构如图所示:
+
+
+ { loading=lazy style="margin:0 auto;"}
+ 模型结构
+
+
+模型将气候变量作为不同模态的输入。在 `Modal-Customized Encoder` 中将多个模态的特征进行编码,并使用基于 Transformer 的 `Cross-modal Fuser` 对编码后的特征进行融合,得到联合表示,最后在 `Modal-Customized Decoder` 中从联合表示中分别预测气候变量。
+
+模型使用预训练权重推理,接下来将介绍模型的推理过程。
+
+## 3. 模型构建
+
+在该案例中,实现了 FengWuPredictor用于ONNX模型的推理:
+
+``` py linenums="74" title="examples/fengwu/predict.py"
+--8<--
+examples/fengwu/predict.py:74:130
+--8<--
+```
+
+``` yaml linenums="28" title="examples/fengwu/conf/fengwu.yaml"
+--8<--
+examples/fengwu/conf/fengwu.yaml:28:46
+--8<--
+```
+
+其中,`input_file` 和 `input_next_file` 分别代表网络模型输入的开始时刻气象数据和6小时后的气象数据。
+
+## 4. 结果可视化
+
+模型推理结果包含 56 个 npy 文件,表示从预测时间点开始,未来 14 天内每隔6小时的气象数据。结果可视化需要先将数据从 npy 转换为 NetCDF 格式,然后采用 ncvue 进行查看。
+
+1. 安装相关依赖
+```python
+pip install cdsapi netCDF4 ncvue
+```
+
+2. 使用脚本进行数据转换
+```python
+python convert_data.py
+```
+
+3. 使用 ncvue 打开转换后的 NetCDF 文件, ncvue 具体说明见[ncvue官方文档](https://github.com/mcuntz/ncvue)
+
+## 5. 完整代码
+
+``` py linenums="1" title="examples/fengwu/predict.py"
+--8<--
+examples/fengwu/predict.py
+--8<--
+```
+
+## 6. 结果展示
+
+下图展示了模型的未来6小时平均海平面气压预测结果,更多指标可以使用 ncvue 查看。
+
+
+ { loading=lazy style="margin:0 auto;"}
+ 未来6小时平均海平面气压
+
+
+## 7. 参考资料
+
+- [FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead](https://arxiv.org/pdf/2304.02948)
diff --git a/examples/fengwu/conf/fengwu.yaml b/examples/fengwu/conf/fengwu.yaml
new file mode 100644
index 0000000000..2af18f3a7e
--- /dev/null
+++ b/examples/fengwu/conf/fengwu.yaml
@@ -0,0 +1,46 @@
+defaults:
+ - ppsci_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ dir: ./outputs_fengwu
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2023
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+# inference settings
+INFER:
+ pretrained_model_path: null
+ export_path: inference/fengwu_v2
+ onnx_path: ${INFER.export_path}.onnx
+ device: gpu
+ engine: onnx
+ precision: fp32
+ ir_optim: false
+ min_subgraph_size: 30
+ gpu_mem: 100
+ gpu_id: 0
+ max_batch_size: 1
+ num_cpu_threads: 10
+ batch_size: 1
+ mean_path: ./data_mean.npy
+ std_path: ./data_std.npy
+ input_file: './data/input1.npy'
+ input_next_file: './data/input2.npy'
diff --git a/examples/fengwu/convert_data.py b/examples/fengwu/convert_data.py
new file mode 100644
index 0000000000..90c98c40ff
--- /dev/null
+++ b/examples/fengwu/convert_data.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ref: https://github.com/HaxyMoly/Pangu-Weather-ReadyToGo/blob/main/forecast_decode_functions.py
+
+import os
+from os import path as osp
+from typing import Dict
+
+import hydra
+import netCDF4 as nc
+import numpy as np
+
+from ppsci.utils import logger
+
+
+def convert_surface_data_to_nc(
+ surface_file: str, file_name: str, output_dir: str
+) -> None:
+ surface_data = np.load(surface_file)
+ u_component_of_wind_10m = surface_data[0]
+ v_component_of_wind_10m = surface_data[1]
+ temperature_2m = surface_data[2]
+ mean_sea_level_pressure = surface_data[3]
+
+ with nc.Dataset(
+ os.path.join(output_dir, file_name), "w", format="NETCDF4_CLASSIC"
+ ) as nc_file:
+ # Create dimensions
+ nc_file.createDimension("longitude", 1440)
+ nc_file.createDimension("latitude", 721)
+
+ # Create variables
+ nc_lon = nc_file.createVariable("longitude", np.float32, ("longitude",))
+ nc_lat = nc_file.createVariable("latitude", np.float32, ("latitude",))
+ nc_msl = nc_file.createVariable(
+ "mean_sea_level_pressure", np.float32, ("latitude", "longitude")
+ )
+ nc_u10 = nc_file.createVariable(
+ "u_component_of_wind_10m", np.float32, ("latitude", "longitude")
+ )
+ nc_v10 = nc_file.createVariable(
+ "v_component_of_wind_10m", np.float32, ("latitude", "longitude")
+ )
+ nc_t2m = nc_file.createVariable(
+ "temperature_2m", np.float32, ("latitude", "longitude")
+ )
+
+ # Set variable attributes
+ nc_lon.units = "degrees_east"
+ nc_lat.units = "degrees_north"
+ nc_msl.units = "Pa"
+ nc_u10.units = "m/s"
+ nc_v10.units = "m/s"
+ nc_t2m.units = "K"
+
+ # Write data to variables
+ nc_lon[:] = np.linspace(0.125, 359.875, 1440)
+ nc_lat[:] = np.linspace(90, -90, 721)
+ nc_msl[:] = mean_sea_level_pressure
+ nc_u10[:] = u_component_of_wind_10m
+ nc_v10[:] = v_component_of_wind_10m
+ nc_t2m[:] = temperature_2m
+
+ logger.info(
+ f"Convert output surface data file {surface_file} as nc format and save to {output_dir}/{file_name}."
+ )
+
+
+def convert_upper_data_to_nc(upper_file: str, file_name: str, output_dir: str) -> None:
+ # Load the saved numpy arrays
+ upper_data = np.load(upper_file)
+
+ # surface data offset
+ st = 4
+ level = 13
+
+ geopotential = upper_data[st : st + level]
+ specific_humidity = upper_data[st + level : st + 2 * level]
+ u_component_of_wind = upper_data[st + 2 * level : st + 3 * level]
+ v_component_of_wind = upper_data[st + 3 * level : st + 4 * level]
+ temperature = upper_data[st + 4 * level :]
+
+ with nc.Dataset(
+ os.path.join(output_dir, file_name), "w", format="NETCDF4_CLASSIC"
+ ) as nc_file:
+ # Create dimensions
+ nc_file.createDimension("longitude", 1440)
+ nc_file.createDimension("latitude", 721)
+ nc_file.createDimension("level", level)
+
+ # Create variables
+ nc_lon = nc_file.createVariable("longitude", np.float32, ("longitude",))
+ nc_lat = nc_file.createVariable("latitude", np.float32, ("latitude",))
+ nc_geopotential = nc_file.createVariable(
+ "geopotential", np.float32, ("level", "latitude", "longitude")
+ )
+ nc_specific_humidity = nc_file.createVariable(
+ "specific_humidity", np.float32, ("level", "latitude", "longitude")
+ )
+ nc_temperature = nc_file.createVariable(
+ "temperature", np.float32, ("level", "latitude", "longitude")
+ )
+ nc_u_component_of_wind = nc_file.createVariable(
+ "u_component_of_wind", np.float32, ("level", "latitude", "longitude")
+ )
+ nc_v_component_of_wind = nc_file.createVariable(
+ "v_component_of_wind", np.float32, ("level", "latitude", "longitude")
+ )
+
+ # Set variable attributes
+ nc_lon.units = "degrees_east"
+ nc_lat.units = "degrees_north"
+ nc_geopotential.units = "m"
+ nc_specific_humidity.units = "kg/kg"
+ nc_temperature.units = "K"
+ nc_u_component_of_wind.units = "m/s"
+ nc_v_component_of_wind.units = "m/s"
+ # Write data to variables
+ nc_lon[:] = np.linspace(0.125, 359.875, 1440)
+ nc_lat[:] = np.linspace(90, -90, 721)
+ nc_geopotential[:] = geopotential
+ nc_specific_humidity[:] = specific_humidity
+ nc_temperature[:] = temperature
+ nc_u_component_of_wind[:] = u_component_of_wind
+ nc_v_component_of_wind[:] = v_component_of_wind
+
+ logger.info(
+ f"Convert output upper data file {upper_file} as nc format and save to {output_dir}/{file_name}."
+ )
+
+
+def convert(cfg: Dict):
+ output_dir = cfg.output_dir
+
+ for _, file_name in os.listdir(output_dir):
+ if not file_name.endwiths("npy"):
+ continue
+
+ convert_surface_data_to_nc(
+ osp.join(output_dir, file_name),
+ osp.basename(file_name) + "_surface.nc",
+ output_dir,
+ )
+ convert_upper_data_to_nc(
+ osp.join(output_dir, file_name),
+ osp.basename(file_name) + "_upper.nc",
+ output_dir,
+ )
+
+
+@hydra.main(version_base=None, config_path="./conf", config_name="fengwu.yaml")
+def main(cfg: Dict):
+ if cfg.mode == "infer":
+ convert(cfg)
+ else:
+ raise ValueError(f"cfg.mode should in ['infer'], but got '{cfg.mode}'")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/fengwu/data_mean.npy b/examples/fengwu/data_mean.npy
new file mode 100644
index 0000000000..b79998952c
Binary files /dev/null and b/examples/fengwu/data_mean.npy differ
diff --git a/examples/fengwu/data_std.npy b/examples/fengwu/data_std.npy
new file mode 100644
index 0000000000..805d54884a
Binary files /dev/null and b/examples/fengwu/data_std.npy differ
diff --git a/examples/fengwu/predict.py b/examples/fengwu/predict.py
new file mode 100644
index 0000000000..d5b756edbe
--- /dev/null
+++ b/examples/fengwu/predict.py
@@ -0,0 +1,184 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from os import path as osp
+from typing import List
+
+import hydra
+import numpy as np
+import paddle
+from omegaconf import DictConfig
+from packaging import version
+
+from deploy.python_infer import base
+from ppsci.utils import logger
+
+
+class FengWuPredictor(base.Predictor):
+ """General predictor for FengWu model.
+
+ Args:
+ cfg (DictConfig): Running configuration.
+ """
+
+ # 14 day with time-interval of siz hours
+ PREDICT_TIMESTAMP = int(14 * 24 / 6)
+ # Where 69 represents 69 atmospheric features, The first four variables are surface variables in the order of ['u10', 'v10', 't2m', 'msl'],
+ # followed by non-surface variables in the order of ['z', 'q', 'u', 'v', 't']. Each data has 13 levels, which are ordered as
+ # [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000].
+ # Therefore, the order of the 69 variables is [u10, v10, t2m, msl, z50, z100, ..., z1000, q50, q100, ..., q1000, t50, t100, ..., t1000].
+ NUM_ATMOSPHERIC_FEATURES = 69
+
+ def __init__(
+ self,
+ cfg: DictConfig,
+ ):
+ assert cfg.INFER.engine == "onnx", "FengWu engine only supports 'onnx'."
+
+ super().__init__(
+ pdmodel_path=None,
+ pdiparams_path=None,
+ device=cfg.INFER.device,
+ engine=cfg.INFER.engine,
+ precision=cfg.INFER.precision,
+ onnx_path=cfg.INFER.onnx_path,
+ ir_optim=cfg.INFER.ir_optim,
+ min_subgraph_size=cfg.INFER.min_subgraph_size,
+ gpu_mem=cfg.INFER.gpu_mem,
+ gpu_id=cfg.INFER.gpu_id,
+ max_batch_size=cfg.INFER.max_batch_size,
+ num_cpu_threads=cfg.INFER.num_cpu_threads,
+ )
+ self.log_freq = cfg.log_freq
+
+ # get input names
+ self.input_names = [
+ input_node.name for input_node in self.predictor.get_inputs()
+ ]
+
+ # get output names
+ self.output_names = [
+ output_node.name for output_node in self.predictor.get_outputs()
+ ]
+
+ # load mean and std data
+ self.data_mean = np.load(cfg.INFER.mean_path)[:, np.newaxis, np.newaxis]
+ self.data_std = np.load(cfg.INFER.std_path)[:, np.newaxis, np.newaxis]
+
+ def _preprocess_data(
+ self, input_data_prev: np.ndarray, input_data_next: np.ndarray
+ ) -> np.ndarray:
+ input_data_prev_after_norm = (
+ input_data_prev.astype("float32") - self.data_mean
+ ) / self.data_std
+ input_data_next_after_norm = (
+ input_data_next.astype("float32") - self.data_mean
+ ) / self.data_std
+ input_data = np.concatenate(
+ (input_data_prev_after_norm, input_data_next_after_norm), axis=0
+ )[np.newaxis, :, :, :]
+ input_data = input_data.astype(np.float32)
+
+ return input_data
+
+ def predict(
+ self,
+ input_data_prev: np.ndarray,
+ input_data_next: np.ndarray,
+ batch_size: int = 1,
+ ) -> List[np.ndarray]:
+ """Predicts the output of the yinglong model for the given input.
+
+ Args:
+ input_data_prev(np.ndarray): Atomospheric data at the first time moment.
+ input_data_next(np.ndarray): Atmospheric data six later.
+ batch_size (int, optional): Batch size, now only support 1. Defaults to 1.
+
+ Returns:
+ List[np.ndarray]: Prediction for next 56 hours.
+ """
+ if batch_size != 1:
+ raise ValueError(
+ f"FengWuPredictor only support batch_size=1, but got {batch_size}"
+ )
+
+ # process data
+ input_data = self._preprocess_data(input_data_prev, input_data_next)
+
+ output_data_list = []
+ # prepare input dict
+ for _ in range(self.PREDICT_TIMESTAMP):
+ input_dict = {
+ self.input_names[0]: input_data,
+ }
+
+ # run predictor
+ output_data = self.predictor.run(None, input_dict)[0]
+ input_data = np.concatenate(
+ (
+ input_data[:, self.NUM_ATMOSPHERIC_FEATURES :],
+ output_data[:, : self.NUM_ATMOSPHERIC_FEATURES],
+ ),
+ axis=1,
+ )
+ output_data = (
+ output_data[0, : self.NUM_ATMOSPHERIC_FEATURES] * self.data_std
+ ) + self.data_mean
+
+ output_data_list.append(output_data)
+
+ return output_data_list
+
+
+def inference(cfg: DictConfig):
+ # log paddlepaddle's version
+ if version.Version(paddle.__version__) != version.Version("0.0.0"):
+ paddle_version = paddle.__version__
+ if version.Version(paddle.__version__) < version.Version("2.6.0"):
+ logger.warning(
+ f"Detected paddlepaddle version is '{paddle_version}', "
+ "currently it is recommended to use release 2.6 or develop version."
+ )
+ else:
+ paddle_version = f"develop({paddle.version.commit[:7]})"
+
+ logger.info(f"Using paddlepaddle {paddle_version}")
+
+ # create predictor
+ predictor = FengWuPredictor(cfg)
+
+ # load data
+ input_data_prev = np.load(cfg.INFER.input_file).astype(np.float32)
+ input_data_next = np.load(cfg.INFER.input_next_file).astype(np.float32)
+
+ # run predictor
+ output_data_list = predictor.predict(input_data_prev, input_data_next)
+
+ # save predict data
+ for i in range(FengWuPredictor.PREDICT_TIMESTAMP):
+ output_save_path = osp.join(cfg.output_dir, f"output_{i}.npy")
+ np.save(output_save_path, output_data_list[i])
+ logger.info(f"Save output with timestamp:{i} to {output_save_path}.")
+
+
+@hydra.main(version_base=None, config_path="./conf", config_name="fengwu.yaml")
+def main(cfg: DictConfig):
+ if cfg.mode == "infer":
+ inference(cfg)
+ else:
+ raise ValueError(f"cfg.mode should in ['infer'], but got '{cfg.mode}'")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mkdocs.yml b/mkdocs.yml
index c18b188b59..d4105b1c14 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -105,6 +105,7 @@ nav:
- TGCN: zh/examples/tgcn.md
- IOPS: zh/examples/iops.md
- Pang-Weather: zh/examples/pangu_weather.md
+ - FengWu: zh/examples/fengwu.md
- 化学科学(AI for Chemistry):
- Moflow: zh/examples/moflow.md
- IFM: zh/examples/ifm.md