@@ -540,6 +540,11 @@ class TrainingArguments:
540
540
)
541
541
},
542
542
)
543
+ sharding_comm_buffer_size_MB : int = field (
544
+ default = - 1 ,
545
+ metadata = {"help" : ("Sharding fused comm buffer size in communication between sharding ranks. " )},
546
+ )
547
+
543
548
save_sharded_model : bool = field (
544
549
default = False ,
545
550
metadata = {
@@ -1163,19 +1168,19 @@ def __post_init__(self):
1163
1168
# sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"].
1164
1169
1165
1170
if sync_param or sync_grad or sync_moment :
1166
- print (f "setting sync_param_name" )
1171
+ print ("setting sync_param_name" )
1167
1172
strategy .sync_param_name = ["" ]
1168
1173
1169
1174
if sync_param :
1170
- print (f "setting sync_param" )
1175
+ print ("setting sync_param" )
1171
1176
strategy .hybrid_configs ["mp_configs" ].sync_param = True
1172
1177
1173
1178
if sync_grad :
1174
- print (f "setting sync_grad" )
1179
+ print ("setting sync_grad" )
1175
1180
strategy .hybrid_configs ["mp_configs" ].sync_grad = True
1176
1181
1177
1182
if sync_moment :
1178
- print (f "setting sync_moment" )
1183
+ print ("setting sync_moment" )
1179
1184
strategy .hybrid_configs ["mp_configs" ].sync_moment = True
1180
1185
1181
1186
except :
@@ -1263,6 +1268,11 @@ def is_segment_parallel_supported():
1263
1268
)
1264
1269
1265
1270
try :
1271
+ if self .sharding_comm_buffer_size_MB > 0 :
1272
+ strategy .hybrid_configs ["sharding_configs" ].comm_buffer_size_MB = int (
1273
+ self .sharding_comm_buffer_size_MB
1274
+ )
1275
+
1266
1276
if "split_param" in sharding_parallel_config :
1267
1277
strategy .hybrid_configs ["sharding_configs" ].split_param = True
1268
1278
0 commit comments