Skip to content

Commit d4e1791

Browse files
authored
add comm buffer size (#8963)
1 parent 0f0cc2d commit d4e1791

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,11 @@ class TrainingArguments:
540540
)
541541
},
542542
)
543+
sharding_comm_buffer_size_MB: int = field(
544+
default=-1,
545+
metadata={"help": ("Sharding fused comm buffer size in communication between sharding ranks. ")},
546+
)
547+
543548
save_sharded_model: bool = field(
544549
default=False,
545550
metadata={
@@ -1163,19 +1168,19 @@ def __post_init__(self):
11631168
# sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"].
11641169

11651170
if sync_param or sync_grad or sync_moment:
1166-
print(f"setting sync_param_name")
1171+
print("setting sync_param_name")
11671172
strategy.sync_param_name = [""]
11681173

11691174
if sync_param:
1170-
print(f"setting sync_param")
1175+
print("setting sync_param")
11711176
strategy.hybrid_configs["mp_configs"].sync_param = True
11721177

11731178
if sync_grad:
1174-
print(f"setting sync_grad")
1179+
print("setting sync_grad")
11751180
strategy.hybrid_configs["mp_configs"].sync_grad = True
11761181

11771182
if sync_moment:
1178-
print(f"setting sync_moment")
1183+
print("setting sync_moment")
11791184
strategy.hybrid_configs["mp_configs"].sync_moment = True
11801185

11811186
except:
@@ -1263,6 +1268,11 @@ def is_segment_parallel_supported():
12631268
)
12641269

12651270
try:
1271+
if self.sharding_comm_buffer_size_MB > 0:
1272+
strategy.hybrid_configs["sharding_configs"].comm_buffer_size_MB = int(
1273+
self.sharding_comm_buffer_size_MB
1274+
)
1275+
12661276
if "split_param" in sharding_parallel_config:
12671277
strategy.hybrid_configs["sharding_configs"].split_param = True
12681278

0 commit comments

Comments
 (0)