diff --git a/README.md b/README.md index 65ee734668..71d0cc5557 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,8 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计 | 天气预报 | [FourCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fourcastnet) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) | | 天气预报 | [NowCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/nowcastnet) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) | | 天气预报 | [GraphCast 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/graphcast) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) | +| 天气预报 | [FengWu 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fengwu) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2304.02948) | +| 天气预报 | [Pangu-Weather 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/pangu_weather) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) | | 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - | | 天气预报 | [DGMR 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/dgmr.md) | 数据驱动 | DGMR | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) | | 地震波形反演 | [VelocityGAN 地震波形反演](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) | diff --git a/docs/index.md b/docs/index.md index 602c238c2c..732aeca6f0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -148,6 +148,8 @@ | 天气预报 | [FourCastNet 气象预报](./zh/examples/fourcastnet.md) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) | | 天气预报 | [NowCastNet 气象预报](./zh/examples/nowcastnet.md) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) | | 天气预报 | [GraphCast 气象预报](./zh/examples/graphcast.md) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) | +| 天气预报 | [FengWu 气象预报](./zh/examples/fengwu.md) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2304.02948) | +| 天气预报 | [Pangu-Weather 气象预报](./zh/examples/pangu_weather.md) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) | | 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - | | 天气预报 | [DGMR 气象预报](./zh/examples/dgmr.md) | 数据驱动 | DGMR | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) | | 地震波形反演 | [VelocityGAN 地震波形反演](./zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) | diff --git a/docs/zh/examples/fengwu.md b/docs/zh/examples/fengwu.md new file mode 100644 index 0000000000..2466b62924 --- /dev/null +++ b/docs/zh/examples/fengwu.md @@ -0,0 +1,101 @@ +# FengWu + +=== "模型训练命令" + + 暂无 + +=== "模型评估命令" + + 暂无 + +=== "模型导出命令" + + 暂无 + +=== "模型推理命令" + + ``` sh + # Download sample input data + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/input1.npy -P ./data + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/input2.npy -P ./data + + # Download pretrain model weight + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/fengwu_v2.onnx -P ./inference + + # inference + python predict.py + ``` + +## 1. 背景简介 + +随着近年来全球气候变化加剧,极端天气频发,各界对天气预报的时效和精度的期待更是与日俱增。如何提高天气预报的时效和准确度,一直是业内的重点课题。AI大模型“风乌”基于多模态和多任务深度学习方法构建,实现在高分辨率上对核心大气变量进行超过10天的有效预报,并在80%的评估指标上超越DeepMind发布的模型GraphCast。同时,“风乌”仅需30秒即可生成未来10天全球高精度预报结果,在效率上大幅优于传统模型。 + +## 2. 模型原理 + +本章节仅对风乌气象大模型的原理进行简单地介绍,详细的理论推导请阅读 [FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead](https://arxiv.org/pdf/2304.02948)。 + +模型的总体结构如图所示: + +
+ ![result](https://paddle-org.bj.bcebos.com/paddlescience/docs/fengwu/model_architecture.png){ loading=lazy style="margin:0 auto;"} +
模型结构
+
+ +模型将气候变量作为不同模态的输入。在 `Modal-Customized Encoder` 中将多个模态的特征进行编码,并使用基于 Transformer 的 `Cross-modal Fuser` 对编码后的特征进行融合,得到联合表示,最后在 `Modal-Customized Decoder` 中从联合表示中分别预测气候变量。 + +模型使用预训练权重推理,接下来将介绍模型的推理过程。 + +## 3. 模型构建 + +在该案例中,实现了 FengWuPredictor用于ONNX模型的推理: + +``` py linenums="74" title="examples/fengwu/predict.py" +--8<-- +examples/fengwu/predict.py:74:130 +--8<-- +``` + +``` yaml linenums="28" title="examples/fengwu/conf/fengwu.yaml" +--8<-- +examples/fengwu/conf/fengwu.yaml:28:46 +--8<-- +``` + +其中,`input_file` 和 `input_next_file` 分别代表网络模型输入的开始时刻气象数据和6小时后的气象数据。 + +## 4. 结果可视化 + +模型推理结果包含 56 个 npy 文件,表示从预测时间点开始,未来 14 天内每隔6小时的气象数据。结果可视化需要先将数据从 npy 转换为 NetCDF 格式,然后采用 ncvue 进行查看。 + +1. 安装相关依赖 +```python +pip install cdsapi netCDF4 ncvue +``` + +2. 使用脚本进行数据转换 +```python +python convert_data.py +``` + +3. 使用 ncvue 打开转换后的 NetCDF 文件, ncvue 具体说明见[ncvue官方文档](https://github.com/mcuntz/ncvue) + +## 5. 完整代码 + +``` py linenums="1" title="examples/fengwu/predict.py" +--8<-- +examples/fengwu/predict.py +--8<-- +``` + +## 6. 结果展示 + +下图展示了模型的未来6小时平均海平面气压预测结果,更多指标可以使用 ncvue 查看。 + +
+ ![result](https://paddle-org.bj.bcebos.com/paddlescience/docs/fengwu/image.png){ loading=lazy style="margin:0 auto;"} +
未来6小时平均海平面气压
+
+ +## 7. 参考资料 + +- [FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead](https://arxiv.org/pdf/2304.02948) diff --git a/examples/fengwu/conf/fengwu.yaml b/examples/fengwu/conf/fengwu.yaml new file mode 100644 index 0000000000..2af18f3a7e --- /dev/null +++ b/examples/fengwu/conf/fengwu.yaml @@ -0,0 +1,46 @@ +defaults: + - ppsci_default + - INFER: infer_default + - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default + - _self_ + +hydra: + run: + # dynamic output directory according to running time and override name + dir: ./outputs_fengwu + job: + name: ${mode} # name of logfile + chdir: false # keep current working directory unchanged + callbacks: + init_callback: + _target_: ppsci.utils.callbacks.InitCallback + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: infer # running mode: infer +seed: 2023 +output_dir: ${hydra:run.dir} +log_freq: 20 + +# inference settings +INFER: + pretrained_model_path: null + export_path: inference/fengwu_v2 + onnx_path: ${INFER.export_path}.onnx + device: gpu + engine: onnx + precision: fp32 + ir_optim: false + min_subgraph_size: 30 + gpu_mem: 100 + gpu_id: 0 + max_batch_size: 1 + num_cpu_threads: 10 + batch_size: 1 + mean_path: ./data_mean.npy + std_path: ./data_std.npy + input_file: './data/input1.npy' + input_next_file: './data/input2.npy' diff --git a/examples/fengwu/convert_data.py b/examples/fengwu/convert_data.py new file mode 100644 index 0000000000..90c98c40ff --- /dev/null +++ b/examples/fengwu/convert_data.py @@ -0,0 +1,172 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ref: https://github.com/HaxyMoly/Pangu-Weather-ReadyToGo/blob/main/forecast_decode_functions.py + +import os +from os import path as osp +from typing import Dict + +import hydra +import netCDF4 as nc +import numpy as np + +from ppsci.utils import logger + + +def convert_surface_data_to_nc( + surface_file: str, file_name: str, output_dir: str +) -> None: + surface_data = np.load(surface_file) + u_component_of_wind_10m = surface_data[0] + v_component_of_wind_10m = surface_data[1] + temperature_2m = surface_data[2] + mean_sea_level_pressure = surface_data[3] + + with nc.Dataset( + os.path.join(output_dir, file_name), "w", format="NETCDF4_CLASSIC" + ) as nc_file: + # Create dimensions + nc_file.createDimension("longitude", 1440) + nc_file.createDimension("latitude", 721) + + # Create variables + nc_lon = nc_file.createVariable("longitude", np.float32, ("longitude",)) + nc_lat = nc_file.createVariable("latitude", np.float32, ("latitude",)) + nc_msl = nc_file.createVariable( + "mean_sea_level_pressure", np.float32, ("latitude", "longitude") + ) + nc_u10 = nc_file.createVariable( + "u_component_of_wind_10m", np.float32, ("latitude", "longitude") + ) + nc_v10 = nc_file.createVariable( + "v_component_of_wind_10m", np.float32, ("latitude", "longitude") + ) + nc_t2m = nc_file.createVariable( + "temperature_2m", np.float32, ("latitude", "longitude") + ) + + # Set variable attributes + nc_lon.units = "degrees_east" + nc_lat.units = "degrees_north" + nc_msl.units = "Pa" + nc_u10.units = "m/s" + nc_v10.units = "m/s" + nc_t2m.units = "K" + + # Write data to variables + nc_lon[:] = np.linspace(0.125, 359.875, 1440) + nc_lat[:] = np.linspace(90, -90, 721) + nc_msl[:] = mean_sea_level_pressure + nc_u10[:] = u_component_of_wind_10m + nc_v10[:] = v_component_of_wind_10m + nc_t2m[:] = temperature_2m + + logger.info( + f"Convert output surface data file {surface_file} as nc format and save to {output_dir}/{file_name}." + ) + + +def convert_upper_data_to_nc(upper_file: str, file_name: str, output_dir: str) -> None: + # Load the saved numpy arrays + upper_data = np.load(upper_file) + + # surface data offset + st = 4 + level = 13 + + geopotential = upper_data[st : st + level] + specific_humidity = upper_data[st + level : st + 2 * level] + u_component_of_wind = upper_data[st + 2 * level : st + 3 * level] + v_component_of_wind = upper_data[st + 3 * level : st + 4 * level] + temperature = upper_data[st + 4 * level :] + + with nc.Dataset( + os.path.join(output_dir, file_name), "w", format="NETCDF4_CLASSIC" + ) as nc_file: + # Create dimensions + nc_file.createDimension("longitude", 1440) + nc_file.createDimension("latitude", 721) + nc_file.createDimension("level", level) + + # Create variables + nc_lon = nc_file.createVariable("longitude", np.float32, ("longitude",)) + nc_lat = nc_file.createVariable("latitude", np.float32, ("latitude",)) + nc_geopotential = nc_file.createVariable( + "geopotential", np.float32, ("level", "latitude", "longitude") + ) + nc_specific_humidity = nc_file.createVariable( + "specific_humidity", np.float32, ("level", "latitude", "longitude") + ) + nc_temperature = nc_file.createVariable( + "temperature", np.float32, ("level", "latitude", "longitude") + ) + nc_u_component_of_wind = nc_file.createVariable( + "u_component_of_wind", np.float32, ("level", "latitude", "longitude") + ) + nc_v_component_of_wind = nc_file.createVariable( + "v_component_of_wind", np.float32, ("level", "latitude", "longitude") + ) + + # Set variable attributes + nc_lon.units = "degrees_east" + nc_lat.units = "degrees_north" + nc_geopotential.units = "m" + nc_specific_humidity.units = "kg/kg" + nc_temperature.units = "K" + nc_u_component_of_wind.units = "m/s" + nc_v_component_of_wind.units = "m/s" + # Write data to variables + nc_lon[:] = np.linspace(0.125, 359.875, 1440) + nc_lat[:] = np.linspace(90, -90, 721) + nc_geopotential[:] = geopotential + nc_specific_humidity[:] = specific_humidity + nc_temperature[:] = temperature + nc_u_component_of_wind[:] = u_component_of_wind + nc_v_component_of_wind[:] = v_component_of_wind + + logger.info( + f"Convert output upper data file {upper_file} as nc format and save to {output_dir}/{file_name}." + ) + + +def convert(cfg: Dict): + output_dir = cfg.output_dir + + for _, file_name in os.listdir(output_dir): + if not file_name.endwiths("npy"): + continue + + convert_surface_data_to_nc( + osp.join(output_dir, file_name), + osp.basename(file_name) + "_surface.nc", + output_dir, + ) + convert_upper_data_to_nc( + osp.join(output_dir, file_name), + osp.basename(file_name) + "_upper.nc", + output_dir, + ) + + +@hydra.main(version_base=None, config_path="./conf", config_name="fengwu.yaml") +def main(cfg: Dict): + if cfg.mode == "infer": + convert(cfg) + else: + raise ValueError(f"cfg.mode should in ['infer'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() diff --git a/examples/fengwu/data_mean.npy b/examples/fengwu/data_mean.npy new file mode 100644 index 0000000000..b79998952c Binary files /dev/null and b/examples/fengwu/data_mean.npy differ diff --git a/examples/fengwu/data_std.npy b/examples/fengwu/data_std.npy new file mode 100644 index 0000000000..805d54884a Binary files /dev/null and b/examples/fengwu/data_std.npy differ diff --git a/examples/fengwu/predict.py b/examples/fengwu/predict.py new file mode 100644 index 0000000000..d5b756edbe --- /dev/null +++ b/examples/fengwu/predict.py @@ -0,0 +1,184 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from os import path as osp +from typing import List + +import hydra +import numpy as np +import paddle +from omegaconf import DictConfig +from packaging import version + +from deploy.python_infer import base +from ppsci.utils import logger + + +class FengWuPredictor(base.Predictor): + """General predictor for FengWu model. + + Args: + cfg (DictConfig): Running configuration. + """ + + # 14 day with time-interval of siz hours + PREDICT_TIMESTAMP = int(14 * 24 / 6) + # Where 69 represents 69 atmospheric features, The first four variables are surface variables in the order of ['u10', 'v10', 't2m', 'msl'], + # followed by non-surface variables in the order of ['z', 'q', 'u', 'v', 't']. Each data has 13 levels, which are ordered as + # [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]. + # Therefore, the order of the 69 variables is [u10, v10, t2m, msl, z50, z100, ..., z1000, q50, q100, ..., q1000, t50, t100, ..., t1000]. + NUM_ATMOSPHERIC_FEATURES = 69 + + def __init__( + self, + cfg: DictConfig, + ): + assert cfg.INFER.engine == "onnx", "FengWu engine only supports 'onnx'." + + super().__init__( + pdmodel_path=None, + pdiparams_path=None, + device=cfg.INFER.device, + engine=cfg.INFER.engine, + precision=cfg.INFER.precision, + onnx_path=cfg.INFER.onnx_path, + ir_optim=cfg.INFER.ir_optim, + min_subgraph_size=cfg.INFER.min_subgraph_size, + gpu_mem=cfg.INFER.gpu_mem, + gpu_id=cfg.INFER.gpu_id, + max_batch_size=cfg.INFER.max_batch_size, + num_cpu_threads=cfg.INFER.num_cpu_threads, + ) + self.log_freq = cfg.log_freq + + # get input names + self.input_names = [ + input_node.name for input_node in self.predictor.get_inputs() + ] + + # get output names + self.output_names = [ + output_node.name for output_node in self.predictor.get_outputs() + ] + + # load mean and std data + self.data_mean = np.load(cfg.INFER.mean_path)[:, np.newaxis, np.newaxis] + self.data_std = np.load(cfg.INFER.std_path)[:, np.newaxis, np.newaxis] + + def _preprocess_data( + self, input_data_prev: np.ndarray, input_data_next: np.ndarray + ) -> np.ndarray: + input_data_prev_after_norm = ( + input_data_prev.astype("float32") - self.data_mean + ) / self.data_std + input_data_next_after_norm = ( + input_data_next.astype("float32") - self.data_mean + ) / self.data_std + input_data = np.concatenate( + (input_data_prev_after_norm, input_data_next_after_norm), axis=0 + )[np.newaxis, :, :, :] + input_data = input_data.astype(np.float32) + + return input_data + + def predict( + self, + input_data_prev: np.ndarray, + input_data_next: np.ndarray, + batch_size: int = 1, + ) -> List[np.ndarray]: + """Predicts the output of the yinglong model for the given input. + + Args: + input_data_prev(np.ndarray): Atomospheric data at the first time moment. + input_data_next(np.ndarray): Atmospheric data six later. + batch_size (int, optional): Batch size, now only support 1. Defaults to 1. + + Returns: + List[np.ndarray]: Prediction for next 56 hours. + """ + if batch_size != 1: + raise ValueError( + f"FengWuPredictor only support batch_size=1, but got {batch_size}" + ) + + # process data + input_data = self._preprocess_data(input_data_prev, input_data_next) + + output_data_list = [] + # prepare input dict + for _ in range(self.PREDICT_TIMESTAMP): + input_dict = { + self.input_names[0]: input_data, + } + + # run predictor + output_data = self.predictor.run(None, input_dict)[0] + input_data = np.concatenate( + ( + input_data[:, self.NUM_ATMOSPHERIC_FEATURES :], + output_data[:, : self.NUM_ATMOSPHERIC_FEATURES], + ), + axis=1, + ) + output_data = ( + output_data[0, : self.NUM_ATMOSPHERIC_FEATURES] * self.data_std + ) + self.data_mean + + output_data_list.append(output_data) + + return output_data_list + + +def inference(cfg: DictConfig): + # log paddlepaddle's version + if version.Version(paddle.__version__) != version.Version("0.0.0"): + paddle_version = paddle.__version__ + if version.Version(paddle.__version__) < version.Version("2.6.0"): + logger.warning( + f"Detected paddlepaddle version is '{paddle_version}', " + "currently it is recommended to use release 2.6 or develop version." + ) + else: + paddle_version = f"develop({paddle.version.commit[:7]})" + + logger.info(f"Using paddlepaddle {paddle_version}") + + # create predictor + predictor = FengWuPredictor(cfg) + + # load data + input_data_prev = np.load(cfg.INFER.input_file).astype(np.float32) + input_data_next = np.load(cfg.INFER.input_next_file).astype(np.float32) + + # run predictor + output_data_list = predictor.predict(input_data_prev, input_data_next) + + # save predict data + for i in range(FengWuPredictor.PREDICT_TIMESTAMP): + output_save_path = osp.join(cfg.output_dir, f"output_{i}.npy") + np.save(output_save_path, output_data_list[i]) + logger.info(f"Save output with timestamp:{i} to {output_save_path}.") + + +@hydra.main(version_base=None, config_path="./conf", config_name="fengwu.yaml") +def main(cfg: DictConfig): + if cfg.mode == "infer": + inference(cfg) + else: + raise ValueError(f"cfg.mode should in ['infer'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml index c18b188b59..d4105b1c14 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -105,6 +105,7 @@ nav: - TGCN: zh/examples/tgcn.md - IOPS: zh/examples/iops.md - Pang-Weather: zh/examples/pangu_weather.md + - FengWu: zh/examples/fengwu.md - 化学科学(AI for Chemistry): - Moflow: zh/examples/moflow.md - IFM: zh/examples/ifm.md