Skip to content

Commit 235f09f

Browse files
authored
[PIR] add pir grad merge test (#9074)
* add test * change task name * Update ci_case_auto.sh * Update ci_case_auto.sh * Update ci_case_auto.sh * Update ci_case_auto.sh * Update ci_case_auto.sh
1 parent 0854287 commit 235f09f

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

scripts/distribute/ci_case_auto.sh

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ function llama_case_list_auto() {
5757
llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
5858
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP1-SP
5959
llama_align_dygraph_dy2st_pir_auto_bs2_bf16_DP2-MP2-PP2-SP
60+
llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1
6061
}
6162

6263
function llm_gpt_case_list_auto() {
@@ -1254,6 +1255,104 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() {
12541255
echo "=========== $FUNCNAME run end ==========="
12551256
}
12561257

1258+
function llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1() {
1259+
echo "=========== $FUNCNAME run begin ==========="
1260+
export FLAGS_call_stack_level=3
1261+
export NVIDIA_TF32_OVERRIDE=0
1262+
export FLAGS_max_inplace_grad_add=3
1263+
1264+
task_name="llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP2"
1265+
case_out_dir="output/$task_name"
1266+
case_log_dir="output/$task_name""_log"
1267+
1268+
loss1=0
1269+
loss2=0
1270+
use_pir=1
1271+
max_step=12
1272+
1273+
for to_static in "0" "1"; do
1274+
export FLAGS_enable_pir_api=${use_pir}
1275+
export FLAGS_enable_pir_in_executor=${use_pir}
1276+
rm -rf $case_out_dir
1277+
rm -rf $case_log_dir
1278+
rm -rf ${log_path}/$FUNCNAME
1279+
1280+
/usr/bin/python -u -m paddle.distributed.launch \
1281+
--gpus "0" \
1282+
--log_dir $case_log_dir \
1283+
run_pretrain_auto.py \
1284+
--model_type "llama" \
1285+
--model_name_or_path "facebook/llama-7b" \
1286+
--tokenizer_name_or_path "facebook/llama-7b" \
1287+
--input_dir "./data" \
1288+
--output_dir $case_out_dir \
1289+
--split 949,50,1 \
1290+
--weight_decay 0.01 \
1291+
--warmup_ratio 0.01 \
1292+
--warmup_steps 30 \
1293+
--max_grad_norm 0.0 \
1294+
--learning_rate 3e-05 \
1295+
--min_learning_rate 3e-06 \
1296+
--max_steps $max_step \
1297+
--logging_steps 1 \
1298+
--eval_steps 1000 \
1299+
--save_steps 50000 \
1300+
--continue_training 0 \
1301+
--do_train true \
1302+
--do_eval false \
1303+
--do_predict false \
1304+
--disable_tqdm true \
1305+
--skip_profile_timer true \
1306+
--save_total_limit 2 \
1307+
--device gpu \
1308+
--disable_tqdm true \
1309+
--dataloader_num_workers 1 \
1310+
--distributed_dataloader 0 \
1311+
--enable_auto_parallel 1 \
1312+
--per_device_train_batch_size 1 \
1313+
--gradient_accumulation_steps 2 \
1314+
--per_device_eval_batch_size 2 \
1315+
--recompute false \
1316+
--recompute_use_reentrant true \
1317+
--recompute_granularity full \
1318+
--pp_recompute_interval 0 \
1319+
--fp16 0 \
1320+
--fp16_opt_level "O2" \
1321+
--fuse_attention_ffn true \
1322+
--fuse_attention_qkv false \
1323+
--fuse_sequence_parallel_allreduce false \
1324+
--use_flash_attention 0 \
1325+
--use_fused_rope false \
1326+
--use_fused_rms_norm 0 \
1327+
--max_seq_length 2048 \
1328+
--sep_parallel_degree 1 \
1329+
--sequence_parallel false \
1330+
--pipeline_parallel_degree 1 \
1331+
--sharding_parallel_degree 1 \
1332+
--tensor_parallel_degree 1 \
1333+
--virtual_pp_degree 1 \
1334+
--sharding "" \
1335+
--to_static ${to_static} \
1336+
--num_hidden_layers 2 \
1337+
--data_parallel_config "gradient_sync_after_accumulate" \
1338+
>>${log_path}/$FUNCNAME 2>&1
1339+
1340+
loss=$(grep "global_step: $max_step" "$case_log_dir/workerlog.0" | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}')
1341+
if [ $to_static -eq 0 ];then
1342+
loss1=($loss)
1343+
else
1344+
loss2=($loss)
1345+
fi
1346+
echo "result: to_static=$to_static loss=$loss"
1347+
done
1348+
1349+
ips=-1
1350+
mem=-1
1351+
ips_base=-1
1352+
mem_base=-1
1353+
check_result $FUNCNAME ${loss1} ${loss2} ${ips_base} ${ips} ${mem_base} ${mem}
1354+
}
1355+
12571356
function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() {
12581357
echo "=========== $FUNCNAME run begin ==========="
12591358
export PYTHONPATH=$root_path/:$PYTHONPATH

0 commit comments

Comments
 (0)