@@ -1039,10 +1039,10 @@ def __post_init__(self):
1039
1039
strategy = fleet .DistributedStrategy ()
1040
1040
assert self .data_parallel_config == "" , "data_parallle_config is not supported in hybrid parallel"
1041
1041
if self .pipeline_parallel_degree > 1 :
1042
- if " " in self .pipeline_parallel_config :
1043
- pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1044
- else :
1042
+ if "," in self .pipeline_parallel_config :
1045
1043
pipeline_parallel_config = set (self .pipeline_parallel_config .split ("," ))
1044
+ else :
1045
+ pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1046
1046
for x in pipeline_parallel_config :
1047
1047
if len (x ) > 0 :
1048
1048
if x not in [
@@ -1116,10 +1116,10 @@ def __post_init__(self):
1116
1116
if self .tensor_parallel_degree > 1 :
1117
1117
strategy .tensor_parallel_configs = {"tensor_init_seed" : self .seed }
1118
1118
1119
- if " " in self .tensor_parallel_config :
1120
- mp_config = set (self .tensor_parallel_config .split (" " ))
1121
- else :
1119
+ if "," in self .tensor_parallel_config :
1122
1120
mp_config = set (self .tensor_parallel_config .split ("," ))
1121
+ else :
1122
+ mp_config = set (self .tensor_parallel_config .split (" " ))
1123
1123
1124
1124
for x in mp_config :
1125
1125
if len (x ) > 0 :
@@ -1225,10 +1225,10 @@ def is_segment_parallel_supported():
1225
1225
strategy .hybrid_configs = hybrid_configs
1226
1226
1227
1227
if self .sharding_parallel_degree > 1 :
1228
- if " " in self .sharding_parallel_config :
1229
- sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1230
- else :
1228
+ if "," in self .sharding_parallel_config :
1231
1229
sharding_parallel_config = set (self .sharding_parallel_config .split ("," ))
1230
+ else :
1231
+ sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1232
1232
for x in sharding_parallel_config :
1233
1233
if len (x ) > 0 :
1234
1234
if x not in [
@@ -1384,10 +1384,10 @@ def is_segment_parallel_supported():
1384
1384
1385
1385
# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
1386
1386
if self .pipeline_parallel_degree > 1 and self .gradient_accumulation_steps > 1 :
1387
- if " " in self .pipeline_parallel_config :
1388
- pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1389
- else :
1387
+ if "," in self .pipeline_parallel_config :
1390
1388
pipeline_parallel_config = set (self .pipeline_parallel_config .split ("," ))
1389
+ else :
1390
+ pipeline_parallel_config = set (self .pipeline_parallel_config .split (" " ))
1391
1391
for x in pipeline_parallel_config :
1392
1392
if len (x ) > 0 :
1393
1393
if x not in [
@@ -1437,10 +1437,10 @@ def is_segment_parallel_supported():
1437
1437
if self .tensor_parallel_degree > 1 :
1438
1438
mp_optimization = strategy .mp_optimization
1439
1439
1440
- if " " in self .tensor_parallel_config :
1441
- mp_config = set (self .tensor_parallel_config .split (" " ))
1442
- else :
1440
+ if "," in self .tensor_parallel_config :
1443
1441
mp_config = set (self .tensor_parallel_config .split ("," ))
1442
+ else :
1443
+ mp_config = set (self .tensor_parallel_config .split (" " ))
1444
1444
1445
1445
for x in mp_config :
1446
1446
if len (x ) > 0 :
@@ -1473,10 +1473,10 @@ def is_segment_parallel_supported():
1473
1473
elif ShardingOption .FULL_SHARD in self .sharding :
1474
1474
sharding .stage = 3
1475
1475
1476
- if " " in self .sharding_parallel_config :
1477
- sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1478
- else :
1476
+ if "," in self .sharding_parallel_config :
1479
1477
sharding_parallel_config = set (self .sharding_parallel_config .split ("," ))
1478
+ else :
1479
+ sharding_parallel_config = set (self .sharding_parallel_config .split (" " ))
1480
1480
for x in sharding_parallel_config :
1481
1481
if len (x ) > 0 :
1482
1482
if x not in [
0 commit comments