@@ -28,6 +28,15 @@ export llm_gpt_case_path=$root_path/llm/gpt-3/auto_parallel
28
28
29
29
unset CUDA_VISIBLE_DEVICES
30
30
31
+ function is_a100() {
32
+ if [ $( nvidia-smi| grep A100| wc -l) -ne 0 ]; then
33
+ echo 1
34
+ else
35
+ echo 0
36
+ fi
37
+ }
38
+
39
+
31
40
function gpt_case_list_auto() {
32
41
gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1
33
42
gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8
@@ -108,6 +117,11 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() {
108
117
loss_base=10.507633305
109
118
ips_base=3518
110
119
mem_base=11750.6
120
+ if [ $( is_a100) ]; then
121
+ loss_base=10.530449009
122
+ ips_base=16763
123
+ mem_base=11750.6
124
+ fi
111
125
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
112
126
echo " =========== $FUNCNAME run end ==========="
113
127
}
@@ -144,6 +158,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() {
144
158
loss_base=10.570028400
145
159
ips_base=35050
146
160
mem_base=1988.9
161
+ if [ $( is_a100) ]; then
162
+ loss_base=10.559662151
163
+ ips_base=83918
164
+ mem_base=2022.7
165
+ fi
147
166
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
148
167
echo " =========== $FUNCNAME run end ==========="
149
168
}
@@ -181,6 +200,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() {
181
200
loss_base=10.570028400
182
201
ips_base=35050
183
202
mem_base=1988.9
203
+ if [ $( is_a100) ]; then
204
+ loss_base=10.559662151
205
+ ips_base=83918
206
+ mem_base=2022.7
207
+ fi
184
208
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
185
209
echo " =========== $FUNCNAME run end ==========="
186
210
}
@@ -217,6 +241,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() {
217
241
loss_base=10.700293922
218
242
ips_base=32518
219
243
mem_base=1535.7
244
+ if [ $( is_a100) ]; then
245
+ loss_base=10.679453373
246
+ ips_base=79116
247
+ mem_base=1488.2
248
+ fi
220
249
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
221
250
echo " =========== $FUNCNAME run end ==========="
222
251
}
@@ -253,6 +282,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() {
253
282
loss_base=10.672543240
254
283
ips_base=18681
255
284
mem_base=2135.7
285
+ if [ $( is_a100) ]; then
286
+ loss_base=10.651049423
287
+ ips_base=41174
288
+ mem_base=2064.5
289
+ fi
256
290
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
257
291
echo " =========== $FUNCNAME run end ==========="
258
292
}
@@ -290,6 +324,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() {
290
324
loss_base=10.672543240
291
325
ips_base=18681
292
326
mem_base=2135.7
327
+ if [ $( is_a100) ]; then
328
+ loss_base=10.651049423
329
+ ips_base=41174
330
+ mem_base=2064.5
331
+ fi
293
332
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
294
333
echo " =========== $FUNCNAME run end ==========="
295
334
}
@@ -326,6 +365,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() {
326
365
loss_base=10.720068359
327
366
ips_base=15232
328
367
mem_base=1999.2
368
+ if [ $( is_a100) ]; then
369
+ loss_base=10.657777309
370
+ ips_base=30027
371
+ mem_base=2002.0
372
+ fi
329
373
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
330
374
echo " =========== $FUNCNAME run end ==========="
331
375
}
@@ -363,6 +407,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() {
363
407
loss_base=10.720068359
364
408
ips_base=15232
365
409
mem_base=1999.2
410
+ if [ $( is_a100) ]; then
411
+ loss_base=10.657777309
412
+ ips_base=30027
413
+ mem_base=2002.0
414
+ fi
366
415
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
367
416
echo " =========== $FUNCNAME run end ==========="
368
417
}
@@ -399,6 +448,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() {
399
448
loss_base=10.720078850
400
449
ips_base=15571
401
450
mem_base=1999.2
451
+ if [ $( is_a100) ]; then
452
+ loss_base=10.657803535
453
+ ips_base=29166
454
+ mem_base=2002.0
455
+ fi
402
456
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
403
457
echo " =========== $FUNCNAME run end ==========="
404
458
}
@@ -435,6 +489,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() {
435
489
loss_base=10.681921577
436
490
ips_base=13813
437
491
mem_base=1747.6
492
+ if [ $( is_a100) ]; then
493
+ loss_base=10.662137604
494
+ ips_base=24700
495
+ mem_base=1750.5
496
+ fi
438
497
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
439
498
echo " =========== $FUNCNAME run end ==========="
440
499
}
@@ -471,6 +530,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() {
471
530
loss_base=10.579057693
472
531
ips_base=19822
473
532
mem_base=1709.8
533
+ if [ $( is_a100) ]; then
534
+ loss_base=10.586785984
535
+ ips_base=42813
536
+ mem_base=1743.8
537
+ fi
474
538
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
475
539
echo " =========== $FUNCNAME run end ==========="
476
540
}
@@ -508,6 +572,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() {
508
572
loss_base=10.579057693
509
573
ips_base=19822
510
574
mem_base=1709.8
575
+ if [ $( is_a100) ]; then
576
+ loss_base=10.586785984
577
+ ips_base=42813
578
+ mem_base=1743.8
579
+ fi
511
580
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
512
581
echo " =========== $FUNCNAME run end ==========="
513
582
}
@@ -544,6 +613,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() {
544
613
loss_base=10.579057693
545
614
ips_base=20170
546
615
mem_base=1709.8
616
+ if [ $( is_a100) ]; then
617
+ loss_base=10.586785984
618
+ ips_base=42995
619
+ mem_base=1743.8
620
+ fi
547
621
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
548
622
echo " =========== $FUNCNAME run end ==========="
549
623
}
@@ -580,6 +654,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() {
580
654
loss_base=10.585316849
581
655
ips_base=15742
582
656
mem_base=1591.6
657
+ if [ $( is_a100) ]; then
658
+ loss_base=10.555718899
659
+ ips_base=34688
660
+ mem_base=1625.6
661
+ fi
583
662
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
584
663
echo " =========== $FUNCNAME run end ==========="
585
664
}
@@ -616,6 +695,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() {
616
695
loss_base=10.672568035
617
696
ips_base=19461
618
697
mem_base=1384.7
698
+ if [ $( is_a100) ]; then
699
+ loss_base=10.651032448
700
+ ips_base=42435
701
+ mem_base=1377.5
702
+ fi
619
703
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
620
704
echo " =========== $FUNCNAME run end ==========="
621
705
}
@@ -652,6 +736,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() {
652
736
loss_base=10.672568035
653
737
ips_base=19652
654
738
mem_base=1384.7
739
+ if [ $( is_a100) ]; then
740
+ loss_base=10.651032448
741
+ ips_base=43008
742
+ mem_base=1377.5
743
+ fi
655
744
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
656
745
echo " =========== $FUNCNAME run end ==========="
657
746
}
@@ -689,6 +778,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() {
689
778
loss_base=10.672568035
690
779
ips_base=19652
691
780
mem_base=1384.7
781
+ if [ $( is_a100) ]; then
782
+ loss_base=10.651032448
783
+ ips_base=43008
784
+ mem_base=1377.5
785
+ fi
692
786
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
693
787
echo " =========== $FUNCNAME run end ==========="
694
788
}
@@ -725,6 +819,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() {
725
819
loss_base=10.696336079
726
820
ips_base=16613
727
821
mem_base=1280.5
822
+ if [ $( is_a100) ]; then
823
+ loss_base=10.705118465
824
+ ips_base=37104
825
+ mem_base=1217.3
826
+ fi
728
827
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
729
828
echo " =========== $FUNCNAME run end ==========="
730
829
}
@@ -762,6 +861,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() {
762
861
loss_base=10.696336079
763
862
ips_base=16613
764
863
mem_base=1280.5
864
+ if [ $( is_a100) ]; then
865
+ loss_base=10.705118465
866
+ ips_base=37104
867
+ mem_base=1217.3
868
+ fi
765
869
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
766
870
echo " =========== $FUNCNAME run end ==========="
767
871
}
@@ -908,6 +1012,9 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
908
1012
mem=-1
909
1013
echo " result: loss=$loss ips=$ips mem=$mem "
910
1014
loss_base=9.52110565
1015
+ if [ $( is_a100) ]; then
1016
+ loss_base=9.44003963
1017
+ fi
911
1018
ips_base=-1
912
1019
mem_base=-1
913
1020
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -974,6 +1081,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
974
1081
mem=-1
975
1082
echo " result: loss=$loss ips=$ips mem=$mem "
976
1083
loss_base=9.42011833
1084
+ if [ $( is_a100) ]; then
1085
+ loss_base=9.44003963
1086
+ fi
977
1087
ips_base=-1
978
1088
mem_base=-1
979
1089
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1040,6 +1150,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
1040
1150
mem=-1
1041
1151
echo " result: loss=$loss ips=$ips mem=$mem "
1042
1152
loss_base=9.44299471
1153
+ if [ $( is_a100) ]; then
1154
+ loss_base=9.45633757
1155
+ fi
1043
1156
ips_base=-1
1044
1157
mem_base=-1
1045
1158
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1106,6 +1219,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
1106
1219
mem=-1
1107
1220
echo " result: loss=$loss ips=$ips mem=$mem "
1108
1221
loss_base=9.45936012
1222
+ if [ $( is_a100) ]; then
1223
+ loss_base=9.46121407
1224
+ fi
1109
1225
ips_base=-1
1110
1226
mem_base=-1
1111
1227
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1174,6 +1290,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
1174
1290
mem=-1
1175
1291
echo " result: loss=$loss ips=$ips mem=$mem "
1176
1292
loss_base=9.46707726
1293
+ if [ $( is_a100) ]; then
1294
+ loss_base=9.44474411
1295
+ fi
1177
1296
ips_base=-1
1178
1297
mem_base=-1
1179
1298
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1243,6 +1362,9 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
1243
1362
mem=-1
1244
1363
echo " result: loss=$loss ips=$ips mem=$mem "
1245
1364
loss_base=10.0859375
1365
+ if [ $( is_a100) ]; then
1366
+ loss_base=10.125
1367
+ fi
1246
1368
ips_base=-1
1247
1369
mem_base=-1
1248
1370
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1310,6 +1432,9 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
1310
1432
mem=-1
1311
1433
echo " result: loss=$loss ips=$ips mem=$mem "
1312
1434
loss_base=9.53389835
1435
+ if [ $( is_a100) ]; then
1436
+ loss_base=9.54253578
1437
+ fi
1313
1438
ips_base=-1
1314
1439
mem_base=-1
1315
1440
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1377,6 +1502,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
1377
1502
mem=-1
1378
1503
echo " result: loss=$loss ips=$ips mem=$mem "
1379
1504
loss_base=9.39066124
1505
+ if [ $( is_a100) ]; then
1506
+ loss_base=9.41613197
1507
+ fi
1380
1508
ips_base=-1
1381
1509
mem_base=-1
1382
1510
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1444,6 +1572,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
1444
1572
mem=-1
1445
1573
echo " result: loss=$loss ips=$ips mem=$mem "
1446
1574
loss_base=9.38235474
1575
+ if [ $( is_a100) ]; then
1576
+ loss_base=9.4053154
1577
+ fi
1447
1578
ips_base=-1
1448
1579
mem_base=-1
1449
1580
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1512,6 +1643,9 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
1512
1643
mem=-1
1513
1644
echo " result: loss=$loss ips=$ips mem=$mem "
1514
1645
loss_base=9.38256836
1646
+ if [ $( is_a100) ]; then
1647
+ loss_base=9.4055137
1648
+ fi
1515
1649
ips_base=-1
1516
1650
mem_base=-1
1517
1651
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
0 commit comments