Skip to content

Commit 194a593

Browse files
【Hackathon 8th No.11】DrivAerNet++ 论文复现 (#1062)
* Update ReduceOnPlateau lr_scheduler.md * support drivaernetplusplus * amend older error arch.md * amend some question * update regpintnet.py * amend regpointnet and drivaernetplusplus.py * update code and format markdown * update code and format markdown * update code and format markdown * update md * update code and format markdown * Update drivaernetplusplus_dataset.py * Update examples/drivaernetplusplus/drivaernetplusplus.py Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update examples/drivaernetplusplus/drivaernetplusplus.py Co-authored-by: HydrogenSulfate <490868991@qq.com> * Update examples/drivaernetplusplus/drivaernetplusplus.py Co-authored-by: HydrogenSulfate <490868991@qq.com> * move data_augmentation to load dataset file * sysc the apply_augmentation in markdown file * sysc the apply_augmentation in markdown file * pre-commit file drivaernetplusplus.py * Update docs/zh/examples/drivaernetplusplus.md Co-authored-by: HydrogenSulfate <490868991@qq.com> * update drivaernetplusplus.md * update mkdocs.yml * Update mkdocs.yml * Update mkdocs.yml * update drivaernetplusplus_dataset.py * remove ARGS about dataset setting --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
1 parent 1de0dd5 commit 194a593

File tree

10 files changed

+1585
-0
lines changed

10 files changed

+1585
-0
lines changed

docs/zh/api/arch.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
- LNO
3939
- TGCN
4040
- RegDGCNN
41+
- RegPointNet
4142
- IFMMLP
4243
show_root_heading: true
4344
heading_level: 3

docs/zh/api/data/dataset.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@
3333
- CGCNNDataset
3434
- PEMSDataset
3535
- DrivAerNetDataset
36+
- DrivAerNetPlusPlusDataset
3637
- IFMMoeDataset
3738
show_root_heading: true

docs/zh/examples/drivaernetplusplus.md

Lines changed: 832 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
defaults:
2+
- ppsci_default
3+
- TRAIN: train_default
4+
- TRAIN/ema: ema_default
5+
- TRAIN/swa: swa_default
6+
- EVAL: eval_default
7+
- INFER: infer_default
8+
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
9+
- _self_
10+
11+
hydra:
12+
run:
13+
dir: outputs_drivaernetplusplus/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
14+
job:
15+
name: ${mode}
16+
chdir: false
17+
callbacks:
18+
init_callback:
19+
_target_: ppsci.utils.callbacks.InitCallback
20+
sweep:
21+
dir: ${hydra.run.dir}
22+
subdir: ./
23+
24+
# general settings
25+
mode: eval
26+
seed: 1
27+
output_dir: ${hydra:run.dir}
28+
log_freq: 100
29+
30+
# model settings
31+
MODEL:
32+
input_keys: ["vertices"]
33+
output_keys: ["cd_value"]
34+
weight_keys: ["weight_keys"]
35+
dropout: 0.0
36+
emb_dims: 1024
37+
channels: [6, 64, 128, 256, 512, 1024]
38+
linear_sizes: [128, 64, 32, 16]
39+
k: 40
40+
output_channels: 1
41+
42+
# training settings
43+
TRAIN:
44+
iters_per_epoch: 5399
45+
epochs: 200
46+
num_points: 100000
47+
num_workers: 32
48+
eval_during_train: True
49+
train_ids_file: "train_design_ids.txt"
50+
eval_ids_file: "val_design_ids.txt"
51+
batch_size: 32
52+
scheduler:
53+
mode: "min"
54+
patience: 20
55+
factor: 0.1
56+
verbose: True
57+
58+
# evaluation settings
59+
EVAL:
60+
num_points: 100000
61+
batch_size: 1
62+
pretrained_model_path: "https://dataset.bj.bcebos.com/PaddleScience/DNNFluid-Car/DrivAer%2B%2B/DragPrediction_DrivAerNet_PointNet_r2_batchsize16_200epochs_100kpoints_tsne_NeurIPS_best_model.pdparams"
63+
eval_with_no_grad: True
64+
ids_file: "test_design_ids.txt"
65+
num_workers: 8
66+
67+
# optimizer settings
68+
optimizer:
69+
weight_decay: 0.0001
70+
lr: 0.001
71+
optimizer: "adam"
72+
73+
# dataset settings
74+
dataset_path: "data/DrivAerNetPlusPlus_Processed_Point_Clouds_100k_paddle"
75+
aero_coeff: "data/DrivAerNetPlusPlus_Drag_8k.csv"
76+
subset_dir: "data/subset_dir"
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import warnings
17+
from functools import partial
18+
19+
import hydra
20+
import paddle
21+
from omegaconf import DictConfig
22+
23+
import ppsci
24+
25+
26+
def train(cfg: DictConfig):
27+
# set model
28+
model = ppsci.arch.RegPointNet(
29+
input_keys=cfg.MODEL.input_keys,
30+
label_keys=cfg.MODEL.output_keys,
31+
weight_keys=cfg.MODEL.weight_keys,
32+
args=cfg.MODEL,
33+
)
34+
35+
train_dataloader_cfg = {
36+
"dataset": {
37+
"name": "DrivAerNetPlusPlusDataset",
38+
"root_dir": cfg.dataset_path,
39+
"input_keys": cfg.MODEL.input_keys,
40+
"label_keys": cfg.MODEL.output_keys,
41+
"weight_keys": cfg.MODEL.weight_keys,
42+
"subset_dir": cfg.subset_dir,
43+
"ids_file": cfg.TRAIN.train_ids_file,
44+
"csv_file": cfg.aero_coeff,
45+
"num_points": cfg.TRAIN.num_points,
46+
},
47+
"batch_size": cfg.TRAIN.batch_size,
48+
"num_workers": cfg.TRAIN.num_workers,
49+
}
50+
51+
drivaernetplusplus_constraint = ppsci.constraint.SupervisedConstraint(
52+
train_dataloader_cfg,
53+
ppsci.loss.MSELoss("mean"),
54+
name="DrivAerNetplusplus_constraint",
55+
)
56+
57+
constraint = {drivaernetplusplus_constraint.name: drivaernetplusplus_constraint}
58+
59+
valid_dataloader_cfg = {
60+
"dataset": {
61+
"name": "DrivAerNetPlusPlusDataset",
62+
"root_dir": cfg.dataset_path,
63+
"input_keys": cfg.MODEL.input_keys,
64+
"label_keys": cfg.MODEL.output_keys,
65+
"weight_keys": cfg.MODEL.weight_keys,
66+
"subset_dir": cfg.subset_dir,
67+
"ids_file": cfg.TRAIN.eval_ids_file,
68+
"csv_file": cfg.aero_coeff,
69+
"num_points": cfg.TRAIN.num_points,
70+
},
71+
"batch_size": cfg.TRAIN.batch_size,
72+
"num_workers": cfg.TRAIN.num_workers,
73+
}
74+
75+
drivaernetplusplus_valid = ppsci.validate.SupervisedValidator(
76+
valid_dataloader_cfg,
77+
loss=ppsci.loss.MSELoss("mean"),
78+
metric={"MSE": ppsci.metric.MSE()},
79+
name="DrivAerNetplusplus_valid",
80+
)
81+
82+
validator = {drivaernetplusplus_valid.name: drivaernetplusplus_valid}
83+
84+
# set optimizer
85+
lr_scheduler = ppsci.optimizer.lr_scheduler.ReduceOnPlateau(
86+
epochs=cfg.TRAIN.epochs,
87+
iters_per_epoch=(
88+
cfg.TRAIN.iters_per_epoch
89+
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size)
90+
+ 1
91+
),
92+
learning_rate=cfg.optimizer.lr,
93+
mode=cfg.TRAIN.scheduler.mode,
94+
patience=cfg.TRAIN.scheduler.patience,
95+
factor=cfg.TRAIN.scheduler.factor,
96+
verbose=cfg.TRAIN.scheduler.verbose,
97+
)()
98+
99+
optimizer = (
100+
ppsci.optimizer.Adam(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)(
101+
model
102+
)
103+
if cfg.optimizer.optimizer == "adam"
104+
else ppsci.optimizer.SGD(lr_scheduler, weight_decay=cfg.optimizer.weight_decay)(
105+
model
106+
)
107+
)
108+
109+
# initialize solver
110+
solver = ppsci.solver.Solver(
111+
model=model,
112+
iters_per_epoch=(
113+
cfg.TRAIN.iters_per_epoch
114+
// (paddle.distributed.get_world_size() * cfg.TRAIN.batch_size)
115+
+ 1
116+
),
117+
constraint=constraint,
118+
output_dir=cfg.output_dir,
119+
optimizer=optimizer,
120+
lr_scheduler=lr_scheduler,
121+
epochs=cfg.TRAIN.epochs,
122+
validator=validator,
123+
eval_during_train=cfg.TRAIN.eval_during_train,
124+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
125+
)
126+
127+
lr_scheduler.step = partial(lr_scheduler.step, metrics=solver.cur_metric)
128+
solver.lr_scheduler = lr_scheduler
129+
130+
# train model
131+
solver.train()
132+
133+
solver.eval()
134+
135+
136+
def evaluate(cfg: DictConfig):
137+
# set model
138+
model = ppsci.arch.RegPointNet(
139+
input_keys=cfg.MODEL.input_keys,
140+
label_keys=cfg.MODEL.output_keys,
141+
weight_keys=cfg.MODEL.weight_keys,
142+
args=cfg.MODEL,
143+
)
144+
145+
valid_dataloader_cfg = {
146+
"dataset": {
147+
"name": "DrivAerNetPlusPlusDataset",
148+
"root_dir": cfg.dataset_path,
149+
"input_keys": cfg.MODEL.input_keys,
150+
"label_keys": cfg.MODEL.output_keys,
151+
"weight_keys": cfg.MODEL.weight_keys,
152+
"subset_dir": cfg.subset_dir,
153+
"ids_file": cfg.EVAL.ids_file,
154+
"csv_file": cfg.aero_coeff,
155+
"num_points": cfg.EVAL.num_points,
156+
},
157+
"batch_size": cfg.EVAL.batch_size,
158+
"num_workers": cfg.EVAL.num_workers,
159+
}
160+
161+
drivaernetplusplus_valid = ppsci.validate.SupervisedValidator(
162+
valid_dataloader_cfg,
163+
loss=ppsci.loss.MSELoss("mean"),
164+
metric={
165+
"MSE": ppsci.metric.MSE(),
166+
"MAE": ppsci.metric.MAE(),
167+
"Max AE": ppsci.metric.MaxAE(),
168+
"R²": ppsci.metric.R2Score(),
169+
},
170+
name="DrivAerNetPlusPlus_valid",
171+
)
172+
173+
validator = {drivaernetplusplus_valid.name: drivaernetplusplus_valid}
174+
175+
solver = ppsci.solver.Solver(
176+
model=model,
177+
validator=validator,
178+
pretrained_model_path=cfg.EVAL.pretrained_model_path,
179+
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
180+
)
181+
182+
# evaluate model
183+
solver.eval()
184+
185+
186+
@hydra.main(
187+
version_base=None, config_path="./conf", config_name="drivaernetplusplus.yaml"
188+
)
189+
def main(cfg: DictConfig):
190+
warnings.filterwarnings("ignore")
191+
if cfg.mode == "train":
192+
train(cfg)
193+
elif cfg.mode == "eval":
194+
evaluate(cfg)
195+
else:
196+
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
197+
198+
199+
if __name__ == "__main__":
200+
main()

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ nav:
6767
- Darcy2D: zh/examples/darcy2d.md
6868
- DeepCFD: zh/examples/deepcfd.md
6969
- DrivAerNet: zh/examples/drivaernet.md
70+
- DrivAerNetPlusPlus: zh/examples/drivaernetplusplus.md
7071
- LDC2D_steady: zh/examples/ldc2d_steady.md
7172
- LDC2D_unsteady: zh/examples/ldc2d_unsteady.md
7273
- Labelfree_DNN_surrogate: zh/examples/labelfree_DNN_surrogate.md

ppsci/arch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from ppsci.arch.moflow_net import MoFlowNet, MoFlowProp # isort:skip
6161
from ppsci.utils import logger # isort:skip
6262
from ppsci.arch.regdgcnn import RegDGCNN # isort:skip
63+
from ppsci.arch.regpointnet import RegPointNet # isort:skip
6364
from ppsci.arch.ifm_mlp import IFMMLP # isort:skip
6465

6566
__all__ = [
@@ -110,6 +111,7 @@
110111
"VelocityDiscriminator",
111112
"VelocityGenerator",
112113
"RegDGCNN",
114+
"RegPointNet",
113115
"IFMMLP",
114116
]
115117

0 commit comments

Comments
 (0)