Skip to content

Commit f1fd9b8

Browse files
authored
【Hackathon 8th No.18】FengWu 论文复现 (#1090)
* Add FengWu predictor * fix reviewer issue * gpu_id default set to 0 * Add convert script * Add docs and fix some bugs * update docs * fix docs * fix docs * resolve reviewer issues
1 parent 96eb5ba commit f1fd9b8

File tree

9 files changed

+508
-0
lines changed

9 files changed

+508
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ PaddleScience 是一个基于深度学习框架 PaddlePaddle 开发的科学计
112112
| 天气预报 | [FourCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fourcastnet) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
113113
| 天气预报 | [NowCastNet 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/nowcastnet) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
114114
| 天气预报 | [GraphCast 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/graphcast) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
115+
| 天气预报 | [FengWu 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/fengwu) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2304.02948) |
116+
| 天气预报 | [Pangu-Weather 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/pangu_weather) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) |
115117
| 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - |
116118
| 天气预报 | [DGMR 气象预报](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/dgmr.md) | 数据驱动 | DGMR | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) |
117119
| 地震波形反演 | [VelocityGAN 地震波形反演](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) |

docs/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@
148148
| 天气预报 | [FourCastNet 气象预报](./zh/examples/fourcastnet.md) | 数据驱动 | FourCastNet | 监督学习 | [ERA5](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://arxiv.org/pdf/2202.11214.pdf) |
149149
| 天气预报 | [NowCastNet 气象预报](./zh/examples/nowcastnet.md) | 数据驱动 | NowCastNet | 监督学习 | [MRMS](https://app.globus.org/file-manager?origin_id=945b3c9e-0f8c-11ed-8daf-9f359c660fbd&origin_path=%2F~%2Fdata%2F) | [Paper](https://www.nature.com/articles/s41586-023-06184-4) |
150150
| 天气预报 | [GraphCast 气象预报](./zh/examples/graphcast.md) | 数据驱动 | GraphCastNet | 监督学习 | - | [Paper](https://arxiv.org/abs/2212.12794) |
151+
| 天气预报 | [FengWu 气象预报](./zh/examples/fengwu.md) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2304.02948) |
152+
| 天气预报 | [Pangu-Weather 气象预报](./zh/examples/pangu_weather.md) | 数据驱动 | Transformer | 监督学习 | - | [Paper](https://arxiv.org/pdf/2211.02556) |
151153
| 大气污染物 | [UNet 污染物扩散](https://aistudio.baidu.com/projectdetail/5663515?channel=0&channelType=0&sUid=438690&shared=1&ts=1698221963752) | 数据驱动 | UNet | 监督学习 | [Data](https://aistudio.baidu.com/datasetdetail/198102) | - |
152154
| 天气预报 | [DGMR 气象预报](./zh/examples/dgmr.md) | 数据驱动 | DGMR | 监督学习 | [UK dataset](https://huggingface.co/datasets/openclimatefix/nimrod-uk-1km) | [Paper](https://arxiv.org/pdf/2104.00954.pdf) |
153155
| 地震波形反演 | [VelocityGAN 地震波形反演](./zh/examples/velocity_gan.md) | 数据驱动 | VelocityGAN | 监督学习 | [OpenFWI](https://openfwi-lanl.github.io/docs/data.html#vel) | [Paper](https://arxiv.org/abs/1809.10262v6) |

docs/zh/examples/fengwu.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# FengWu
2+
3+
=== "模型训练命令"
4+
5+
暂无
6+
7+
=== "模型评估命令"
8+
9+
暂无
10+
11+
=== "模型导出命令"
12+
13+
暂无
14+
15+
=== "模型推理命令"
16+
17+
``` sh
18+
# Download sample input data
19+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/input1.npy -P ./data
20+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/input2.npy -P ./data
21+
22+
# Download pretrain model weight
23+
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/Fengwu/fengwu_v2.onnx -P ./inference
24+
25+
# inference
26+
python predict.py
27+
```
28+
29+
## 1. 背景简介
30+
31+
随着近年来全球气候变化加剧,极端天气频发,各界对天气预报的时效和精度的期待更是与日俱增。如何提高天气预报的时效和准确度,一直是业内的重点课题。AI大模型“风乌”基于多模态和多任务深度学习方法构建,实现在高分辨率上对核心大气变量进行超过10天的有效预报,并在80%的评估指标上超越DeepMind发布的模型GraphCast。同时,“风乌”仅需30秒即可生成未来10天全球高精度预报结果,在效率上大幅优于传统模型。
32+
33+
## 2. 模型原理
34+
35+
本章节仅对风乌气象大模型的原理进行简单地介绍,详细的理论推导请阅读 [FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead](https://arxiv.org/pdf/2304.02948)
36+
37+
模型的总体结构如图所示:
38+
39+
<figure markdown>
40+
![result](https://paddle-org.bj.bcebos.com/paddlescience/docs/fengwu/model_architecture.png){ loading=lazy style="margin:0 auto;"}
41+
<figcaption>模型结构</figcaption>
42+
</figure>
43+
44+
模型将气候变量作为不同模态的输入。在 `Modal-Customized Encoder` 中将多个模态的特征进行编码,并使用基于 Transformer 的 `Cross-modal Fuser` 对编码后的特征进行融合,得到联合表示,最后在 `Modal-Customized Decoder` 中从联合表示中分别预测气候变量。
45+
46+
模型使用预训练权重推理,接下来将介绍模型的推理过程。
47+
48+
## 3. 模型构建
49+
50+
在该案例中,实现了 FengWuPredictor用于ONNX模型的推理:
51+
52+
``` py linenums="74" title="examples/fengwu/predict.py"
53+
--8<--
54+
examples/fengwu/predict.py:74:130
55+
--8<--
56+
```
57+
58+
``` yaml linenums="28" title="examples/fengwu/conf/fengwu.yaml"
59+
--8<--
60+
examples/fengwu/conf/fengwu.yaml:28:46
61+
--8<--
62+
```
63+
64+
其中,`input_file``input_next_file` 分别代表网络模型输入的开始时刻气象数据和6小时后的气象数据。
65+
66+
## 4. 结果可视化
67+
68+
模型推理结果包含 56 个 npy 文件,表示从预测时间点开始,未来 14 天内每隔6小时的气象数据。结果可视化需要先将数据从 npy 转换为 NetCDF 格式,然后采用 ncvue 进行查看。
69+
70+
1. 安装相关依赖
71+
```python
72+
pip install cdsapi netCDF4 ncvue
73+
```
74+
75+
2. 使用脚本进行数据转换
76+
```python
77+
python convert_data.py
78+
```
79+
80+
3. 使用 ncvue 打开转换后的 NetCDF 文件, ncvue 具体说明见[ncvue官方文档](https://github.com/mcuntz/ncvue)
81+
82+
## 5. 完整代码
83+
84+
``` py linenums="1" title="examples/fengwu/predict.py"
85+
--8<--
86+
examples/fengwu/predict.py
87+
--8<--
88+
```
89+
90+
## 6. 结果展示
91+
92+
下图展示了模型的未来6小时平均海平面气压预测结果,更多指标可以使用 ncvue 查看。
93+
94+
<figure markdown>
95+
![result](https://paddle-org.bj.bcebos.com/paddlescience/docs/fengwu/image.png){ loading=lazy style="margin:0 auto;"}
96+
<figcaption>未来6小时平均海平面气压</figcaption>
97+
</figure>
98+
99+
## 7. 参考资料
100+
101+
- [FengWu: Pushing the Skillful Global Medium-range Weather Forecast beyond 10 Days Lead](https://arxiv.org/pdf/2304.02948)

examples/fengwu/conf/fengwu.yaml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
defaults:
2+
- ppsci_default
3+
- INFER: infer_default
4+
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
5+
- _self_
6+
7+
hydra:
8+
run:
9+
# dynamic output directory according to running time and override name
10+
dir: ./outputs_fengwu
11+
job:
12+
name: ${mode} # name of logfile
13+
chdir: false # keep current working directory unchanged
14+
callbacks:
15+
init_callback:
16+
_target_: ppsci.utils.callbacks.InitCallback
17+
sweep:
18+
# output directory for multirun
19+
dir: ${hydra.run.dir}
20+
subdir: ./
21+
22+
# general settings
23+
mode: infer # running mode: infer
24+
seed: 2023
25+
output_dir: ${hydra:run.dir}
26+
log_freq: 20
27+
28+
# inference settings
29+
INFER:
30+
pretrained_model_path: null
31+
export_path: inference/fengwu_v2
32+
onnx_path: ${INFER.export_path}.onnx
33+
device: gpu
34+
engine: onnx
35+
precision: fp32
36+
ir_optim: false
37+
min_subgraph_size: 30
38+
gpu_mem: 100
39+
gpu_id: 0
40+
max_batch_size: 1
41+
num_cpu_threads: 10
42+
batch_size: 1
43+
mean_path: ./data_mean.npy
44+
std_path: ./data_std.npy
45+
input_file: './data/input1.npy'
46+
input_next_file: './data/input2.npy'

examples/fengwu/convert_data.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# ref: https://github.com/HaxyMoly/Pangu-Weather-ReadyToGo/blob/main/forecast_decode_functions.py
16+
17+
import os
18+
from os import path as osp
19+
from typing import Dict
20+
21+
import hydra
22+
import netCDF4 as nc
23+
import numpy as np
24+
25+
from ppsci.utils import logger
26+
27+
28+
def convert_surface_data_to_nc(
29+
surface_file: str, file_name: str, output_dir: str
30+
) -> None:
31+
surface_data = np.load(surface_file)
32+
u_component_of_wind_10m = surface_data[0]
33+
v_component_of_wind_10m = surface_data[1]
34+
temperature_2m = surface_data[2]
35+
mean_sea_level_pressure = surface_data[3]
36+
37+
with nc.Dataset(
38+
os.path.join(output_dir, file_name), "w", format="NETCDF4_CLASSIC"
39+
) as nc_file:
40+
# Create dimensions
41+
nc_file.createDimension("longitude", 1440)
42+
nc_file.createDimension("latitude", 721)
43+
44+
# Create variables
45+
nc_lon = nc_file.createVariable("longitude", np.float32, ("longitude",))
46+
nc_lat = nc_file.createVariable("latitude", np.float32, ("latitude",))
47+
nc_msl = nc_file.createVariable(
48+
"mean_sea_level_pressure", np.float32, ("latitude", "longitude")
49+
)
50+
nc_u10 = nc_file.createVariable(
51+
"u_component_of_wind_10m", np.float32, ("latitude", "longitude")
52+
)
53+
nc_v10 = nc_file.createVariable(
54+
"v_component_of_wind_10m", np.float32, ("latitude", "longitude")
55+
)
56+
nc_t2m = nc_file.createVariable(
57+
"temperature_2m", np.float32, ("latitude", "longitude")
58+
)
59+
60+
# Set variable attributes
61+
nc_lon.units = "degrees_east"
62+
nc_lat.units = "degrees_north"
63+
nc_msl.units = "Pa"
64+
nc_u10.units = "m/s"
65+
nc_v10.units = "m/s"
66+
nc_t2m.units = "K"
67+
68+
# Write data to variables
69+
nc_lon[:] = np.linspace(0.125, 359.875, 1440)
70+
nc_lat[:] = np.linspace(90, -90, 721)
71+
nc_msl[:] = mean_sea_level_pressure
72+
nc_u10[:] = u_component_of_wind_10m
73+
nc_v10[:] = v_component_of_wind_10m
74+
nc_t2m[:] = temperature_2m
75+
76+
logger.info(
77+
f"Convert output surface data file {surface_file} as nc format and save to {output_dir}/{file_name}."
78+
)
79+
80+
81+
def convert_upper_data_to_nc(upper_file: str, file_name: str, output_dir: str) -> None:
82+
# Load the saved numpy arrays
83+
upper_data = np.load(upper_file)
84+
85+
# surface data offset
86+
st = 4
87+
level = 13
88+
89+
geopotential = upper_data[st : st + level]
90+
specific_humidity = upper_data[st + level : st + 2 * level]
91+
u_component_of_wind = upper_data[st + 2 * level : st + 3 * level]
92+
v_component_of_wind = upper_data[st + 3 * level : st + 4 * level]
93+
temperature = upper_data[st + 4 * level :]
94+
95+
with nc.Dataset(
96+
os.path.join(output_dir, file_name), "w", format="NETCDF4_CLASSIC"
97+
) as nc_file:
98+
# Create dimensions
99+
nc_file.createDimension("longitude", 1440)
100+
nc_file.createDimension("latitude", 721)
101+
nc_file.createDimension("level", level)
102+
103+
# Create variables
104+
nc_lon = nc_file.createVariable("longitude", np.float32, ("longitude",))
105+
nc_lat = nc_file.createVariable("latitude", np.float32, ("latitude",))
106+
nc_geopotential = nc_file.createVariable(
107+
"geopotential", np.float32, ("level", "latitude", "longitude")
108+
)
109+
nc_specific_humidity = nc_file.createVariable(
110+
"specific_humidity", np.float32, ("level", "latitude", "longitude")
111+
)
112+
nc_temperature = nc_file.createVariable(
113+
"temperature", np.float32, ("level", "latitude", "longitude")
114+
)
115+
nc_u_component_of_wind = nc_file.createVariable(
116+
"u_component_of_wind", np.float32, ("level", "latitude", "longitude")
117+
)
118+
nc_v_component_of_wind = nc_file.createVariable(
119+
"v_component_of_wind", np.float32, ("level", "latitude", "longitude")
120+
)
121+
122+
# Set variable attributes
123+
nc_lon.units = "degrees_east"
124+
nc_lat.units = "degrees_north"
125+
nc_geopotential.units = "m"
126+
nc_specific_humidity.units = "kg/kg"
127+
nc_temperature.units = "K"
128+
nc_u_component_of_wind.units = "m/s"
129+
nc_v_component_of_wind.units = "m/s"
130+
# Write data to variables
131+
nc_lon[:] = np.linspace(0.125, 359.875, 1440)
132+
nc_lat[:] = np.linspace(90, -90, 721)
133+
nc_geopotential[:] = geopotential
134+
nc_specific_humidity[:] = specific_humidity
135+
nc_temperature[:] = temperature
136+
nc_u_component_of_wind[:] = u_component_of_wind
137+
nc_v_component_of_wind[:] = v_component_of_wind
138+
139+
logger.info(
140+
f"Convert output upper data file {upper_file} as nc format and save to {output_dir}/{file_name}."
141+
)
142+
143+
144+
def convert(cfg: Dict):
145+
output_dir = cfg.output_dir
146+
147+
for _, file_name in os.listdir(output_dir):
148+
if not file_name.endwiths("npy"):
149+
continue
150+
151+
convert_surface_data_to_nc(
152+
osp.join(output_dir, file_name),
153+
osp.basename(file_name) + "_surface.nc",
154+
output_dir,
155+
)
156+
convert_upper_data_to_nc(
157+
osp.join(output_dir, file_name),
158+
osp.basename(file_name) + "_upper.nc",
159+
output_dir,
160+
)
161+
162+
163+
@hydra.main(version_base=None, config_path="./conf", config_name="fengwu.yaml")
164+
def main(cfg: Dict):
165+
if cfg.mode == "infer":
166+
convert(cfg)
167+
else:
168+
raise ValueError(f"cfg.mode should in ['infer'], but got '{cfg.mode}'")
169+
170+
171+
if __name__ == "__main__":
172+
main()

examples/fengwu/data_mean.npy

680 Bytes
Binary file not shown.

examples/fengwu/data_std.npy

680 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)