Skip to content

Commit 0ce0743

Browse files
committed
fix commas
1 parent fbe613b commit 0ce0743

File tree

4 files changed

+39
-43
lines changed

4 files changed

+39
-43
lines changed

legacy/examples/RLHF/trainer_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
340340
fused_allreduce_gradients(list(model.parameters()), None)
341341

342342
# Pipeline parallel mode, handle gradient reduce here to overlap
343-
pipeline_parallel_config = (
344-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
345-
)
346-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
347-
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
343+
enable_dp_comm_overlap = False
344+
enable_release_grads = False
345+
if args.pipeline_parallel_degree > 1:
346+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
347+
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config
348348

349349
# Case 3: Pipeline parallel mode, overlap with dp
350350
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:

llm/alignment/ppo/trainer_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
340340
fused_allreduce_gradients(list(model.parameters()), None)
341341

342342
# Pipeline parallel mode, handle gradient reduce here to overlap
343-
pipeline_parallel_config = (
344-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
345-
)
346-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
347-
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
343+
enable_dp_comm_overlap = False
344+
enable_release_grads = False
345+
if args.pipeline_parallel_degree > 1:
346+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
347+
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config
348348

349349
# Case 3: Pipeline parallel mode, overlap with dp
350350
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:

paddlenlp/trainer/trainer.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,17 +1083,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
10831083
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
10841084

10851085
# Pipeline parallel mode, handle gradient reduce here to overlap
1086-
pipeline_parallel_config = (
1087-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
1088-
)
1089-
sharding_parallel_config = (
1090-
set(args.sharding_parallel_config.split(" ")) if args.sharding_parallel_degree > 1 else set()
1091-
)
1092-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
1093-
enable_release_grads = (
1094-
"enable_release_grads" in pipeline_parallel_config
1095-
or "enable_release_grads" in sharding_parallel_config
1096-
)
1086+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
1087+
1088+
enable_release_grads = False
1089+
if args.sharding_parallel_degree > 1:
1090+
enable_release_grads = "enable_release_grads" in args.sharding_parallel_config
1091+
if args.pipeline_parallel_degree > 1:
1092+
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config
10971093

10981094
# Case 3: Pipeline parallel mode, overlap with dp
10991095
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
@@ -1992,8 +1988,7 @@ def get_expected_keys(inputs, keys):
19921988
"please upgrade your paddle (using nightly version)."
19931989
)
19941990

1995-
sharding_parallel_config = set(self.args.sharding_parallel_config.split(" "))
1996-
if level == "os_g" and "enable_stage2_overlap" in sharding_parallel_config:
1991+
if level == "os_g" and "enable_stage2_overlap" in self.args.sharding_parallel_config:
19971992
model._set_reduce_overlap(True)
19981993
optimizer._set_broadcast_overlap(True, model)
19991994

@@ -2133,9 +2128,9 @@ def compute_loss(self, model, inputs, return_outputs=False):
21332128
def _enable_delay_scale_loss(self):
21342129
key = "enable_delay_scale_loss"
21352130
if self.args.pipeline_parallel_degree > 1:
2136-
return key in self.args.pipeline_parallel_config.split(" ")
2131+
return key in self.args.pipeline_parallel_config
21372132
elif self.args.tensor_parallel_degree > 1:
2138-
return key in self.args.tensor_parallel_config.split(" ")
2133+
return key in self.args.tensor_parallel_config
21392134
else:
21402135
return False
21412136

paddlenlp/trainer/training_args.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,10 +1039,10 @@ def __post_init__(self):
10391039
strategy = fleet.DistributedStrategy()
10401040
assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel"
10411041
if self.pipeline_parallel_degree > 1:
1042-
if " " in self.pipeline_parallel_config:
1043-
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
1044-
else:
1042+
if "," in self.pipeline_parallel_config:
10451043
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
1044+
else:
1045+
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
10461046
for x in pipeline_parallel_config:
10471047
if len(x) > 0:
10481048
if x not in [
@@ -1116,10 +1116,10 @@ def __post_init__(self):
11161116
if self.tensor_parallel_degree > 1:
11171117
strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed}
11181118

1119-
if " " in self.tensor_parallel_config:
1120-
mp_config = set(self.tensor_parallel_config.split(" "))
1121-
else:
1119+
if "," in self.tensor_parallel_config:
11221120
mp_config = set(self.tensor_parallel_config.split(","))
1121+
else:
1122+
mp_config = set(self.tensor_parallel_config.split(" "))
11231123

11241124
for x in mp_config:
11251125
if len(x) > 0:
@@ -1225,10 +1225,11 @@ def is_segment_parallel_supported():
12251225
strategy.hybrid_configs = hybrid_configs
12261226

12271227
if self.sharding_parallel_degree > 1:
1228-
if " " in self.sharding_parallel_config:
1229-
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
1230-
else:
1228+
if "," in self.sharding_parallel_config:
12311229
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
1230+
else:
1231+
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
1232+
12321233
for x in sharding_parallel_config:
12331234
if len(x) > 0:
12341235
if x not in [
@@ -1384,10 +1385,10 @@ def is_segment_parallel_supported():
13841385

13851386
# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
13861387
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
1387-
if " " in self.pipeline_parallel_config:
1388-
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
1389-
else:
1388+
if "," in self.pipeline_parallel_config:
13901389
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
1390+
else:
1391+
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
13911392
for x in pipeline_parallel_config:
13921393
if len(x) > 0:
13931394
if x not in [
@@ -1437,10 +1438,10 @@ def is_segment_parallel_supported():
14371438
if self.tensor_parallel_degree > 1:
14381439
mp_optimization = strategy.mp_optimization
14391440

1440-
if " " in self.tensor_parallel_config:
1441-
mp_config = set(self.tensor_parallel_config.split(" "))
1442-
else:
1441+
if "," in self.tensor_parallel_config:
14431442
mp_config = set(self.tensor_parallel_config.split(","))
1443+
else:
1444+
mp_config = set(self.tensor_parallel_config.split(" "))
14441445

14451446
for x in mp_config:
14461447
if len(x) > 0:
@@ -1473,10 +1474,10 @@ def is_segment_parallel_supported():
14731474
elif ShardingOption.FULL_SHARD in self.sharding:
14741475
sharding.stage = 3
14751476

1476-
if " " in self.sharding_parallel_config:
1477-
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
1478-
else:
1477+
if "," in self.sharding_parallel_config:
14791478
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
1479+
else:
1480+
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
14801481
for x in sharding_parallel_config:
14811482
if len(x) > 0:
14821483
if x not in [

0 commit comments

Comments
 (0)