Skip to content

Commit 67bf18d

Browse files
committed
add test
1 parent 5e1f01f commit 67bf18d

File tree

1 file changed

+78
-74
lines changed

1 file changed

+78
-74
lines changed

scripts/distribute/ci_case_auto.sh

Lines changed: 78 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -667,80 +667,84 @@ function llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP() {
667667
case_log_dir="output/$task_name""_log"
668668

669669
for to_static in "0" "1"; do
670-
rm -rf $case_out_dir
671-
rm -rf $case_log_dir
672-
python -u -m paddle.distributed.launch \
673-
--gpus "0,1,2,3" \
674-
--log_dir $case_log_dir \
675-
run_pretrain_auto.py \
676-
--model_type "llama" \
677-
--model_name_or_path "facebook/llama-7b" \
678-
--tokenizer_name_or_path "facebook/llama-7b" \
679-
--input_dir "./data" \
680-
--output_dir $case_out_dir \
681-
--split 949,50,1 \
682-
--weight_decay 0.01 \
683-
--warmup_ratio 0.01 \
684-
--max_grad_norm 0.0 \
685-
--learning_rate 3e-05 \
686-
--min_learning_rate 3e-06 \
687-
--max_steps 10 \
688-
--logging_steps 10 \
689-
--eval_steps 1000 \
690-
--save_steps 50000 \
691-
--continue_training 0 \
692-
--do_train true \
693-
--do_eval false \
694-
--do_predict false \
695-
--disable_tqdm true \
696-
--skip_profile_timer true \
697-
--save_total_limit 2 \
698-
--device gpu \
699-
--disable_tqdm true \
700-
--dataloader_num_workers 1 \
701-
--enable_auto_parallel 1 \
702-
--per_device_train_batch_size 1 \
703-
--gradient_accumulation_steps 1 \
704-
--per_device_eval_batch_size 2 \
705-
--recompute false \
706-
--bf16 1\
707-
--fp16_opt_level "O2" \
708-
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
709-
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
710-
--amp_master_grad 1 \
711-
--fuse_attention_ffn false \
712-
--fuse_attention_qkv false \
713-
--fuse_sequence_parallel_allreduce false \
714-
--use_flash_attention 0 \
715-
--use_fused_rope false \
716-
--use_fused_rms_norm 0 \
717-
--max_seq_length 4096 \
718-
--sep_parallel_degree 1 \
719-
--sequence_parallel true \
720-
--pipeline_parallel_degree 1 \
721-
--sharding_parallel_degree 1 \
722-
--tensor_parallel_degree 2 \
723-
--virtual_pp_degree 1 \
724-
--sharding "" \
725-
--to_static ${to_static} \
726-
--num_hidden_layers 4 \
727-
>>${log_path}/$FUNCNAME 2>&1
728-
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
729-
loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'`
730-
ips=-1
731-
mem=-1
732-
echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem"
733-
loss_base=9.16783295
734-
loss_md5_base=8ea72495fba4e1b9ba004b4431e27218
735-
if [ $IS_A100 -ne 0 ] && [ $to_static -eq 0 ];then
736-
loss_base=9.37966919
737-
elif [ $IS_A100 -ne 0 ] && [ $to_static -eq 1 ];then
738-
loss_base=9.38012543
739-
fi
740-
ips_base=-1
741-
mem_base=-1
742-
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
743-
# check_md5_result $FUNCNAME ${loss_md5_base} ${loss_md5}
670+
for use_recompute in "1" "0"; do
671+
rm -rf $case_out_dir
672+
rm -rf $case_log_dir
673+
python -u -m paddle.distributed.launch \
674+
--gpus "0,1,2,3" \
675+
--log_dir $case_log_dir \
676+
run_pretrain_auto.py \
677+
--model_type "llama" \
678+
--model_name_or_path "facebook/llama-7b" \
679+
--tokenizer_name_or_path "facebook/llama-7b" \
680+
--input_dir "./data" \
681+
--output_dir $case_out_dir \
682+
--split 949,50,1 \
683+
--weight_decay 0.01 \
684+
--warmup_ratio 0.01 \
685+
--max_grad_norm 0.0 \
686+
--learning_rate 3e-05 \
687+
--min_learning_rate 3e-06 \
688+
--max_steps 10 \
689+
--logging_steps 10 \
690+
--eval_steps 1000 \
691+
--save_steps 50000 \
692+
--continue_training 0 \
693+
--do_train true \
694+
--do_eval false \
695+
--do_predict false \
696+
--disable_tqdm true \
697+
--skip_profile_timer true \
698+
--save_total_limit 2 \
699+
--device gpu \
700+
--disable_tqdm true \
701+
--dataloader_num_workers 1 \
702+
--enable_auto_parallel 1 \
703+
--per_device_train_batch_size 1 \
704+
--gradient_accumulation_steps 1 \
705+
--per_device_eval_batch_size 2 \
706+
--recompute ${use_recompute} \
707+
--bf16 1\
708+
--fp16_opt_level "O2" \
709+
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
710+
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
711+
--amp_master_grad 1 \
712+
--fuse_attention_ffn false \
713+
--fuse_attention_qkv false \
714+
--fuse_sequence_parallel_allreduce false \
715+
--use_flash_attention 0 \
716+
--use_fused_rope false \
717+
--use_fused_rms_norm 0 \
718+
--max_seq_length 4096 \
719+
--sep_parallel_degree 1 \
720+
--sequence_parallel true \
721+
--pipeline_parallel_degree 1 \
722+
--sharding_parallel_degree 1 \
723+
--tensor_parallel_degree 2 \
724+
--virtual_pp_degree 1 \
725+
--sharding "" \
726+
--to_static ${to_static} \
727+
--num_hidden_layers 4 \
728+
>>${log_path}/$FUNCNAME 2>&1
729+
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
730+
loss_md5=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss_md5: ' '{print $2}' | awk -F ',' '{print $1}'`
731+
ips=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'interval_tokens_per_second_per_device: ' '{print $2}' | awk -F ',' '{print $1}'`
732+
mem=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'max_memory_reserved: ' '{print $2}' | awk -F ',' '{print $1}'`
733+
echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem"
734+
loss_base=9.16783295
735+
loss_md5_base=8ea72495fba4e1b9ba004b4431e27218
736+
if [ $IS_A100 -ne 0 ] && [ $to_static -eq 0 ];then
737+
loss_base=9.37966919
738+
elif [ $IS_A100 -ne 0 ] && [ $to_static -eq 1 ];then
739+
loss_base=9.38012543
740+
fi
741+
ips=-1
742+
mem=-1
743+
ips_base=-1
744+
mem_base=-1
745+
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
746+
# check_md5_result $FUNCNAME ${loss_md5_base} ${loss_md5}
747+
done
744748
done
745749
echo "=========== $FUNCNAME run end ==========="
746750
}

0 commit comments

Comments
 (0)