Skip to content

Commit 9e3053c

Browse files
ForFishesMangodadada
authored andcommitted
[Cherry-pick] add comm buffer size (PaddlePaddle#8963) (PaddlePaddle#9031)
* add comm buffer size (PaddlePaddle#8963) * add doc
1 parent 4452d29 commit 9e3053c

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

docs/trainer.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,14 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并
506506
with 8 cards, then set sharding_degree=8, sharding will only communication inside machine.
507507
default -1 means sharding parameters between all workers. (`int`, *optional*, defaults to `-1`)
508508
509+
--sharding_comm_buffer_size_MB
510+
设置sharding的通信中fuse梯度的大小。此选项只在sharding选项开启时候生效。
511+
默认值为-1,表示所有通信fuse的梯度大小按照默认配置,默认配置是256MB。
512+
(`int`, 可选, 默认为 `-1`)
513+
514+
Set the size of the fuse gradient in sharding communication. This option only takes effect when the sharding option is turned on.The default value is -1, which means that the gradient size of all communication fuses follows the default configuration, which is 256MB.
515+
(`int`, optional, default `-1`)
516+
509517
--tensor_parallel_degree
510518
张量并行是Megatron论文针对Transformer结构的张量切分方法.
511519
此方法将一层transformer的计算划分到了不同卡上.

paddlenlp/trainer/training_args.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,17 @@ class TrainingArguments:
549549
)
550550
},
551551
)
552+
sharding_comm_buffer_size_MB: int = field(
553+
default=-1,
554+
metadata={
555+
"help": (
556+
"Set the size of the fuse gradient in sharding communication. This option only takes effect when "
557+
"the sharding option is turned on.The default value is -1, which means that the gradient size of "
558+
"all communication fuses follows the default configuration, which is 256MB. "
559+
)
560+
},
561+
)
562+
552563
save_sharded_model: bool = field(
553564
default=False,
554565
metadata={
@@ -1293,6 +1304,11 @@ def is_segment_parallel_supported():
12931304
)
12941305

12951306
try:
1307+
if self.sharding_comm_buffer_size_MB > 0:
1308+
strategy.hybrid_configs["sharding_configs"].comm_buffer_size_MB = int(
1309+
self.sharding_comm_buffer_size_MB
1310+
)
1311+
12961312
if "split_param" in sharding_parallel_config:
12971313
strategy.hybrid_configs["sharding_configs"].split_param = True
12981314

0 commit comments

Comments
 (0)