@@ -1022,6 +1022,13 @@ def __post_init__(self):
1022
1022
logger .warning ("set amp_master_grad to false since amp is disabled." )
1023
1023
self .amp_master_grad = False
1024
1024
1025
+ def split_parallel_config (parallel_config ):
1026
+ if "," in parallel_config :
1027
+ parallel_config = set (parallel_config .split ("," ))
1028
+ else :
1029
+ parallel_config = set (parallel_config .split (" " ))
1030
+ return parallel_config
1031
+
1025
1032
# use_hybrid_parallel
1026
1033
if self .use_hybrid_parallel :
1027
1034
@@ -1039,10 +1046,7 @@ def __post_init__(self):
1039
1046
strategy = fleet .DistributedStrategy ()
1040
1047
assert self .data_parallel_config == "" , "data_parallle_config is not supported in hybrid parallel"
1041
1048
if self .pipeline_parallel_degree > 1 :
1042
- if " " in self .pipeline_parallel_config :
1043
- pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1044
- else :
1045
- pipeline_parallel_config = set (self .pipeline_parallel_config .split ("," ))
1049
+ pipeline_parallel_config = split_parallel_config (self .pipeline_parallel_config )
1046
1050
for x in pipeline_parallel_config :
1047
1051
if len (x ) > 0 :
1048
1052
if x not in [
@@ -1116,10 +1120,7 @@ def __post_init__(self):
1116
1120
if self .tensor_parallel_degree > 1 :
1117
1121
strategy .tensor_parallel_configs = {"tensor_init_seed" : self .seed }
1118
1122
1119
- if " " in self .tensor_parallel_config :
1120
- mp_config = set (self .tensor_parallel_config .split (" " ))
1121
- else :
1122
- mp_config = set (self .tensor_parallel_config .split ("," ))
1123
+ mp_config = split_parallel_config (self .tensor_parallel_config )
1123
1124
1124
1125
for x in mp_config :
1125
1126
if len (x ) > 0 :
@@ -1225,10 +1226,8 @@ def is_segment_parallel_supported():
1225
1226
strategy .hybrid_configs = hybrid_configs
1226
1227
1227
1228
if self .sharding_parallel_degree > 1 :
1228
- if " " in self .sharding_parallel_config :
1229
- sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1230
- else :
1231
- sharding_parallel_config = set (self .sharding_parallel_config .split ("," ))
1229
+ sharding_parallel_config = split_parallel_config (self .sharding_parallel_config )
1230
+
1232
1231
for x in sharding_parallel_config :
1233
1232
if len (x ) > 0 :
1234
1233
if x not in [
@@ -1384,10 +1383,7 @@ def is_segment_parallel_supported():
1384
1383
1385
1384
# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
1386
1385
if self .pipeline_parallel_degree > 1 and self .gradient_accumulation_steps > 1 :
1387
- if " " in self .pipeline_parallel_config :
1388
- pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1389
- else :
1390
- pipeline_parallel_config = set (self .pipeline_parallel_config .split ("," ))
1386
+ pipeline_parallel_config = split_parallel_config (self .pipeline_parallel_config )
1391
1387
for x in pipeline_parallel_config :
1392
1388
if len (x ) > 0 :
1393
1389
if x not in [
@@ -1436,11 +1432,7 @@ def is_segment_parallel_supported():
1436
1432
1437
1433
if self .tensor_parallel_degree > 1 :
1438
1434
mp_optimization = strategy .mp_optimization
1439
-
1440
- if " " in self .tensor_parallel_config :
1441
- mp_config = set (self .tensor_parallel_config .split (" " ))
1442
- else :
1443
- mp_config = set (self .tensor_parallel_config .split ("," ))
1435
+ mp_config = split_parallel_config (self .tensor_parallel_config )
1444
1436
1445
1437
for x in mp_config :
1446
1438
if len (x ) > 0 :
@@ -1473,10 +1465,7 @@ def is_segment_parallel_supported():
1473
1465
elif ShardingOption .FULL_SHARD in self .sharding :
1474
1466
sharding .stage = 3
1475
1467
1476
- if " " in self .sharding_parallel_config :
1477
- sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1478
- else :
1479
- sharding_parallel_config = set (self .sharding_parallel_config .split ("," ))
1468
+ sharding_parallel_config = split_parallel_config (self .sharding_parallel_config )
1480
1469
for x in sharding_parallel_config :
1481
1470
if len (x ) > 0 :
1482
1471
if x not in [
0 commit comments