Skip to content

Commit 88b1f3a

Browse files
committed
fix split in trainer
1 parent 3ebe938 commit 88b1f3a

File tree

4 files changed

+28
-41
lines changed

4 files changed

+28
-41
lines changed

legacy/examples/RLHF/trainer_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,8 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
340340
fused_allreduce_gradients(list(model.parameters()), None)
341341

342342
# Pipeline parallel mode, handle gradient reduce here to overlap
343-
pipeline_parallel_config = (
344-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
345-
)
346-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
347-
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
343+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
344+
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config
348345

349346
# Case 3: Pipeline parallel mode, overlap with dp
350347
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:

llm/alignment/ppo/trainer_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,8 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
340340
fused_allreduce_gradients(list(model.parameters()), None)
341341

342342
# Pipeline parallel mode, handle gradient reduce here to overlap
343-
pipeline_parallel_config = (
344-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
345-
)
346-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
347-
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
343+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
344+
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config
348345

349346
# Case 3: Pipeline parallel mode, overlap with dp
350347
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:

paddlenlp/trainer/trainer.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,16 +1083,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
10831083
fused_allreduce_gradients_no_sync(list(model.parameters()), None)
10841084

10851085
# Pipeline parallel mode, handle gradient reduce here to overlap
1086-
pipeline_parallel_config = (
1087-
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
1088-
)
1089-
sharding_parallel_config = (
1090-
set(args.sharding_parallel_config.split(" ")) if args.sharding_parallel_degree > 1 else set()
1091-
)
1092-
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
1086+
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
10931087
enable_release_grads = (
1094-
"enable_release_grads" in pipeline_parallel_config
1095-
or "enable_release_grads" in sharding_parallel_config
1088+
"enable_release_grads" in args.pipeline_parallel_config
1089+
or "enable_release_grads" in args.sharding_parallel_config
10961090
)
10971091

10981092
# Case 3: Pipeline parallel mode, overlap with dp
@@ -1992,8 +1986,7 @@ def get_expected_keys(inputs, keys):
19921986
"please upgrade your paddle (using nightly version)."
19931987
)
19941988

1995-
sharding_parallel_config = set(self.args.sharding_parallel_config.split(" "))
1996-
if level == "os_g" and "enable_stage2_overlap" in sharding_parallel_config:
1989+
if level == "os_g" and "enable_stage2_overlap" in self.args.sharding_parallel_config:
19971990
model._set_reduce_overlap(True)
19981991
optimizer._set_broadcast_overlap(True, model)
19991992

@@ -2133,9 +2126,9 @@ def compute_loss(self, model, inputs, return_outputs=False):
21332126
def _enable_delay_scale_loss(self):
21342127
key = "enable_delay_scale_loss"
21352128
if self.args.pipeline_parallel_degree > 1:
2136-
return key in self.args.pipeline_parallel_config.split(" ")
2129+
return key in self.args.pipeline_parallel_config
21372130
elif self.args.tensor_parallel_degree > 1:
2138-
return key in self.args.tensor_parallel_config.split(" ")
2131+
return key in self.args.tensor_parallel_config
21392132
else:
21402133
return False
21412134

paddlenlp/trainer/training_args.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,10 +1039,10 @@ def __post_init__(self):
10391039
strategy = fleet.DistributedStrategy()
10401040
assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel"
10411041
if self.pipeline_parallel_degree > 1:
1042-
if " " in self.pipeline_parallel_config:
1043-
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
1044-
else:
1042+
if "," in self.pipeline_parallel_config:
10451043
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
1044+
else:
1045+
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
10461046
for x in pipeline_parallel_config:
10471047
if len(x) > 0:
10481048
if x not in [
@@ -1116,10 +1116,10 @@ def __post_init__(self):
11161116
if self.tensor_parallel_degree > 1:
11171117
strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed}
11181118

1119-
if " " in self.tensor_parallel_config:
1120-
mp_config = set(self.tensor_parallel_config.split(" "))
1121-
else:
1119+
if "," in self.tensor_parallel_config:
11221120
mp_config = set(self.tensor_parallel_config.split(","))
1121+
else:
1122+
mp_config = set(self.tensor_parallel_config.split(" "))
11231123

11241124
for x in mp_config:
11251125
if len(x) > 0:
@@ -1225,10 +1225,10 @@ def is_segment_parallel_supported():
12251225
strategy.hybrid_configs = hybrid_configs
12261226

12271227
if self.sharding_parallel_degree > 1:
1228-
if " " in self.sharding_parallel_config:
1229-
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
1230-
else:
1228+
if "," in self.sharding_parallel_config:
12311229
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
1230+
else:
1231+
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
12321232
for x in sharding_parallel_config:
12331233
if len(x) > 0:
12341234
if x not in [
@@ -1384,10 +1384,10 @@ def is_segment_parallel_supported():
13841384

13851385
# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
13861386
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
1387-
if " " in self.pipeline_parallel_config:
1388-
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
1389-
else:
1387+
if "," in self.pipeline_parallel_config:
13901388
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
1389+
else:
1390+
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
13911391
for x in pipeline_parallel_config:
13921392
if len(x) > 0:
13931393
if x not in [
@@ -1437,10 +1437,10 @@ def is_segment_parallel_supported():
14371437
if self.tensor_parallel_degree > 1:
14381438
mp_optimization = strategy.mp_optimization
14391439

1440-
if " " in self.tensor_parallel_config:
1441-
mp_config = set(self.tensor_parallel_config.split(" "))
1442-
else:
1440+
if "," in self.tensor_parallel_config:
14431441
mp_config = set(self.tensor_parallel_config.split(","))
1442+
else:
1443+
mp_config = set(self.tensor_parallel_config.split(" "))
14441444

14451445
for x in mp_config:
14461446
if len(x) > 0:
@@ -1473,10 +1473,10 @@ def is_segment_parallel_supported():
14731473
elif ShardingOption.FULL_SHARD in self.sharding:
14741474
sharding.stage = 3
14751475

1476-
if " " in self.sharding_parallel_config:
1477-
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
1478-
else:
1476+
if "," in self.sharding_parallel_config:
14791477
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
1478+
else:
1479+
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
14801480
for x in sharding_parallel_config:
14811481
if len(x) > 0:
14821482
if x not in [

0 commit comments

Comments
 (0)