Skip to content

Commit 8333a28

Browse files
committed
fix commas
1 parent 5a508e5 commit 8333a28

File tree

4 files changed

+34
-50
lines changed

4 files changed

+34
-50
lines changed

legacy/examples/RLHF/trainer_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
340340
fused_allreduce_gradients(list(model.parameters()), None)
341341

342342
# Pipeline parallel mode, handle gradient reduce here to overlap
343-
pipeline_parallel_config = (
344-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
345-
)
346-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
347-
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
343+
enable_dp_comm_overlap = False
344+
enable_release_grads = False
345+
if args.pipeline_parallel_degree > 1:
346+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
347+
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config
348348

349349
# Case 3: Pipeline parallel mode, overlap with dp
350350
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:

llm/alignment/ppo/trainer_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
340340
fused_allreduce_gradients(list(model.parameters()), None)
341341

342342
# Pipeline parallel mode, handle gradient reduce here to overlap
343-
pipeline_parallel_config = (
344-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
345-
)
346-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
347-
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
343+
enable_dp_comm_overlap = False
344+
enable_release_grads = False
345+
if args.pipeline_parallel_degree > 1:
346+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
347+
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config
348348

349349
# Case 3: Pipeline parallel mode, overlap with dp
350350
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:

paddlenlp/trainer/trainer.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,17 +1083,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
10831083
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
10841084

10851085
# Pipeline parallel mode, handle gradient reduce here to overlap
1086-
pipeline_parallel_config = (
1087-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
1088-
)
1089-
sharding_parallel_config = (
1090-
set(args.sharding_parallel_config.split(" ")) if args.sharding_parallel_degree > 1 else set()
1091-
)
1092-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
1093-
enable_release_grads = (
1094-
"enable_release_grads" in pipeline_parallel_config
1095-
or "enable_release_grads" in sharding_parallel_config
1096-
)
1086+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
1087+
1088+
enable_release_grads = False
1089+
if args.sharding_parallel_degree > 1:
1090+
enable_release_grads = "enable_release_grads" in args.sharding_parallel_config
1091+
if args.pipeline_parallel_degree > 1:
1092+
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config
10971093

10981094
# Case 3: Pipeline parallel mode, overlap with dp
10991095
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
@@ -1992,8 +1988,7 @@ def get_expected_keys(inputs, keys):
19921988
"please upgrade your paddle (using nightly version)."
19931989
)
19941990

1995-
sharding_parallel_config = set(self.args.sharding_parallel_config.split(" "))
1996-
if level == "os_g" and "enable_stage2_overlap" in sharding_parallel_config:
1991+
if level == "os_g" and "enable_stage2_overlap" in self.args.sharding_parallel_config:
19971992
model._set_reduce_overlap(True)
19981993
optimizer._set_broadcast_overlap(True, model)
19991994

@@ -2133,9 +2128,9 @@ def compute_loss(self, model, inputs, return_outputs=False):
21332128
def _enable_delay_scale_loss(self):
21342129
key = "enable_delay_scale_loss"
21352130
if self.args.pipeline_parallel_degree > 1:
2136-
return key in self.args.pipeline_parallel_config.split(" ")
2131+
return key in self.args.pipeline_parallel_config
21372132
elif self.args.tensor_parallel_degree > 1:
2138-
return key in self.args.tensor_parallel_config.split(" ")
2133+
return key in self.args.tensor_parallel_config
21392134
else:
21402135
return False
21412136

paddlenlp/trainer/training_args.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,13 @@ def __post_init__(self):
10221022
logger.warning("set amp_master_grad to false since amp is disabled.")
10231023
self.amp_master_grad = False
10241024

1025+
def split_parallel_config(parallel_config):
1026+
if "," in parallel_config:
1027+
parallel_config = set(parallel_config.split(","))
1028+
else:
1029+
parallel_config = set(parallel_config.split(" "))
1030+
return parallel_config
1031+
10251032
# use_hybrid_parallel
10261033
if self.use_hybrid_parallel:
10271034

@@ -1039,10 +1046,7 @@ def __post_init__(self):
10391046
strategy = fleet.DistributedStrategy()
10401047
assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel"
10411048
if self.pipeline_parallel_degree > 1:
1042-
if " " in self.pipeline_parallel_config:
1043-
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
1044-
else:
1045-
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
1049+
pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config)
10461050
for x in pipeline_parallel_config:
10471051
if len(x) > 0:
10481052
if x not in [
@@ -1116,10 +1120,7 @@ def __post_init__(self):
11161120
if self.tensor_parallel_degree > 1:
11171121
strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed}
11181122

1119-
if " " in self.tensor_parallel_config:
1120-
mp_config = set(self.tensor_parallel_config.split(" "))
1121-
else:
1122-
mp_config = set(self.tensor_parallel_config.split(","))
1123+
mp_config = split_parallel_config(self.tensor_parallel_config)
11231124

11241125
for x in mp_config:
11251126
if len(x) > 0:
@@ -1225,10 +1226,8 @@ def is_segment_parallel_supported():
12251226
strategy.hybrid_configs = hybrid_configs
12261227

12271228
if self.sharding_parallel_degree > 1:
1228-
if " " in self.sharding_parallel_config:
1229-
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
1230-
else:
1231-
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
1229+
sharding_parallel_config = split_parallel_config(self.sharding_parallel_config)
1230+
12321231
for x in sharding_parallel_config:
12331232
if len(x) > 0:
12341233
if x not in [
@@ -1384,10 +1383,7 @@ def is_segment_parallel_supported():
13841383

13851384
# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
13861385
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
1387-
if " " in self.pipeline_parallel_config:
1388-
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
1389-
else:
1390-
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
1386+
pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config)
13911387
for x in pipeline_parallel_config:
13921388
if len(x) > 0:
13931389
if x not in [
@@ -1436,11 +1432,7 @@ def is_segment_parallel_supported():
14361432

14371433
if self.tensor_parallel_degree > 1:
14381434
mp_optimization = strategy.mp_optimization
1439-
1440-
if " " in self.tensor_parallel_config:
1441-
mp_config = set(self.tensor_parallel_config.split(" "))
1442-
else:
1443-
mp_config = set(self.tensor_parallel_config.split(","))
1435+
mp_config = split_parallel_config(self.tensor_parallel_config)
14441436

14451437
for x in mp_config:
14461438
if len(x) > 0:
@@ -1473,10 +1465,7 @@ def is_segment_parallel_supported():
14731465
elif ShardingOption.FULL_SHARD in self.sharding:
14741466
sharding.stage = 3
14751467

1476-
if " " in self.sharding_parallel_config:
1477-
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
1478-
else:
1479-
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
1468+
sharding_parallel_config = split_parallel_config(self.sharding_parallel_config)
14801469
for x in sharding_parallel_config:
14811470
if len(x) > 0:
14821471
if x not in [

0 commit comments

Comments
 (0)