Skip to content

Commit d8e1a6b

Browse files
【AutoParallel】Add split_backward for vpp (#8479)
* add split_backward for vpp * polish * add test cast * polish * update test case * change the config * polish
1 parent a90f163 commit d8e1a6b

File tree

4 files changed

+117
-7
lines changed

4 files changed

+117
-7
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ class TrainingArguments:
634634
"enable_overlap_p2p_comm, overlap p2p communication with computation. \n"
635635
"enable_clear_every_step_cache, clear every step cache for pipeline parallel. \n"
636636
"disable_batch_p2p_comm, disable batched send/recv in pipeline parallel mode. \n"
637+
"enable_split_backward, only can be used in StaticGraph-AutoParallel! split the `backward` program into `backward_b` and `backward_w` to decrease the bubble in VPP pipeline mode when `acc_step == pp_degree`. it increase the memory! \n"
637638
)
638639
},
639640
)
@@ -1304,6 +1305,7 @@ def is_segment_parallel_supported():
13041305
# "enable_sharding_comm_overlap", # no implemenation for auto_parallel
13051306
# "enable_timer", # no implemenation for auto_parallel
13061307
# "disable_batch_p2p_comm", # no implemenation for auto_parallel
1308+
"enable_split_backward",
13071309
]:
13081310
raise ValueError(
13091311
f"Found unknown pipeline mode config {x}, accpet config is enable_send_recv_overlap."
@@ -1312,6 +1314,7 @@ def is_segment_parallel_supported():
13121314
pipeline = strategy.pipeline
13131315
pipeline.enable = True
13141316
pipeline.enable_send_recv_overlap = "enable_send_recv_overlap" in pipeline_parallel_config
1317+
pipeline.split_backward = "enable_split_backward" in pipeline_parallel_config
13151318
pipeline.accumulate_steps = self.gradient_accumulation_steps
13161319
pipeline.micro_batch_size = self.per_device_train_batch_size
13171320
pipeline.schedule_mode = self.pipeline_schedule_mode

scripts/distribute/ci_case_auto.sh

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ function llama_case_list_auto() {
6060
llama_dygraph_auto_bs8_fp32_DP2-MP2
6161
llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2
6262
llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2
63+
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw
6364
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2
6465

6566
llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1
@@ -1668,6 +1669,12 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2() {
16681669
export FLAGS_call_stack_level=3
16691670
export NVIDIA_TF32_OVERRIDE=0
16701671

1672+
export FLAGS_cudnn_deterministic=1
1673+
export FLAGS_embedding_deterministic=1
1674+
1675+
export CUDA_DEVICE_MAX_CONNECTIONS=1
1676+
export PARALLEL_CROSS_ENTROPY=true
1677+
16711678
task_name="llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2"
16721679
case_out_dir="output/$task_name"
16731680
case_log_dir="output/$task_name""_log"
@@ -1724,7 +1731,7 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2() {
17241731
--max_seq_length 4096 \
17251732
--sep_parallel_degree 1 \
17261733
--sequence_parallel false \
1727-
--pipeline_parallel_degree 2 \
1734+
--pipeline_parallel_degree 4 \
17281735
--sharding_parallel_degree 2 \
17291736
--tensor_parallel_degree 1 \
17301737
--virtual_pp_degree 3 \
@@ -1741,12 +1748,111 @@ function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2() {
17411748
--skip_memory_metrics 0 \
17421749
>>${log_path}/$FUNCNAME 2>&1
17431750
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1744-
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'interval_samples_per_second: ' '{print $2}' | awk -F ',' '{print $1}'`
1745-
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'current_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
1751+
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
1752+
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
1753+
echo "result: loss=$loss ips=$ips mem=$mem"
1754+
loss_base=7.5364624
1755+
ips_base=5442.5208
1756+
mem_base=22.387750148773193
1757+
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
1758+
echo "=========== $FUNCNAME run end ==========="
1759+
}
1760+
1761+
function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw() {
1762+
# Only A100 support this case.
1763+
if [ $IS_A100 -eq 0 ]; then
1764+
return
1765+
fi
1766+
echo "=========== $FUNCNAME run begin ==========="
1767+
export PYTHONPATH=$root_path/:$PYTHONPATH
1768+
export FLAGS_call_stack_level=3
1769+
export NVIDIA_TF32_OVERRIDE=0
1770+
1771+
export FLAGS_cudnn_deterministic=1
1772+
export FLAGS_embedding_deterministic=1
1773+
1774+
export CUDA_DEVICE_MAX_CONNECTIONS=1
1775+
export PARALLEL_CROSS_ENTROPY=true
1776+
1777+
task_name="llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw"
1778+
case_out_dir="output/$task_name"
1779+
case_log_dir="output/$task_name""_log"
1780+
rm -rf $case_out_dir
1781+
rm -rf $case_log_dir
1782+
1783+
python -u -m paddle.distributed.launch \
1784+
--gpus "0,1,2,3,4,5,6,7" \
1785+
--log_dir "output/$task_name""_log" \
1786+
./run_pretrain_auto.py \
1787+
--model_name_or_path "meta-llama/Llama-2-13b" \
1788+
--tokenizer_name_or_path "meta-llama/Llama-2-13b" \
1789+
--input_dir "./data" \
1790+
--output_dir "./output" \
1791+
--split 949,50,1 \
1792+
--weight_decay 0.01 \
1793+
--warmup_ratio 0.01 \
1794+
--max_grad_norm 1.0 \
1795+
--learning_rate 3e-05 \
1796+
--min_learning_rate 3e-06 \
1797+
--max_steps 30 \
1798+
--logging_steps 10 \
1799+
--eval_steps 1000 \
1800+
--save_steps 50000 \
1801+
--continue_training 0 \
1802+
--do_train true \
1803+
--do_eval false \
1804+
--do_predict false \
1805+
--disable_tqdm true \
1806+
--skip_profile_timer true \
1807+
--save_total_limit 2 \
1808+
--device gpu \
1809+
--disable_tqdm true \
1810+
--dataloader_num_workers 1 \
1811+
--distributed_dataloader 0 \
1812+
--enable_auto_parallel 1 \
1813+
--per_device_train_batch_size 1 \
1814+
--gradient_accumulation_steps 4 \
1815+
--per_device_eval_batch_size 1 \
1816+
--recompute false \
1817+
--recompute_use_reentrant true \
1818+
--recompute_granularity full \
1819+
--pp_recompute_interval 0 \
1820+
--bf16 true \
1821+
--fp16_opt_level "O2" \
1822+
--amp_master_grad true \
1823+
--fuse_attention_ffn false \
1824+
--fuse_attention_qkv true \
1825+
--fused_linear_param_grad_add 1 \
1826+
--fuse_sequence_parallel_allreduce false \
1827+
--use_flash_attention true \
1828+
--use_fused_rope true \
1829+
--use_fused_rms_norm true \
1830+
--max_seq_length 4096 \
1831+
--sep_parallel_degree 1 \
1832+
--sequence_parallel false \
1833+
--pipeline_parallel_degree 4 \
1834+
--sharding_parallel_degree 2 \
1835+
--tensor_parallel_degree 1 \
1836+
--virtual_pp_degree 3 \
1837+
--pipeline_schedule_mode "VPP" \
1838+
--sharding "stage2" \
1839+
--pipeline_parallel_config "enable_send_recv_overlap enable_split_backward" \
1840+
--data_parallel_config "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate" \
1841+
--sharding_parallel_config "enable_stage2_overlap" \
1842+
--tensor_parallel_config "enable_mp_async_allreduce" \
1843+
--to_static 1 \
1844+
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
1845+
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
1846+
--num_hidden_layers 12 \
1847+
--skip_memory_metrics 0 \
1848+
>>${log_path}/$FUNCNAME 2>&1
1849+
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
1850+
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
1851+
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 30' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
17461852
echo "result: loss=$loss ips=$ips mem=$mem"
1747-
loss_base=7.52383575
1748-
ips_base=12.4135
1749-
mem_base=29.140248775482178
1853+
loss_base=7.5364624
1854+
ips_base=5864.2898
1855+
mem_base=23.745134115219116
17501856
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
17511857
echo "=========== $FUNCNAME run end ==========="
17521858
}

scripts/distribute/run_ci.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ target_lists_for_llama=(
3030
"paddlenlp/trainer/auto_trainer.py"
3131
"paddlenlp/transformers/llama/modeling_auto_static.py"
3232
"paddlenlp/transformers/llama/modeling_auto.py"
33+
"paddlenlp/transformers/llama/modeling.py"
3334
"scripts/distribute"
3435
)
3536

tests/test_tipc/static/auto_parallel/llama2/pretrain_config_llama2_13b/pretrain-llama2_13b.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"data_parallel_config": "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate",
1313
"sharding_parallel_config": "enable_stage2_overlap",
1414
"tensor_parallel_config": "enable_mp_async_allreduce",
15-
"pipeline_parallel_config": "enable_send_recv_overlap",
15+
"pipeline_parallel_config": "enable_send_recv_overlap enable_split_backward",
1616
"pipeline_schedule_mode": "VPP",
1717
"virtual_pp_degree": 5,
1818
"sequence_parallel": 0,

0 commit comments

Comments
 (0)