Skip to content

Commit 445fd0e

Browse files
committed
add a100 test ground truth
1 parent 84615ea commit 445fd0e

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def scaled_dot_product_attention(
181181

182182
attn_output = paddle.matmul(attn_weights, value_states)
183183
attn_output = attn_output.transpose([0, 2, 1, 3])
184+
# [bsz, q_len, num_heads, head_dim] -> [bsz, q_len, num_heads * head_dim]
184185
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
185186
return (attn_output, attn_weights) if output_attentions else attn_output
186187

@@ -399,9 +400,10 @@ def forward(
399400
alibi: Optional[paddle.Tensor] = None,
400401
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
401402
"""Input shape: Batch x Time x Channel"""
402-
# [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
403+
# [bs, seq_len, num_head * head_dim] or [seq_len / n, bs, num_head * head_dim] (if sequence_parallel)
403404
# enter tp region
404405
if self.config.sequence_parallel:
406+
# [seq_len / n, bs, num_head * head_dim] -> [seq_len, bs, num_head * head_dim] (if sequence_parallel)
405407
hidden_states = dist.reshard(
406408
hidden_states,
407409
get_mesh(self.ipp),
@@ -422,6 +424,8 @@ def forward(
422424
value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape)
423425

424426
if self.config.sequence_parallel:
427+
# [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
428+
# FA and rope not support sequence first
425429
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
426430
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
427431
value_states = paddle.transpose(value_states, [1, 0, 2, 3])
@@ -526,12 +530,12 @@ def forward(
526530
else:
527531
attn_output = outputs
528532

529-
# if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
530-
# else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
533+
# [bs, q_len, num_head * head_dim]
531534
attn_output = self.o_proj(attn_output)
532535

533536
# enter sp region
534537
if self.config.sequence_parallel:
538+
# [bs, q_len, num_head * head_dim] -> [q_len / n, bs, num_head * head_dim]
535539
attn_output = paddle.transpose(attn_output, [1, 0, 2])
536540
attn_output = dist.reshard(
537541
attn_output,
@@ -595,7 +599,7 @@ def forward(
595599
cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
596600
"""
597601

598-
# [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel)
602+
# [bs, seq_len, embed_dim] or [seq_len / n, bs, embed_dim] (if sequence_parallel)
599603
residual = hidden_states
600604

601605
hidden_states = self.input_layernorm(hidden_states)

scripts/distribute/ci_case_auto.sh

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ export llm_gpt_case_path=$root_path/llm/gpt-3/auto_parallel
2828

2929
unset CUDA_VISIBLE_DEVICES
3030

31+
function is_a100() {
32+
if [ $(nvidia-smi|grep A100|wc -l) -ne 0 ];then
33+
echo 1
34+
else
35+
echo 0
36+
fi
37+
}
38+
39+
3140
function gpt_case_list_auto() {
3241
gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1
3342
gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8
@@ -108,6 +117,11 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() {
108117
loss_base=10.507633305
109118
ips_base=3518
110119
mem_base=11750.6
120+
if [ $(is_a100) ];then
121+
loss_base=10.530449009
122+
ips_base=16763
123+
mem_base=11750.6
124+
fi
111125
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
112126
echo "=========== $FUNCNAME run end ==========="
113127
}
@@ -144,6 +158,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() {
144158
loss_base=10.570028400
145159
ips_base=35050
146160
mem_base=1988.9
161+
if [ $(is_a100) ];then
162+
loss_base=10.559662151
163+
ips_base=83918
164+
mem_base=2022.7
165+
fi
147166
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
148167
echo "=========== $FUNCNAME run end ==========="
149168
}
@@ -181,6 +200,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() {
181200
loss_base=10.570028400
182201
ips_base=35050
183202
mem_base=1988.9
203+
if [ $(is_a100) ];then
204+
loss_base=10.559662151
205+
ips_base=83918
206+
mem_base=2022.7
207+
fi
184208
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
185209
echo "=========== $FUNCNAME run end ==========="
186210
}
@@ -217,6 +241,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() {
217241
loss_base=10.700293922
218242
ips_base=32518
219243
mem_base=1535.7
244+
if [ $(is_a100) ];then
245+
loss_base=10.679453373
246+
ips_base=79116
247+
mem_base=1488.2
248+
fi
220249
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
221250
echo "=========== $FUNCNAME run end ==========="
222251
}
@@ -253,6 +282,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() {
253282
loss_base=10.672543240
254283
ips_base=18681
255284
mem_base=2135.7
285+
if [ $(is_a100) ];then
286+
loss_base=10.651049423
287+
ips_base=41174
288+
mem_base=2064.5
289+
fi
256290
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
257291
echo "=========== $FUNCNAME run end ==========="
258292
}
@@ -290,6 +324,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() {
290324
loss_base=10.672543240
291325
ips_base=18681
292326
mem_base=2135.7
327+
if [ $(is_a100) ];then
328+
loss_base=10.651049423
329+
ips_base=41174
330+
mem_base=2064.5
331+
fi
293332
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
294333
echo "=========== $FUNCNAME run end ==========="
295334
}
@@ -326,6 +365,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() {
326365
loss_base=10.720068359
327366
ips_base=15232
328367
mem_base=1999.2
368+
if [ $(is_a100) ];then
369+
loss_base=10.657777309
370+
ips_base=30027
371+
mem_base=2002.0
372+
fi
329373
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
330374
echo "=========== $FUNCNAME run end ==========="
331375
}
@@ -363,6 +407,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() {
363407
loss_base=10.720068359
364408
ips_base=15232
365409
mem_base=1999.2
410+
if [ $(is_a100) ];then
411+
loss_base=10.657777309
412+
ips_base=30027
413+
mem_base=2002.0
414+
fi
366415
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
367416
echo "=========== $FUNCNAME run end ==========="
368417
}
@@ -399,6 +448,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() {
399448
loss_base=10.720078850
400449
ips_base=15571
401450
mem_base=1999.2
451+
if [ $(is_a100) ];then
452+
loss_base=10.657803535
453+
ips_base=29166
454+
mem_base=2002.0
455+
fi
402456
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
403457
echo "=========== $FUNCNAME run end ==========="
404458
}
@@ -435,6 +489,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() {
435489
loss_base=10.681921577
436490
ips_base=13813
437491
mem_base=1747.6
492+
if [ $(is_a100) ];then
493+
loss_base=10.662137604
494+
ips_base=24700
495+
mem_base=1750.5
496+
fi
438497
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
439498
echo "=========== $FUNCNAME run end ==========="
440499
}
@@ -471,6 +530,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() {
471530
loss_base=10.579057693
472531
ips_base=19822
473532
mem_base=1709.8
533+
if [ $(is_a100) ];then
534+
loss_base=10.586785984
535+
ips_base=42813
536+
mem_base=1743.8
537+
fi
474538
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
475539
echo "=========== $FUNCNAME run end ==========="
476540
}
@@ -508,6 +572,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() {
508572
loss_base=10.579057693
509573
ips_base=19822
510574
mem_base=1709.8
575+
if [ $(is_a100) ];then
576+
loss_base=10.586785984
577+
ips_base=42813
578+
mem_base=1743.8
579+
fi
511580
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
512581
echo "=========== $FUNCNAME run end ==========="
513582
}
@@ -544,6 +613,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() {
544613
loss_base=10.579057693
545614
ips_base=20170
546615
mem_base=1709.8
616+
if [ $(is_a100) ];then
617+
loss_base=10.586785984
618+
ips_base=42995
619+
mem_base=1743.8
620+
fi
547621
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
548622
echo "=========== $FUNCNAME run end ==========="
549623
}
@@ -580,6 +654,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() {
580654
loss_base=10.585316849
581655
ips_base=15742
582656
mem_base=1591.6
657+
if [ $(is_a100) ];then
658+
loss_base=10.555718899
659+
ips_base=34688
660+
mem_base=1625.6
661+
fi
583662
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
584663
echo "=========== $FUNCNAME run end ==========="
585664
}
@@ -616,6 +695,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() {
616695
loss_base=10.672568035
617696
ips_base=19461
618697
mem_base=1384.7
698+
if [ $(is_a100) ];then
699+
loss_base=10.651032448
700+
ips_base=42435
701+
mem_base=1377.5
702+
fi
619703
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
620704
echo "=========== $FUNCNAME run end ==========="
621705
}
@@ -652,6 +736,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() {
652736
loss_base=10.672568035
653737
ips_base=19652
654738
mem_base=1384.7
739+
if [ $(is_a100) ];then
740+
loss_base=10.651032448
741+
ips_base=43008
742+
mem_base=1377.5
743+
fi
655744
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
656745
echo "=========== $FUNCNAME run end ==========="
657746
}
@@ -689,6 +778,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() {
689778
loss_base=10.672568035
690779
ips_base=19652
691780
mem_base=1384.7
781+
if [ $(is_a100) ];then
782+
loss_base=10.651032448
783+
ips_base=43008
784+
mem_base=1377.5
785+
fi
692786
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
693787
echo "=========== $FUNCNAME run end ==========="
694788
}
@@ -725,6 +819,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() {
725819
loss_base=10.696336079
726820
ips_base=16613
727821
mem_base=1280.5
822+
if [ $(is_a100) ];then
823+
loss_base=10.705118465
824+
ips_base=37104
825+
mem_base=1217.3
826+
fi
728827
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
729828
echo "=========== $FUNCNAME run end ==========="
730829
}
@@ -762,6 +861,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() {
762861
loss_base=10.696336079
763862
ips_base=16613
764863
mem_base=1280.5
864+
if [ $(is_a100) ];then
865+
loss_base=10.705118465
866+
ips_base=37104
867+
mem_base=1217.3
868+
fi
765869
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
766870
echo "=========== $FUNCNAME run end ==========="
767871
}
@@ -908,6 +1012,9 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
9081012
mem=-1
9091013
echo "result: loss=$loss ips=$ips mem=$mem"
9101014
loss_base=9.52110565
1015+
if [ $(is_a100) ];then
1016+
loss_base=9.44003963
1017+
fi
9111018
ips_base=-1
9121019
mem_base=-1
9131020
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -974,6 +1081,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
9741081
mem=-1
9751082
echo "result: loss=$loss ips=$ips mem=$mem"
9761083
loss_base=9.42011833
1084+
if [ $(is_a100) ];then
1085+
loss_base=9.44003963
1086+
fi
9771087
ips_base=-1
9781088
mem_base=-1
9791089
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1040,6 +1150,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
10401150
mem=-1
10411151
echo "result: loss=$loss ips=$ips mem=$mem"
10421152
loss_base=9.44299471
1153+
if [ $(is_a100) ];then
1154+
loss_base=9.45633757
1155+
fi
10431156
ips_base=-1
10441157
mem_base=-1
10451158
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1106,6 +1219,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
11061219
mem=-1
11071220
echo "result: loss=$loss ips=$ips mem=$mem"
11081221
loss_base=9.45936012
1222+
if [ $(is_a100) ];then
1223+
loss_base=9.46121407
1224+
fi
11091225
ips_base=-1
11101226
mem_base=-1
11111227
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1174,6 +1290,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
11741290
mem=-1
11751291
echo "result: loss=$loss ips=$ips mem=$mem"
11761292
loss_base=9.46707726
1293+
if [ $(is_a100) ];then
1294+
loss_base=9.44474411
1295+
fi
11771296
ips_base=-1
11781297
mem_base=-1
11791298
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1243,6 +1362,9 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
12431362
mem=-1
12441363
echo "result: loss=$loss ips=$ips mem=$mem"
12451364
loss_base=10.0859375
1365+
if [ $(is_a100) ];then
1366+
loss_base=10.125
1367+
fi
12461368
ips_base=-1
12471369
mem_base=-1
12481370
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1310,6 +1432,9 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
13101432
mem=-1
13111433
echo "result: loss=$loss ips=$ips mem=$mem"
13121434
loss_base=9.53389835
1435+
if [ $(is_a100) ];then
1436+
loss_base=9.54253578
1437+
fi
13131438
ips_base=-1
13141439
mem_base=-1
13151440
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1377,6 +1502,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
13771502
mem=-1
13781503
echo "result: loss=$loss ips=$ips mem=$mem"
13791504
loss_base=9.39066124
1505+
if [ $(is_a100) ];then
1506+
loss_base=9.41613197
1507+
fi
13801508
ips_base=-1
13811509
mem_base=-1
13821510
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1444,6 +1572,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
14441572
mem=-1
14451573
echo "result: loss=$loss ips=$ips mem=$mem"
14461574
loss_base=9.38235474
1575+
if [ $(is_a100) ];then
1576+
loss_base=9.4053154
1577+
fi
14471578
ips_base=-1
14481579
mem_base=-1
14491580
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1512,6 +1643,9 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
15121643
mem=-1
15131644
echo "result: loss=$loss ips=$ips mem=$mem"
15141645
loss_base=9.38256836
1646+
if [ $(is_a100) ];then
1647+
loss_base=9.4055137
1648+
fi
15151649
ips_base=-1
15161650
mem_base=-1
15171651
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}

0 commit comments

Comments
 (0)