Skip to content

【AutoParallel】Add split_backward for vpp #8479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@
"enable_overlap_p2p_comm, overlap p2p communication with computation. \n"
"enable_clear_every_step_cache, clear every step cache for pipeline parallel. \n"
"disable_batch_p2p_comm, disable batched send/recv in pipeline parallel mode. \n"
"enable_split_backward, only can be used in StaticGraph-AutoParallel! split the `backward` program into `backward_b` and `backward_w` to decrease the bubble in VPP pipeline mode when `acc_step == pp_degree`. it increase the memory! \n"
)
},
)
Expand Down Expand Up @@ -1304,6 +1305,7 @@
# "enable_sharding_comm_overlap", # no implemenation for auto_parallel
# "enable_timer", # no implemenation for auto_parallel
# "disable_batch_p2p_comm", # no implemenation for auto_parallel
"enable_split_backward",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, accpet config is enable_send_recv_overlap."
Expand All @@ -1312,6 +1314,7 @@
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.enable_send_recv_overlap = "enable_send_recv_overlap" in pipeline_parallel_config
pipeline.split_backward = "enable_split_backward" in pipeline_parallel_config

Check warning on line 1317 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1317

Added line #L1317 was not covered by tests
pipeline.accumulate_steps = self.gradient_accumulation_steps
pipeline.micro_batch_size = self.per_device_train_batch_size
pipeline.schedule_mode = self.pipeline_schedule_mode
Expand Down
118 changes: 112 additions & 6 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function llama_case_list_auto() {
llama_dygraph_auto_bs8_fp32_DP2-MP2
llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2
llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2

llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1
Expand Down Expand Up @@ -1668,6 +1669,12 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2() {
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0

export FLAGS_cudnn_deterministic=1
export FLAGS_embedding_deterministic=1

export CUDA_DEVICE_MAX_CONNECTIONS=1
export PARALLEL_CROSS_ENTROPY=true

task_name="llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
Expand Down Expand Up @@ -1724,7 +1731,7 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2() {
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 2 \
--pipeline_parallel_degree 4 \
--sharding_parallel_degree 2 \
--tensor_parallel_degree 1 \
--virtual_pp_degree 3 \
Expand All @@ -1741,12 +1748,111 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2() {
--skip_memory_metrics 0 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'interval_samples_per_second: ' '{print $2}' | awk -F ',' '{print $1}'`
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'current_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=7.5364624
ips_base=5442.5208
mem_base=22.387750148773193
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}

function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw() {
# Only A100 support this case.
if [ $IS_A100 -eq 0 ]; then
return
fi
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0

export FLAGS_cudnn_deterministic=1
export FLAGS_embedding_deterministic=1

export CUDA_DEVICE_MAX_CONNECTIONS=1
export PARALLEL_CROSS_ENTROPY=true

task_name="llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
rm -rf $case_out_dir
rm -rf $case_log_dir

python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name""_log" \
./run_pretrain_auto.py \
--model_name_or_path "meta-llama/Llama-2-13b" \
--tokenizer_name_or_path "meta-llama/Llama-2-13b" \
--input_dir "./data" \
--output_dir "./output" \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 30 \
--logging_steps 10 \
--eval_steps 1000 \
--save_steps 50000 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 1 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--pp_recompute_interval 0 \
--bf16 true \
--fp16_opt_level "O2" \
--amp_master_grad true \
--fuse_attention_ffn false \
--fuse_attention_qkv true \
--fused_linear_param_grad_add 1 \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention true \
--use_fused_rope true \
--use_fused_rms_norm true \
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 4 \
--sharding_parallel_degree 2 \
--tensor_parallel_degree 1 \
--virtual_pp_degree 3 \
--pipeline_schedule_mode "VPP" \
--sharding "stage2" \
--pipeline_parallel_config "enable_send_recv_overlap enable_split_backward" \
--data_parallel_config "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate" \
--sharding_parallel_config "enable_stage2_overlap" \
--tensor_parallel_config "enable_mp_async_allreduce" \
--to_static 1 \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--num_hidden_layers 12 \
--skip_memory_metrics 0 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=7.52383575
ips_base=12.4135
mem_base=29.140248775482178
loss_base=7.5364624
ips_base=5864.2898
mem_base=23.745134115219116
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}
Expand Down
1 change: 1 addition & 0 deletions scripts/distribute/run_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ target_lists_for_llama=(
"paddlenlp/trainer/auto_trainer.py"
"paddlenlp/transformers/llama/modeling_auto_static.py"
"paddlenlp/transformers/llama/modeling_auto.py"
"paddlenlp/transformers/llama/modeling.py"
"scripts/distribute"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"data_parallel_config": "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate",
"sharding_parallel_config": "enable_stage2_overlap",
"tensor_parallel_config": "enable_mp_async_allreduce",
"pipeline_parallel_config": "enable_send_recv_overlap",
"pipeline_parallel_config": "enable_send_recv_overlap enable_split_backward",
"pipeline_schedule_mode": "VPP",
"virtual_pp_degree": 5,
"sequence_parallel": 0,
Expand Down
Loading