Skip to content

[AutoParallel] Add test for PIR recompute #9621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 78 additions & 74 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -667,80 +667,84 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() {
case_log_dir="output/$task_name""_log"

for to_static in "0" "1"; do
rm -rf $case_out_dir
rm -rf $case_log_dir
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3" \
--log_dir $case_log_dir \
run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 10 \
--logging_steps 10 \
--eval_steps 1000 \
--save_steps 50000 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--per_device_eval_batch_size 2 \
--recompute false \
--bf16 1\
--fp16_opt_level "O2" \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--amp_master_grad 1 \
--fuse_attention_ffn false \
--fuse_attention_qkv false \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention 0 \
--use_fused_rope false \
--use_fused_rms_norm 0 \
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel true \
--pipeline_parallel_degree 1 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 2 \
--virtual_pp_degree 1 \
--sharding "" \
--to_static ${to_static} \
--num_hidden_layers 4 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=-1
mem=-1
echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem"
loss_base=9.16783295
loss_md5_base=8ea72495fba4e1b9ba004b4431e27218
if [ $IS_A100 -ne 0 ] && [ $to_static -eq 0 ];then
loss_base=9.37966919
elif [ $IS_A100 -ne 0 ] && [ $to_static -eq 1 ];then
loss_base=9.38012543
fi
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
# check_md5_result $FUNCNAME ${loss_md5_base} ${loss_md5}
for use_recompute in "1" "0"; do
rm -rf $case_out_dir
rm -rf $case_log_dir
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3" \
--log_dir $case_log_dir \
run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 10 \
--logging_steps 10 \
--eval_steps 1000 \
--save_steps 50000 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--per_device_eval_batch_size 2 \
--recompute ${use_recompute} \
--bf16 1\
--fp16_opt_level "O2" \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--amp_master_grad 1 \
--fuse_attention_ffn false \
--fuse_attention_qkv false \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention 0 \
--use_fused_rope false \
--use_fused_rms_norm 0 \
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel true \
--pipeline_parallel_degree 1 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 2 \
--virtual_pp_degree 1 \
--sharding "" \
--to_static ${to_static} \
--num_hidden_layers 4 \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem"
loss_base=9.16783295
loss_md5_base=8ea72495fba4e1b9ba004b4431e27218
if [ $IS_A100 -ne 0 ] && [ $to_static -eq 0 ];then
loss_base=9.37966919
elif [ $IS_A100 -ne 0 ] && [ $to_static -eq 1 ];then
loss_base=9.38012543
fi
ips=-1
mem=-1
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
# check_md5_result $FUNCNAME ${loss_md5_base} ${loss_md5}
done
done
echo "=========== $FUNCNAME run end ==========="
}
Expand Down