@@ -27,6 +27,15 @@ export llama_data_path=/llama_data
27
27
28
28
unset CUDA_VISIBLE_DEVICES
29
29
30
+ function is_a100() {
31
+ if [ $( nvidia-smi| grep A100| wc -l) -ne 0 ]; then
32
+ echo 1
33
+ else
34
+ echo 0
35
+ fi
36
+ }
37
+
38
+
30
39
function gpt_case_list_auto() {
31
40
gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1
32
41
gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8
@@ -100,6 +109,11 @@ function gpt_auto_recompute_bs16_fp32_DP1-MP1-PP1() {
100
109
loss_base=10.507633305
101
110
ips_base=3518
102
111
mem_base=11750.6
112
+ if [ $( is_a100) ]; then
113
+ loss_base=10.530449009
114
+ ips_base=16763
115
+ mem_base=11750.6
116
+ fi
103
117
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
104
118
echo " =========== $FUNCNAME run end ==========="
105
119
}
@@ -136,6 +150,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8() {
136
150
loss_base=10.570028400
137
151
ips_base=35050
138
152
mem_base=1988.9
153
+ if [ $( is_a100) ]; then
154
+ loss_base=10.559662151
155
+ ips_base=83918
156
+ mem_base=2022.7
157
+ fi
139
158
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
140
159
echo " =========== $FUNCNAME run end ==========="
141
160
}
@@ -173,6 +192,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP1-PP8_pir() {
173
192
loss_base=10.570028400
174
193
ips_base=35050
175
194
mem_base=1988.9
195
+ if [ $( is_a100) ]; then
196
+ loss_base=10.559662151
197
+ ips_base=83918
198
+ mem_base=2022.7
199
+ fi
176
200
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
177
201
echo " =========== $FUNCNAME run end ==========="
178
202
}
@@ -209,6 +233,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP1-MP2-PP4() {
209
233
loss_base=10.700293922
210
234
ips_base=32518
211
235
mem_base=1535.7
236
+ if [ $( is_a100) ]; then
237
+ loss_base=10.679453373
238
+ ips_base=79116
239
+ mem_base=1488.2
240
+ fi
212
241
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
213
242
echo " =========== $FUNCNAME run end ==========="
214
243
}
@@ -245,6 +274,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2() {
245
274
loss_base=10.672543240
246
275
ips_base=18681
247
276
mem_base=2135.7
277
+ if [ $( is_a100) ]; then
278
+ loss_base=10.651049423
279
+ ips_base=41174
280
+ mem_base=2064.5
281
+ fi
248
282
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
249
283
echo " =========== $FUNCNAME run end ==========="
250
284
}
@@ -282,6 +316,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_pir() {
282
316
loss_base=10.672543240
283
317
ips_base=18681
284
318
mem_base=2135.7
319
+ if [ $( is_a100) ]; then
320
+ loss_base=10.651049423
321
+ ips_base=41174
322
+ mem_base=2064.5
323
+ fi
285
324
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
286
325
echo " =========== $FUNCNAME run end ==========="
287
326
}
@@ -318,6 +357,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1() {
318
357
loss_base=10.720068359
319
358
ips_base=15232
320
359
mem_base=1999.2
360
+ if [ $( is_a100) ]; then
361
+ loss_base=10.657777309
362
+ ips_base=30027
363
+ mem_base=2002.0
364
+ fi
321
365
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
322
366
echo " =========== $FUNCNAME run end ==========="
323
367
}
@@ -355,6 +399,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage1_pir() {
355
399
loss_base=10.720068359
356
400
ips_base=15232
357
401
mem_base=1999.2
402
+ if [ $( is_a100) ]; then
403
+ loss_base=10.657777309
404
+ ips_base=30027
405
+ mem_base=2002.0
406
+ fi
358
407
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
359
408
echo " =========== $FUNCNAME run end ==========="
360
409
}
@@ -391,6 +440,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage2() {
391
440
loss_base=10.720078850
392
441
ips_base=15571
393
442
mem_base=1999.2
443
+ if [ $( is_a100) ]; then
444
+ loss_base=10.657803535
445
+ ips_base=29166
446
+ mem_base=2002.0
447
+ fi
394
448
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
395
449
echo " =========== $FUNCNAME run end ==========="
396
450
}
@@ -427,6 +481,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP4-MP2-Sharding4_stage3() {
427
481
loss_base=10.681921577
428
482
ips_base=13813
429
483
mem_base=1747.6
484
+ if [ $( is_a100) ]; then
485
+ loss_base=10.662137604
486
+ ips_base=24700
487
+ mem_base=1750.5
488
+ fi
430
489
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
431
490
echo " =========== $FUNCNAME run end ==========="
432
491
}
@@ -463,6 +522,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1() {
463
522
loss_base=10.579057693
464
523
ips_base=19822
465
524
mem_base=1709.8
525
+ if [ $( is_a100) ]; then
526
+ loss_base=10.586785984
527
+ ips_base=42813
528
+ mem_base=1743.8
529
+ fi
466
530
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
467
531
echo " =========== $FUNCNAME run end ==========="
468
532
}
@@ -500,6 +564,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage1_pir() {
500
564
loss_base=10.579057693
501
565
ips_base=19822
502
566
mem_base=1709.8
567
+ if [ $( is_a100) ]; then
568
+ loss_base=10.586785984
569
+ ips_base=42813
570
+ mem_base=1743.8
571
+ fi
503
572
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
504
573
echo " =========== $FUNCNAME run end ==========="
505
574
}
@@ -536,6 +605,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage2() {
536
605
loss_base=10.579057693
537
606
ips_base=20170
538
607
mem_base=1709.8
608
+ if [ $( is_a100) ]; then
609
+ loss_base=10.586785984
610
+ ips_base=42995
611
+ mem_base=1743.8
612
+ fi
539
613
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
540
614
echo " =========== $FUNCNAME run end ==========="
541
615
}
@@ -572,6 +646,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP1-PP4_Sharding2_stage3() {
572
646
loss_base=10.585316849
573
647
ips_base=15742
574
648
mem_base=1591.6
649
+ if [ $( is_a100) ]; then
650
+ loss_base=10.555718899
651
+ ips_base=34688
652
+ mem_base=1625.6
653
+ fi
575
654
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
576
655
echo " =========== $FUNCNAME run end ==========="
577
656
}
@@ -608,6 +687,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage1() {
608
687
loss_base=10.672568035
609
688
ips_base=19461
610
689
mem_base=1384.7
690
+ if [ $( is_a100) ]; then
691
+ loss_base=10.651032448
692
+ ips_base=42435
693
+ mem_base=1377.5
694
+ fi
611
695
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
612
696
echo " =========== $FUNCNAME run end ==========="
613
697
}
@@ -644,6 +728,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2() {
644
728
loss_base=10.672568035
645
729
ips_base=19652
646
730
mem_base=1384.7
731
+ if [ $( is_a100) ]; then
732
+ loss_base=10.651032448
733
+ ips_base=43008
734
+ mem_base=1377.5
735
+ fi
647
736
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
648
737
echo " =========== $FUNCNAME run end ==========="
649
738
}
@@ -681,6 +770,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage2_pir() {
681
770
loss_base=10.672568035
682
771
ips_base=19652
683
772
mem_base=1384.7
773
+ if [ $( is_a100) ]; then
774
+ loss_base=10.651032448
775
+ ips_base=43008
776
+ mem_base=1377.5
777
+ fi
684
778
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
685
779
echo " =========== $FUNCNAME run end ==========="
686
780
}
@@ -717,6 +811,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3() {
717
811
loss_base=10.696336079
718
812
ips_base=16613
719
813
mem_base=1280.5
814
+ if [ $( is_a100) ]; then
815
+ loss_base=10.705118465
816
+ ips_base=37104
817
+ mem_base=1217.3
818
+ fi
720
819
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
721
820
echo " =========== $FUNCNAME run end ==========="
722
821
}
@@ -754,6 +853,11 @@ function gpt_auto_recompute_bs16_fp16_o2_DP2-MP2-PP2_Sharding2_stage3_pir() {
754
853
loss_base=10.696336079
755
854
ips_base=16613
756
855
mem_base=1280.5
856
+ if [ $( is_a100) ]; then
857
+ loss_base=10.705118465
858
+ ips_base=37104
859
+ mem_base=1217.3
860
+ fi
757
861
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
758
862
echo " =========== $FUNCNAME run end ==========="
759
863
}
@@ -900,6 +1004,9 @@ function llama_static_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
900
1004
mem=-1
901
1005
echo " result: loss=$loss ips=$ips mem=$mem "
902
1006
loss_base=9.52110565
1007
+ if [ $( is_a100) ]; then
1008
+ loss_base=9.44003963
1009
+ fi
903
1010
ips_base=-1
904
1011
mem_base=-1
905
1012
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -966,6 +1073,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
966
1073
mem=-1
967
1074
echo " result: loss=$loss ips=$ips mem=$mem "
968
1075
loss_base=9.42011833
1076
+ if [ $( is_a100) ]; then
1077
+ loss_base=9.44003963
1078
+ fi
969
1079
ips_base=-1
970
1080
mem_base=-1
971
1081
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1032,6 +1142,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
1032
1142
mem=-1
1033
1143
echo " result: loss=$loss ips=$ips mem=$mem "
1034
1144
loss_base=9.44299471
1145
+ if [ $( is_a100) ]; then
1146
+ loss_base=9.45633757
1147
+ fi
1035
1148
ips_base=-1
1036
1149
mem_base=-1
1037
1150
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1098,6 +1211,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
1098
1211
mem=-1
1099
1212
echo " result: loss=$loss ips=$ips mem=$mem "
1100
1213
loss_base=9.45936012
1214
+ if [ $( is_a100) ]; then
1215
+ loss_base=9.46121407
1216
+ fi
1101
1217
ips_base=-1
1102
1218
mem_base=-1
1103
1219
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1166,6 +1282,9 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
1166
1282
mem=-1
1167
1283
echo " result: loss=$loss ips=$ips mem=$mem "
1168
1284
loss_base=9.46707726
1285
+ if [ $( is_a100) ]; then
1286
+ loss_base=9.44474411
1287
+ fi
1169
1288
ips_base=-1
1170
1289
mem_base=-1
1171
1290
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1235,6 +1354,9 @@ function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
1235
1354
mem=-1
1236
1355
echo " result: loss=$loss ips=$ips mem=$mem "
1237
1356
loss_base=10.0859375
1357
+ if [ $( is_a100) ]; then
1358
+ loss_base=10.125
1359
+ fi
1238
1360
ips_base=-1
1239
1361
mem_base=-1
1240
1362
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1302,6 +1424,9 @@ function llama_dygraph_auto_bs8_fp32_DP2() {
1302
1424
mem=-1
1303
1425
echo " result: loss=$loss ips=$ips mem=$mem "
1304
1426
loss_base=9.53389835
1427
+ if [ $( is_a100) ]; then
1428
+ loss_base=9.54253578
1429
+ fi
1305
1430
ips_base=-1
1306
1431
mem_base=-1
1307
1432
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1369,6 +1494,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2() {
1369
1494
mem=-1
1370
1495
echo " result: loss=$loss ips=$ips mem=$mem "
1371
1496
loss_base=9.39066124
1497
+ if [ $( is_a100) ]; then
1498
+ loss_base=9.41613197
1499
+ fi
1372
1500
ips_base=-1
1373
1501
mem_base=-1
1374
1502
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1436,6 +1564,9 @@ function llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2() {
1436
1564
mem=-1
1437
1565
echo " result: loss=$loss ips=$ips mem=$mem "
1438
1566
loss_base=9.38235474
1567
+ if [ $( is_a100) ]; then
1568
+ loss_base=9.4053154
1569
+ fi
1439
1570
ips_base=-1
1440
1571
mem_base=-1
1441
1572
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
@@ -1504,6 +1635,9 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
1504
1635
mem=-1
1505
1636
echo " result: loss=$loss ips=$ips mem=$mem "
1506
1637
loss_base=9.38256836
1638
+ if [ $( is_a100) ]; then
1639
+ loss_base=9.4055137
1640
+ fi
1507
1641
ips_base=-1
1508
1642
mem_base=-1
1509
1643
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
0 commit comments