Skip to content

Commit ae7f9da

Browse files
authored
Merge branch 'develop' into develop
2 parents 78bbade + 998ba83 commit ae7f9da

File tree

14 files changed

+2554
-8
lines changed

14 files changed

+2554
-8
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
export FLAGS_use_cuda_managed_memory=true
15+
export FLAGS_enable_pir_api=true
16+
17+
TRAINING_MODEL_RESUME="None"
18+
TRAINER_INSTANCES='127.0.0.1'
19+
MASTER='127.0.0.1:8080'
20+
TRAINERS_NUM=1 # nnodes, machine num
21+
TRAINING_GPUS_PER_NODE=8 # nproc_per_node
22+
DP_DEGREE=1 # data_parallel_degree
23+
MP_DEGREE=1 # tensor_parallel_degree
24+
PP_DEGREE=1 # pipeline_parallel_degree
25+
SHARDING_DEGREE=8 # sharding_parallel_degree
26+
27+
# real dp_parallel_degree = nnodes * nproc_per_node / tensor_parallel_degree / sharding_parallel_degree
28+
# Please make sure: nnodes * nproc_per_node >= tensor_parallel_degree * sharding_parallel_degree
29+
30+
config_file=config/DiT_XL_patch2.json
31+
OUTPUT_DIR=./output_trainer/DiT_XL_patch2_auto_trainer
32+
feature_path=./data/fastdit_imagenet256
33+
34+
per_device_train_batch_size=32
35+
gradient_accumulation_steps=1
36+
37+
resolution=256
38+
num_workers=8
39+
max_steps=7000000
40+
logging_steps=1
41+
save_steps=5000
42+
image_logging_steps=-1
43+
seed=0
44+
45+
max_grad_norm=-1
46+
47+
USE_AMP=True
48+
FP16_OPT_LEVEL="O2"
49+
50+
enable_tensorboard=True
51+
recompute=True
52+
enable_xformers=True
53+
to_static=0 # whether we use dynamic to static training
54+
55+
TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}"
56+
${TRAINING_PYTHON} train_image_generation_trainer_auto.py \
57+
--do_train \
58+
--feature_path ${feature_path} \
59+
--output_dir ${OUTPUT_DIR} \
60+
--per_device_train_batch_size ${per_device_train_batch_size} \
61+
--gradient_accumulation_steps ${gradient_accumulation_steps} \
62+
--learning_rate 1e-4 \
63+
--weight_decay 0.0 \
64+
--resolution ${resolution} \
65+
--max_steps ${max_steps} \
66+
--lr_scheduler_type "constant" \
67+
--warmup_steps 0 \
68+
--image_logging_steps ${image_logging_steps} \
69+
--logging_dir ${OUTPUT_DIR}/tb_log \
70+
--logging_steps ${logging_steps} \
71+
--save_steps ${save_steps} \
72+
--save_total_limit 50 \
73+
--dataloader_num_workers ${num_workers} \
74+
--vae_name_or_path stabilityai/sd-vae-ft-mse \
75+
--config_file ${config_file} \
76+
--num_inference_steps 25 \
77+
--use_ema True \
78+
--max_grad_norm ${max_grad_norm} \
79+
--overwrite_output_dir True \
80+
--disable_tqdm True \
81+
--fp16_opt_level ${FP16_OPT_LEVEL} \
82+
--seed ${seed} \
83+
--recompute ${recompute} \
84+
--enable_xformers_memory_efficient_attention ${enable_xformers} \
85+
--bf16 ${USE_AMP} \
86+
--amp_master_grad 1 \
87+
--dp_degree ${DP_DEGREE} \
88+
--tensor_parallel_degree ${MP_DEGREE} \
89+
--pipeline_parallel_degree ${PP_DEGREE} \
90+
--sharding_parallel_degree ${SHARDING_DEGREE} \
91+
--sharding "stage1" \
92+
--sharding_parallel_config "enable_stage1_overlap enable_stage1_tensor_fusion" \
93+
--hybrid_parallel_topo_order "sharding_first" \
94+
--sep_parallel_degree 1 \
95+
--enable_auto_parallel 1 \
96+
--to_static $to_static \
97+
# --fp16 ${USE_AMP} \
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
export FLAGS_use_cuda_managed_memory=true
15+
export FLAGS_enable_pir_api=true
16+
17+
TRAINING_MODEL_RESUME="None"
18+
TRAINER_INSTANCES='127.0.0.1'
19+
MASTER='127.0.0.1:8080'
20+
TRAINERS_NUM=1 # nnodes, machine num
21+
TRAINING_GPUS_PER_NODE=8 # nproc_per_node
22+
DP_DEGREE=1 # data_parallel_degree
23+
MP_DEGREE=4 # tensor_parallel_degree
24+
PP_DEGREE=1 # pipeline_parallel_degree
25+
SHARDING_DEGREE=2 # sharding_parallel_degree
26+
27+
# real dp_parallel_degree = nnodes * nproc_per_node / tensor_parallel_degree / sharding_parallel_degree
28+
# Please make sure: nnodes * nproc_per_node >= tensor_parallel_degree * sharding_parallel_degree
29+
30+
config_file=config/LargeDiT_3B_patch2.json
31+
OUTPUT_DIR=./output_trainer/LargeDiT_3B_patch2_auto_trainer
32+
feature_path=./data/fastdit_imagenet256
33+
34+
per_device_train_batch_size=32
35+
gradient_accumulation_steps=1
36+
37+
resolution=256
38+
num_workers=8
39+
max_steps=7000000
40+
logging_steps=1
41+
save_steps=5000
42+
image_logging_steps=-1
43+
seed=0
44+
45+
max_grad_norm=2.0
46+
47+
USE_AMP=True
48+
FP16_OPT_LEVEL="O2"
49+
50+
enable_tensorboard=True
51+
recompute=True
52+
enable_xformers=True
53+
to_static=0 # whether we use dynamic to static training
54+
55+
TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}"
56+
${TRAINING_PYTHON} train_image_generation_trainer_auto.py \
57+
--do_train \
58+
--feature_path ${feature_path} \
59+
--output_dir ${OUTPUT_DIR} \
60+
--per_device_train_batch_size ${per_device_train_batch_size} \
61+
--gradient_accumulation_steps ${gradient_accumulation_steps} \
62+
--learning_rate 1e-4 \
63+
--resolution ${resolution} \
64+
--weight_decay 0.0 \
65+
--max_steps ${max_steps} \
66+
--lr_scheduler_type "constant" \
67+
--warmup_steps 0 \
68+
--image_logging_steps ${image_logging_steps} \
69+
--logging_dir ${OUTPUT_DIR}/tb_log \
70+
--logging_steps ${logging_steps} \
71+
--save_steps ${save_steps} \
72+
--save_total_limit 50 \
73+
--dataloader_num_workers ${num_workers} \
74+
--vae_name_or_path stabilityai/sd-vae-ft-mse \
75+
--config_file ${config_file} \
76+
--num_inference_steps 25 \
77+
--use_ema True \
78+
--max_grad_norm ${max_grad_norm} \
79+
--overwrite_output_dir True \
80+
--disable_tqdm True \
81+
--fp16_opt_level ${FP16_OPT_LEVEL} \
82+
--seed ${seed} \
83+
--recompute ${recompute} \
84+
--enable_xformers_memory_efficient_attention ${enable_xformers} \
85+
--bf16 ${USE_AMP} \
86+
--amp_master_grad 1 \
87+
--dp_degree ${DP_DEGREE} \
88+
--tensor_parallel_degree ${MP_DEGREE} \
89+
--pipeline_parallel_degree ${PP_DEGREE} \
90+
--sharding_parallel_degree ${SHARDING_DEGREE} \
91+
--sharding "stage1" \
92+
--sharding_parallel_config "enable_stage1_overlap enable_stage1_tensor_fusion" \
93+
--hybrid_parallel_topo_order "sharding_first" \
94+
--sep_parallel_degree 1 \
95+
--enable_auto_parallel 1 \
96+
--to_static $to_static \
97+
# --fp16 ${USE_AMP} \

ppdiffusers/examples/class_conditional_image_generation/DiT/README.md

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,186 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" \
161161
--ckpt_every ${save_steps} \
162162
```
163163

164+
### 1.5 使用分布式自动并行进行训练
165+
166+
#### 1.5.1 硬件要求
167+
同1.3.1,与之前手动并行方式显存占用差别不大
168+
169+
#### 1.5.2 单机多卡训练Dit
170+
可以直接运行`sh 0_run_train_dit_trainer_auto.sh`,或者
171+
172+
```bash
173+
FLAGS_enable_pir_api=true
174+
175+
TRAINING_MODEL_RESUME="None"
176+
TRAINER_INSTANCES='127.0.0.1'
177+
MASTER='127.0.0.1:8080'
178+
TRAINERS_NUM=1 # nnodes, machine num
179+
TRAINING_GPUS_PER_NODE=8 # nproc_per_node
180+
DP_DEGREE=1 # data_parallel_degree
181+
MP_DEGREE=1 # tensor_parallel_degree
182+
PP_DEGREE=1 # pipeline_parallel_degree
183+
SHARDING_DEGREE=8 # sharding_parallel_degree
184+
185+
# real dp_parallel_degree = nnodes * nproc_per_node / tensor_parallel_degree / sharding_parallel_degree
186+
# Please make sure: nnodes * nproc_per_node >= tensor_parallel_degree * sharding_parallel_degree
187+
188+
config_file=config/DiT_XL_patch2.json
189+
OUTPUT_DIR=./output_trainer/DiT_XL_patch2_auto_trainer
190+
feature_path=./data/fastdit_imagenet256
191+
192+
per_device_train_batch_size=32
193+
gradient_accumulation_steps=1
194+
195+
resolution=256
196+
num_workers=8
197+
max_steps=7000000
198+
logging_steps=1
199+
save_steps=5000
200+
image_logging_steps=-1
201+
seed=0
202+
203+
max_grad_norm=-1
204+
205+
USE_AMP=True
206+
FP16_OPT_LEVEL="O2"
207+
208+
enable_tensorboard=True
209+
recompute=True
210+
enable_xformers=True
211+
to_static=0 # whether we use dynamic to static training
212+
213+
TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}"
214+
${TRAINING_PYTHON} train_image_generation_trainer_auto.py \
215+
--do_train \
216+
--feature_path ${feature_path} \
217+
--output_dir ${OUTPUT_DIR} \
218+
--per_device_train_batch_size ${per_device_train_batch_size} \
219+
--gradient_accumulation_steps ${gradient_accumulation_steps} \
220+
--learning_rate 1e-4 \
221+
--weight_decay 0.0 \
222+
--resolution ${resolution} \
223+
--max_steps ${max_steps} \
224+
--lr_scheduler_type "constant" \
225+
--warmup_steps 0 \
226+
--image_logging_steps ${image_logging_steps} \
227+
--logging_dir ${OUTPUT_DIR}/tb_log \
228+
--logging_steps ${logging_steps} \
229+
--save_steps ${save_steps} \
230+
--save_total_limit 50 \
231+
--dataloader_num_workers ${num_workers} \
232+
--vae_name_or_path stabilityai/sd-vae-ft-mse \
233+
--config_file ${config_file} \
234+
--num_inference_steps 25 \
235+
--use_ema True \
236+
--max_grad_norm ${max_grad_norm} \
237+
--overwrite_output_dir True \
238+
--disable_tqdm True \
239+
--fp16_opt_level ${FP16_OPT_LEVEL} \
240+
--seed ${seed} \
241+
--recompute ${recompute} \
242+
--enable_xformers_memory_efficient_attention ${enable_xformers} \
243+
--bf16 ${USE_AMP} \
244+
--amp_master_grad 1 \
245+
--dp_degree ${DP_DEGREE} \
246+
--tensor_parallel_degree ${MP_DEGREE} \
247+
--pipeline_parallel_degree ${PP_DEGREE} \
248+
--sharding_parallel_degree ${SHARDING_DEGREE} \
249+
--sharding "stage1" \
250+
--sharding_parallel_config "enable_stage1_overlap enable_stage1_tensor_fusion" \
251+
--hybrid_parallel_topo_order "sharding_first" \
252+
--sep_parallel_degree 1 \
253+
--enable_auto_parallel 1 \
254+
--to_static $to_static
255+
```
256+
257+
#### 1.5.3 单机多卡训练LargeDit
258+
可以直接运行`sh 4_run_train_largedit_3b_trainer_auto.sh`,或者
259+
260+
```bash
261+
FLAGS_enable_pir_api=true
262+
263+
TRAINING_MODEL_RESUME="None"
264+
TRAINER_INSTANCES='127.0.0.1'
265+
MASTER='127.0.0.1:8080'
266+
TRAINERS_NUM=1 # nnodes, machine num
267+
TRAINING_GPUS_PER_NODE=8 # nproc_per_node
268+
DP_DEGREE=1 # data_parallel_degree
269+
MP_DEGREE=4 # tensor_parallel_degree
270+
PP_DEGREE=1 # pipeline_parallel_degree
271+
SHARDING_DEGREE=2 # sharding_parallel_degree
272+
273+
# real dp_parallel_degree = nnodes * nproc_per_node / tensor_parallel_degree / sharding_parallel_degree
274+
# Please make sure: nnodes * nproc_per_node >= tensor_parallel_degree * sharding_parallel_degree
275+
276+
config_file=config/LargeDiT_3B_patch2.json
277+
OUTPUT_DIR=./output_trainer/LargeDiT_3B_patch2_auto_trainer
278+
feature_path=./data/fastdit_imagenet256
279+
280+
per_device_train_batch_size=32
281+
gradient_accumulation_steps=1
282+
283+
resolution=256
284+
num_workers=8
285+
max_steps=7000000
286+
logging_steps=1
287+
save_steps=5000
288+
image_logging_steps=-1
289+
seed=0
290+
291+
max_grad_norm=2.0
292+
293+
USE_AMP=True
294+
FP16_OPT_LEVEL="O2"
295+
296+
enable_tensorboard=True
297+
recompute=True
298+
enable_xformers=True
299+
to_static=0 # whether we use dynamic to static training
300+
301+
TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}"
302+
${TRAINING_PYTHON} train_image_generation_trainer_auto.py \
303+
--do_train \
304+
--feature_path ${feature_path} \
305+
--output_dir ${OUTPUT_DIR} \
306+
--per_device_train_batch_size ${per_device_train_batch_size} \
307+
--gradient_accumulation_steps ${gradient_accumulation_steps} \
308+
--learning_rate 1e-4 \
309+
--resolution ${resolution} \
310+
--weight_decay 0.0 \
311+
--max_steps ${max_steps} \
312+
--lr_scheduler_type "constant" \
313+
--warmup_steps 0 \
314+
--image_logging_steps ${image_logging_steps} \
315+
--logging_dir ${OUTPUT_DIR}/tb_log \
316+
--logging_steps ${logging_steps} \
317+
--save_steps ${save_steps} \
318+
--save_total_limit 50 \
319+
--dataloader_num_workers ${num_workers} \
320+
--vae_name_or_path stabilityai/sd-vae-ft-mse \
321+
--config_file ${config_file} \
322+
--num_inference_steps 25 \
323+
--use_ema True \
324+
--max_grad_norm ${max_grad_norm} \
325+
--overwrite_output_dir True \
326+
--disable_tqdm True \
327+
--fp16_opt_level ${FP16_OPT_LEVEL} \
328+
--seed ${seed} \
329+
--recompute ${recompute} \
330+
--enable_xformers_memory_efficient_attention ${enable_xformers} \
331+
--bf16 ${USE_AMP} \
332+
--amp_master_grad 1 \
333+
--dp_degree ${DP_DEGREE} \
334+
--tensor_parallel_degree ${MP_DEGREE} \
335+
--pipeline_parallel_degree ${PP_DEGREE} \
336+
--sharding_parallel_degree ${SHARDING_DEGREE} \
337+
--sharding "stage1" \
338+
--sharding_parallel_config "enable_stage1_overlap enable_stage1_tensor_fusion" \
339+
--hybrid_parallel_topo_order "sharding_first" \
340+
--sep_parallel_degree 1 \
341+
--enable_auto_parallel 1 \
342+
--to_static $to_static
343+
```
164344

165345
## 2 模型推理
166346

ppdiffusers/examples/class_conditional_image_generation/DiT/config/LargeDiT_3B_patch2.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
"in_channels": 4,
55
"num_layers": 32,
66
"num_attention_heads": 32,
7-
"attention_head_dim": 96
7+
"attention_head_dim": 96,
8+
"qk_norm": false
89
}

0 commit comments

Comments
 (0)