Skip to content

support auto parallel in dit and largedit #551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
89bd742
support auto parallel in dit and largedit
jeff41404 May 23, 2024
98be16e
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 May 23, 2024
1eaba53
fix import error
jeff41404 May 27, 2024
0a64e2d
fix issue of flatten
jeff41404 May 30, 2024
63544b3
use transpose instead of einsum to work around issue of spmd of einsu…
jeff41404 Jun 3, 2024
4ca8ca0
fix issue of incorrectly shard adaln_input and rope
jeff41404 Jun 5, 2024
781203d
fix bug
jeff41404 Jun 6, 2024
e86ff8e
delete print
jeff41404 Jun 17, 2024
9c0f0e3
fix printing issues and clean up useless code
jeff41404 Jun 20, 2024
95a6a9a
add infrastructure of pipeline parallel and dynamic to static
jeff41404 Jun 24, 2024
39f6a28
fix the issues that block dynamic to static operation
jeff41404 Jun 26, 2024
28e76d8
delete pass in train_image_generation_trainer_auto.py
jeff41404 Jul 26, 2024
fbd226b
resolve conflict in trainer_args.py
jeff41404 Jul 26, 2024
26074be
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Jul 26, 2024
b4553d9
delete dist.reshard in gaussian_diffusion.py
jeff41404 Jul 26, 2024
04e4ed2
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Aug 15, 2024
ccf58d5
support pir and use non fused rms_norm
jeff41404 Sep 13, 2024
2439bff
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Sep 13, 2024
240179d
temporarily remove modifications to ppdiffusers/ppdiffusers/models
jeff41404 Sep 14, 2024
ce3f2f7
Merge branch 'develop' into add_auto_parallel_in_dit
jerrywgz Sep 19, 2024
6986ab4
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Sep 23, 2024
45f9084
fix issue of bf16 loss error in d2s pir
jeff41404 Sep 23, 2024
f1d49fc
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Sep 24, 2024
215916b
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Sep 24, 2024
e0e9c40
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Sep 24, 2024
14060ba
set use_fused_rms_norm to true
jeff41404 Sep 25, 2024
c6d3d6d
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Sep 26, 2024
7df5901
clean up annotations and debug print information, remove _extract_int…
jeff41404 Sep 27, 2024
c9809d9
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Sep 29, 2024
9d298ef
add documents
jeff41404 Oct 8, 2024
c33fd6e
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Oct 8, 2024
a9f4d93
modify shape of pos_embed to solve problem of split optimizer state
jeff41404 Oct 9, 2024
686464a
Merge branch 'develop' into add_auto_parallel_in_dit
jeff41404 Oct 9, 2024
3756cee
Merge branch 'develop' into add_auto_parallel_in_dit
nemonameless Oct 10, 2024
b45f034
Merge branch 'develop' into add_auto_parallel_in_dit
nemonameless Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export FLAGS_use_cuda_managed_memory=true
export FLAGS_enable_pir_api=true

TRAINING_MODEL_RESUME="None"
TRAINER_INSTANCES='127.0.0.1'
MASTER='127.0.0.1:8080'
TRAINERS_NUM=1 # nnodes, machine num
TRAINING_GPUS_PER_NODE=8 # nproc_per_node
DP_DEGREE=1 # data_parallel_degree
MP_DEGREE=1 # tensor_parallel_degree
PP_DEGREE=1 # pipeline_parallel_degree
SHARDING_DEGREE=8 # sharding_parallel_degree

# real dp_parallel_degree = nnodes * nproc_per_node / tensor_parallel_degree / sharding_parallel_degree
# Please make sure: nnodes * nproc_per_node >= tensor_parallel_degree * sharding_parallel_degree

config_file=config/DiT_XL_patch2.json
OUTPUT_DIR=./output_trainer/DiT_XL_patch2_auto_trainer
feature_path=./data/fastdit_imagenet256

per_device_train_batch_size=32
gradient_accumulation_steps=1

resolution=256
num_workers=8
max_steps=7000000
logging_steps=1
save_steps=5000
image_logging_steps=-1
seed=0

max_grad_norm=-1

USE_AMP=True
FP16_OPT_LEVEL="O2"

enable_tensorboard=True
recompute=True
enable_xformers=True
to_static=0 # whether we use dynamic to static training

TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}"
${TRAINING_PYTHON} train_image_generation_trainer_auto.py \
--do_train \
--feature_path ${feature_path} \
--output_dir ${OUTPUT_DIR} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate 1e-4 \
--weight_decay 0.0 \
--resolution ${resolution} \
--max_steps ${max_steps} \
--lr_scheduler_type "constant" \
--warmup_steps 0 \
--image_logging_steps ${image_logging_steps} \
--logging_dir ${OUTPUT_DIR}/tb_log \
--logging_steps ${logging_steps} \
--save_steps ${save_steps} \
--save_total_limit 50 \
--dataloader_num_workers ${num_workers} \
--vae_name_or_path stabilityai/sd-vae-ft-mse \
--config_file ${config_file} \
--num_inference_steps 25 \
--use_ema True \
--max_grad_norm ${max_grad_norm} \
--overwrite_output_dir True \
--disable_tqdm True \
--fp16_opt_level ${FP16_OPT_LEVEL} \
--seed ${seed} \
--recompute ${recompute} \
--enable_xformers_memory_efficient_attention ${enable_xformers} \
--bf16 ${USE_AMP} \
--amp_master_grad 1 \
--dp_degree ${DP_DEGREE} \
--tensor_parallel_degree ${MP_DEGREE} \
--pipeline_parallel_degree ${PP_DEGREE} \
--sharding_parallel_degree ${SHARDING_DEGREE} \
--sharding "stage1" \
--sharding_parallel_config "enable_stage1_overlap enable_stage1_tensor_fusion" \
--hybrid_parallel_topo_order "sharding_first" \
--sep_parallel_degree 1 \
--enable_auto_parallel 1 \
--to_static $to_static \
# --fp16 ${USE_AMP} \
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
export FLAGS_use_cuda_managed_memory=true
export FLAGS_enable_pir_api=true

TRAINING_MODEL_RESUME="None"
TRAINER_INSTANCES='127.0.0.1'
MASTER='127.0.0.1:8080'
TRAINERS_NUM=1 # nnodes, machine num
TRAINING_GPUS_PER_NODE=8 # nproc_per_node
DP_DEGREE=1 # data_parallel_degree
MP_DEGREE=4 # tensor_parallel_degree
PP_DEGREE=1 # pipeline_parallel_degree
SHARDING_DEGREE=2 # sharding_parallel_degree

# real dp_parallel_degree = nnodes * nproc_per_node / tensor_parallel_degree / sharding_parallel_degree
# Please make sure: nnodes * nproc_per_node >= tensor_parallel_degree * sharding_parallel_degree

config_file=config/LargeDiT_3B_patch2.json
OUTPUT_DIR=./output_trainer/LargeDiT_3B_patch2_auto_trainer
feature_path=./data/fastdit_imagenet256

per_device_train_batch_size=32
gradient_accumulation_steps=1

resolution=256
num_workers=8
max_steps=7000000
logging_steps=1
save_steps=5000
image_logging_steps=-1
seed=0

max_grad_norm=2.0

USE_AMP=True
FP16_OPT_LEVEL="O2"

enable_tensorboard=True
recompute=True
enable_xformers=True
to_static=0 # whether we use dynamic to static training

TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}"
${TRAINING_PYTHON} train_image_generation_trainer_auto.py \
--do_train \
--feature_path ${feature_path} \
--output_dir ${OUTPUT_DIR} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate 1e-4 \
--resolution ${resolution} \
--weight_decay 0.0 \
--max_steps ${max_steps} \
--lr_scheduler_type "constant" \
--warmup_steps 0 \
--image_logging_steps ${image_logging_steps} \
--logging_dir ${OUTPUT_DIR}/tb_log \
--logging_steps ${logging_steps} \
--save_steps ${save_steps} \
--save_total_limit 50 \
--dataloader_num_workers ${num_workers} \
--vae_name_or_path stabilityai/sd-vae-ft-mse \
--config_file ${config_file} \
--num_inference_steps 25 \
--use_ema True \
--max_grad_norm ${max_grad_norm} \
--overwrite_output_dir True \
--disable_tqdm True \
--fp16_opt_level ${FP16_OPT_LEVEL} \
--seed ${seed} \
--recompute ${recompute} \
--enable_xformers_memory_efficient_attention ${enable_xformers} \
--bf16 ${USE_AMP} \
--amp_master_grad 1 \
--dp_degree ${DP_DEGREE} \
--tensor_parallel_degree ${MP_DEGREE} \
--pipeline_parallel_degree ${PP_DEGREE} \
--sharding_parallel_degree ${SHARDING_DEGREE} \
--sharding "stage1" \
--sharding_parallel_config "enable_stage1_overlap enable_stage1_tensor_fusion" \
--hybrid_parallel_topo_order "sharding_first" \
--sep_parallel_degree 1 \
--enable_auto_parallel 1 \
--to_static $to_static \
# --fp16 ${USE_AMP} \
180 changes: 180 additions & 0 deletions ppdiffusers/examples/class_conditional_image_generation/DiT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,186 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" \
--ckpt_every ${save_steps} \
```

### 1.5 使用分布式自动并行进行训练

#### 1.5.1 硬件要求
同1.3.1,与之前手动并行方式显存占用差别不大

#### 1.5.2 单机多卡训练Dit
可以直接运行`sh 0_run_train_dit_trainer_auto.sh`,或者

```bash
FLAGS_enable_pir_api=true

TRAINING_MODEL_RESUME="None"
TRAINER_INSTANCES='127.0.0.1'
MASTER='127.0.0.1:8080'
TRAINERS_NUM=1 # nnodes, machine num
TRAINING_GPUS_PER_NODE=8 # nproc_per_node
DP_DEGREE=1 # data_parallel_degree
MP_DEGREE=1 # tensor_parallel_degree
PP_DEGREE=1 # pipeline_parallel_degree
SHARDING_DEGREE=8 # sharding_parallel_degree

# real dp_parallel_degree = nnodes * nproc_per_node / tensor_parallel_degree / sharding_parallel_degree
# Please make sure: nnodes * nproc_per_node >= tensor_parallel_degree * sharding_parallel_degree

config_file=config/DiT_XL_patch2.json
OUTPUT_DIR=./output_trainer/DiT_XL_patch2_auto_trainer
feature_path=./data/fastdit_imagenet256

per_device_train_batch_size=32
gradient_accumulation_steps=1

resolution=256
num_workers=8
max_steps=7000000
logging_steps=1
save_steps=5000
image_logging_steps=-1
seed=0

max_grad_norm=-1

USE_AMP=True
FP16_OPT_LEVEL="O2"

enable_tensorboard=True
recompute=True
enable_xformers=True
to_static=0 # whether we use dynamic to static training

TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}"
${TRAINING_PYTHON} train_image_generation_trainer_auto.py \
--do_train \
--feature_path ${feature_path} \
--output_dir ${OUTPUT_DIR} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate 1e-4 \
--weight_decay 0.0 \
--resolution ${resolution} \
--max_steps ${max_steps} \
--lr_scheduler_type "constant" \
--warmup_steps 0 \
--image_logging_steps ${image_logging_steps} \
--logging_dir ${OUTPUT_DIR}/tb_log \
--logging_steps ${logging_steps} \
--save_steps ${save_steps} \
--save_total_limit 50 \
--dataloader_num_workers ${num_workers} \
--vae_name_or_path stabilityai/sd-vae-ft-mse \
--config_file ${config_file} \
--num_inference_steps 25 \
--use_ema True \
--max_grad_norm ${max_grad_norm} \
--overwrite_output_dir True \
--disable_tqdm True \
--fp16_opt_level ${FP16_OPT_LEVEL} \
--seed ${seed} \
--recompute ${recompute} \
--enable_xformers_memory_efficient_attention ${enable_xformers} \
--bf16 ${USE_AMP} \
--amp_master_grad 1 \
--dp_degree ${DP_DEGREE} \
--tensor_parallel_degree ${MP_DEGREE} \
--pipeline_parallel_degree ${PP_DEGREE} \
--sharding_parallel_degree ${SHARDING_DEGREE} \
--sharding "stage1" \
--sharding_parallel_config "enable_stage1_overlap enable_stage1_tensor_fusion" \
--hybrid_parallel_topo_order "sharding_first" \
--sep_parallel_degree 1 \
--enable_auto_parallel 1 \
--to_static $to_static
```

#### 1.5.3 单机多卡训练LargeDit
可以直接运行`sh 4_run_train_largedit_3b_trainer_auto.sh`,或者

```bash
FLAGS_enable_pir_api=true

TRAINING_MODEL_RESUME="None"
TRAINER_INSTANCES='127.0.0.1'
MASTER='127.0.0.1:8080'
TRAINERS_NUM=1 # nnodes, machine num
TRAINING_GPUS_PER_NODE=8 # nproc_per_node
DP_DEGREE=1 # data_parallel_degree
MP_DEGREE=4 # tensor_parallel_degree
PP_DEGREE=1 # pipeline_parallel_degree
SHARDING_DEGREE=2 # sharding_parallel_degree

# real dp_parallel_degree = nnodes * nproc_per_node / tensor_parallel_degree / sharding_parallel_degree
# Please make sure: nnodes * nproc_per_node >= tensor_parallel_degree * sharding_parallel_degree

config_file=config/LargeDiT_3B_patch2.json
OUTPUT_DIR=./output_trainer/LargeDiT_3B_patch2_auto_trainer
feature_path=./data/fastdit_imagenet256

per_device_train_batch_size=32
gradient_accumulation_steps=1

resolution=256
num_workers=8
max_steps=7000000
logging_steps=1
save_steps=5000
image_logging_steps=-1
seed=0

max_grad_norm=2.0

USE_AMP=True
FP16_OPT_LEVEL="O2"

enable_tensorboard=True
recompute=True
enable_xformers=True
to_static=0 # whether we use dynamic to static training

TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}"
${TRAINING_PYTHON} train_image_generation_trainer_auto.py \
--do_train \
--feature_path ${feature_path} \
--output_dir ${OUTPUT_DIR} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate 1e-4 \
--resolution ${resolution} \
--weight_decay 0.0 \
--max_steps ${max_steps} \
--lr_scheduler_type "constant" \
--warmup_steps 0 \
--image_logging_steps ${image_logging_steps} \
--logging_dir ${OUTPUT_DIR}/tb_log \
--logging_steps ${logging_steps} \
--save_steps ${save_steps} \
--save_total_limit 50 \
--dataloader_num_workers ${num_workers} \
--vae_name_or_path stabilityai/sd-vae-ft-mse \
--config_file ${config_file} \
--num_inference_steps 25 \
--use_ema True \
--max_grad_norm ${max_grad_norm} \
--overwrite_output_dir True \
--disable_tqdm True \
--fp16_opt_level ${FP16_OPT_LEVEL} \
--seed ${seed} \
--recompute ${recompute} \
--enable_xformers_memory_efficient_attention ${enable_xformers} \
--bf16 ${USE_AMP} \
--amp_master_grad 1 \
--dp_degree ${DP_DEGREE} \
--tensor_parallel_degree ${MP_DEGREE} \
--pipeline_parallel_degree ${PP_DEGREE} \
--sharding_parallel_degree ${SHARDING_DEGREE} \
--sharding "stage1" \
--sharding_parallel_config "enable_stage1_overlap enable_stage1_tensor_fusion" \
--hybrid_parallel_topo_order "sharding_first" \
--sep_parallel_degree 1 \
--enable_auto_parallel 1 \
--to_static $to_static
```

## 2 模型推理

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"in_channels": 4,
"num_layers": 32,
"num_attention_heads": 32,
"attention_head_dim": 96
"attention_head_dim": 96,
"qk_norm": false
}
Loading