Skip to content

Commit fdfc7bf

Browse files
committed
add a100 test ground truth
1 parent 79cb8b6 commit fdfc7bf

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def scaled_dot_product_attention(
181181

182182
attn_output = paddle.matmul(attn_weights, value_states)
183183
attn_output = attn_output.transpose([0, 2, 1, 3])
184+
# [bsz, q_len, num_heads, head_dim] -> [bsz, q_len, num_heads * head_dim]
184185
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
185186
return (attn_output, attn_weights) if output_attentions else attn_output
186187

@@ -399,9 +400,10 @@ def forward(
399400
alibi: Optional[paddle.Tensor] = None,
400401
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
401402
"""Input shape: Batch x Time x Channel"""
402-
# [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
403+
# [bs, seq_len, num_head * head_dim] or [seq_len / n, bs, num_head * head_dim] (if sequence_parallel)
403404
# enter tp region
404405
if self.config.sequence_parallel:
406+
# [seq_len / n, bs, num_head * head_dim] -> [seq_len, bs, num_head * head_dim] (if sequence_parallel)
405407
hidden_states = dist.reshard(
406408
hidden_states,
407409
get_mesh(self.ipp),
@@ -422,6 +424,8 @@ def forward(
422424
value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape)
423425

424426
if self.config.sequence_parallel:
427+
# [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
428+
# FA and rope not support sequence first
425429
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
426430
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
427431
value_states = paddle.transpose(value_states, [1, 0, 2, 3])
@@ -526,12 +530,12 @@ def forward(
526530
else:
527531
attn_output = outputs
528532

529-
# if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim]
530-
# else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism.
533+
# [bs, q_len, num_head * head_dim]
531534
attn_output = self.o_proj(attn_output)
532535

533536
# enter sp region
534537
if self.config.sequence_parallel:
538+
# [bs, q_len, num_head * head_dim] -> [q_len / n, bs, num_head * head_dim]
535539
attn_output = paddle.transpose(attn_output, [1, 0, 2])
536540
attn_output = dist.reshard(
537541
attn_output,
@@ -595,7 +599,7 @@ def forward(
595599
cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
596600
"""
597601

598-
# [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel)
602+
# [bs, seq_len, embed_dim] or [seq_len / n, bs, embed_dim] (if sequence_parallel)
599603
residual = hidden_states
600604

601605
hidden_states = self.input_layernorm(hidden_states)

scripts/distribute/ci_case_auto.sh

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ export llama_data_path=/llama_data
2727

2828
unset CUDA_VISIBLE_DEVICES
2929

30+
function is_a100() {
31+
if [ $(nvidia-smi|grep A100|wc -l) -ne 0 ];then
32+
echo 1
33+
else
34+
echo 0
35+
fi
36+
}
37+
38+
3039
function gpt_case_list_auto() {
3140
gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1
3241
gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8
@@ -100,6 +109,11 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() {
100109
loss_base=10.507633305
101110
ips_base=3518
102111
mem_base=11750.6
112+
if [ $(is_a100) ];then
113+
loss_base=10.530449009
114+
ips_base=16763
115+
mem_base=11750.6
116+
fi
103117
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
104118
echo "=========== $FUNCNAME run end ==========="
105119
}
@@ -136,6 +150,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() {
136150
loss_base=10.570028400
137151
ips_base=35050
138152
mem_base=1988.9
153+
if [ $(is_a100) ];then
154+
loss_base=10.559662151
155+
ips_base=83918
156+
mem_base=2022.7
157+
fi
139158
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
140159
echo "=========== $FUNCNAME run end ==========="
141160
}
@@ -173,6 +192,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() {
173192
loss_base=10.570028400
174193
ips_base=35050
175194
mem_base=1988.9
195+
if [ $(is_a100) ];then
196+
loss_base=10.559662151
197+
ips_base=83918
198+
mem_base=2022.7
199+
fi
176200
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
177201
echo "=========== $FUNCNAME run end ==========="
178202
}
@@ -209,6 +233,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() {
209233
loss_base=10.700293922
210234
ips_base=32518
211235
mem_base=1535.7
236+
if [ $(is_a100) ];then
237+
loss_base=10.679453373
238+
ips_base=79116
239+
mem_base=1488.2
240+
fi
212241
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
213242
echo "=========== $FUNCNAME run end ==========="
214243
}
@@ -245,6 +274,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() {
245274
loss_base=10.672543240
246275
ips_base=18681
247276
mem_base=2135.7
277+
if [ $(is_a100) ];then
278+
loss_base=10.651049423
279+
ips_base=41174
280+
mem_base=2064.5
281+
fi
248282
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
249283
echo "=========== $FUNCNAME run end ==========="
250284
}
@@ -282,6 +316,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() {
282316
loss_base=10.672543240
283317
ips_base=18681
284318
mem_base=2135.7
319+
if [ $(is_a100) ];then
320+
loss_base=10.651049423
321+
ips_base=41174
322+
mem_base=2064.5
323+
fi
285324
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
286325
echo "=========== $FUNCNAME run end ==========="
287326
}
@@ -318,6 +357,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() {
318357
loss_base=10.720068359
319358
ips_base=15232
320359
mem_base=1999.2
360+
if [ $(is_a100) ];then
361+
loss_base=10.657777309
362+
ips_base=30027
363+
mem_base=2002.0
364+
fi
321365
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
322366
echo "=========== $FUNCNAME run end ==========="
323367
}
@@ -355,6 +399,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() {
355399
loss_base=10.720068359
356400
ips_base=15232
357401
mem_base=1999.2
402+
if [ $(is_a100) ];then
403+
loss_base=10.657777309
404+
ips_base=30027
405+
mem_base=2002.0
406+
fi
358407
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
359408
echo "=========== $FUNCNAME run end ==========="
360409
}
@@ -391,6 +440,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() {
391440
loss_base=10.720078850
392441
ips_base=15571
393442
mem_base=1999.2
443+
if [ $(is_a100) ];then
444+
loss_base=10.657803535
445+
ips_base=29166
446+
mem_base=2002.0
447+
fi
394448
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
395449
echo "=========== $FUNCNAME run end ==========="
396450
}
@@ -427,6 +481,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() {
427481
loss_base=10.681921577
428482
ips_base=13813
429483
mem_base=1747.6
484+
if [ $(is_a100) ];then
485+
loss_base=10.662137604
486+
ips_base=24700
487+
mem_base=1750.5
488+
fi
430489
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
431490
echo "=========== $FUNCNAME run end ==========="
432491
}
@@ -463,6 +522,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() {
463522
loss_base=10.579057693
464523
ips_base=19822
465524
mem_base=1709.8
525+
if [ $(is_a100) ];then
526+
loss_base=10.586785984
527+
ips_base=42813
528+
mem_base=1743.8
529+
fi
466530
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
467531
echo "=========== $FUNCNAME run end ==========="
468532
}
@@ -500,6 +564,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() {
500564
loss_base=10.579057693
501565
ips_base=19822
502566
mem_base=1709.8
567+
if [ $(is_a100) ];then
568+
loss_base=10.586785984
569+
ips_base=42813
570+
mem_base=1743.8
571+
fi
503572
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
504573
echo "=========== $FUNCNAME run end ==========="
505574
}
@@ -536,6 +605,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() {
536605
loss_base=10.579057693
537606
ips_base=20170
538607
mem_base=1709.8
608+
if [ $(is_a100) ];then
609+
loss_base=10.586785984
610+
ips_base=42995
611+
mem_base=1743.8
612+
fi
539613
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
540614
echo "=========== $FUNCNAME run end ==========="
541615
}
@@ -572,6 +646,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() {
572646
loss_base=10.585316849
573647
ips_base=15742
574648
mem_base=1591.6
649+
if [ $(is_a100) ];then
650+
loss_base=10.555718899
651+
ips_base=34688
652+
mem_base=1625.6
653+
fi
575654
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
576655
echo "=========== $FUNCNAME run end ==========="
577656
}
@@ -608,6 +687,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() {
608687
loss_base=10.672568035
609688
ips_base=19461
610689
mem_base=1384.7
690+
if [ $(is_a100) ];then
691+
loss_base=10.651032448
692+
ips_base=42435
693+
mem_base=1377.5
694+
fi
611695
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
612696
echo "=========== $FUNCNAME run end ==========="
613697
}
@@ -644,6 +728,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() {
644728
loss_base=10.672568035
645729
ips_base=19652
646730
mem_base=1384.7
731+
if [ $(is_a100) ];then
732+
loss_base=10.651032448
733+
ips_base=43008
734+
mem_base=1377.5
735+
fi
647736
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
648737
echo "=========== $FUNCNAME run end ==========="
649738
}
@@ -681,6 +770,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() {
681770
loss_base=10.672568035
682771
ips_base=19652
683772
mem_base=1384.7
773+
if [ $(is_a100) ];then
774+
loss_base=10.651032448
775+
ips_base=43008
776+
mem_base=1377.5
777+
fi
684778
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
685779
echo "=========== $FUNCNAME run end ==========="
686780
}
@@ -717,6 +811,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() {
717811
loss_base=10.696336079
718812
ips_base=16613
719813
mem_base=1280.5
814+
if [ $(is_a100) ];then
815+
loss_base=10.705118465
816+
ips_base=37104
817+
mem_base=1217.3
818+
fi
720819
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
721820
echo "=========== $FUNCNAME run end ==========="
722821
}
@@ -754,6 +853,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() {
754853
loss_base=10.696336079
755854
ips_base=16613
756855
mem_base=1280.5
856+
if [ $(is_a100) ];then
857+
loss_base=10.705118465
858+
ips_base=37104
859+
mem_base=1217.3
860+
fi
757861
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
758862
echo "=========== $FUNCNAME run end ==========="
759863
}
@@ -900,6 +1004,9 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
9001004
mem=-1
9011005
echo "result: loss=$loss ips=$ips mem=$mem"
9021006
loss_base=9.52110565
1007+
if [ $(is_a100) ];then
1008+
loss_base=9.44003963
1009+
fi
9031010
ips_base=-1
9041011
mem_base=-1
9051012
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -966,6 +1073,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
9661073
mem=-1
9671074
echo "result: loss=$loss ips=$ips mem=$mem"
9681075
loss_base=9.42011833
1076+
if [ $(is_a100) ];then
1077+
loss_base=9.44003963
1078+
fi
9691079
ips_base=-1
9701080
mem_base=-1
9711081
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1032,6 +1142,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
10321142
mem=-1
10331143
echo "result: loss=$loss ips=$ips mem=$mem"
10341144
loss_base=9.44299471
1145+
if [ $(is_a100) ];then
1146+
loss_base=9.45633757
1147+
fi
10351148
ips_base=-1
10361149
mem_base=-1
10371150
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1098,6 +1211,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
10981211
mem=-1
10991212
echo "result: loss=$loss ips=$ips mem=$mem"
11001213
loss_base=9.45936012
1214+
if [ $(is_a100) ];then
1215+
loss_base=9.46121407
1216+
fi
11011217
ips_base=-1
11021218
mem_base=-1
11031219
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1166,6 +1282,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
11661282
mem=-1
11671283
echo "result: loss=$loss ips=$ips mem=$mem"
11681284
loss_base=9.46707726
1285+
if [ $(is_a100) ];then
1286+
loss_base=9.44474411
1287+
fi
11691288
ips_base=-1
11701289
mem_base=-1
11711290
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1235,6 +1354,9 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
12351354
mem=-1
12361355
echo "result: loss=$loss ips=$ips mem=$mem"
12371356
loss_base=10.0859375
1357+
if [ $(is_a100) ];then
1358+
loss_base=10.125
1359+
fi
12381360
ips_base=-1
12391361
mem_base=-1
12401362
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1302,6 +1424,9 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
13021424
mem=-1
13031425
echo "result: loss=$loss ips=$ips mem=$mem"
13041426
loss_base=9.53389835
1427+
if [ $(is_a100) ];then
1428+
loss_base=9.54253578
1429+
fi
13051430
ips_base=-1
13061431
mem_base=-1
13071432
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1369,6 +1494,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
13691494
mem=-1
13701495
echo "result: loss=$loss ips=$ips mem=$mem"
13711496
loss_base=9.39066124
1497+
if [ $(is_a100) ];then
1498+
loss_base=9.41613197
1499+
fi
13721500
ips_base=-1
13731501
mem_base=-1
13741502
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1436,6 +1564,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
14361564
mem=-1
14371565
echo "result: loss=$loss ips=$ips mem=$mem"
14381566
loss_base=9.38235474
1567+
if [ $(is_a100) ];then
1568+
loss_base=9.4053154
1569+
fi
14391570
ips_base=-1
14401571
mem_base=-1
14411572
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1504,6 +1635,9 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
15041635
mem=-1
15051636
echo "result: loss=$loss ips=$ips mem=$mem"
15061637
loss_base=9.38256836
1638+
if [ $(is_a100) ];then
1639+
loss_base=9.4055137
1640+
fi
15071641
ips_base=-1
15081642
mem_base=-1
15091643
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}

0 commit comments

Comments
 (0)