Skip to content

Commit 7c3ab53

Browse files
authored
Fix sharding overlap bug (#8333)
1 parent ea2926c commit 7c3ab53

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,11 @@ def __post_init__(self):
10261026
enable_dp_comm_overlap and enable_sharding_comm_overlap
10271027
), "dp_comm_overlap and sharding_comm_overlap cannot be enabled at the same time"
10281028

1029+
if enable_sharding_comm_overlap and not self.amp_master_grad:
1030+
raise ValueError(
1031+
"If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True."
1032+
)
1033+
10291034
dygraph_pp_configs = {
10301035
"delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False,
10311036
"dp_comm_overlap": enable_dp_comm_overlap,

0 commit comments

Comments
 (0)