Skip to content

Commit 568507b

Browse files
author
chenruibiao
committed
[DEV] Support sync params in tensor parallel config
1 parent beb433a commit 568507b

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

docs/trainer.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,20 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
521521
default -1 for not use tensor parallel, Suggest tensor_parallel_degree<=8 for better proformance.
522522
Note, this need model support in source code, currently GPT/BLOOM/LLAMA/BLOOM/CLM/CHATGLM is supported.
523523
524+
--tensor_parallel_config
525+
对于张量并行,一些选项会影响训练性能,这里将一些选项配置集中管理,以str形式传入配置.
526+
支持如下选项:
527+
enable_delay_scale_loss : 在优化器阶段做梯度累加,将所有梯度除以累加次数,而不是直接对loss除以累加次数。
528+
sync_param : 在优化器阶段使用broadcast同步所有is_distributed=False的参数
529+
sync_grad : 在优化器阶段使用broadcast同步所有is_distributed=False的梯度
530+
sync_moment : 在优化器阶段使用broadcast同步所有is_distributed=False的momentum
531+
532+
Some additional config it highly affect the usage of tensor parallel, we provide some option to config it.
533+
following config is support:
534+
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
535+
sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.
536+
sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.
537+
sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.
524538
525539
--pipeline_parallel_degree
526540
流水线并行是Megatron论文针对多层Transformer结构提出的按层划分方法.
@@ -549,7 +563,7 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
549563
following config is support:
550564
disable_p2p_cache_shape, if you max sequence length is varying, please set disable_p2p_cache_shape.
551565
disable_partial_send_recv, optmize send speed for tensor parallel.
552-
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
566+
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
553567
enable_dp_comm_overlap, fuse data parallel gradient communication.
554568
555569
--data_parallel_config

paddlenlp/trainer/training_args.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,16 @@ class TrainingArguments:
241241
enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance.
242242
enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.
243243
enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.
244-
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
244+
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
245+
sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.
246+
sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.
247+
sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.
245248
pipeline_parallel_config (`str`, *optional*)(
246249
Some additional config it highly affect the useage of pipeline parallel, we provide some option to config it.
247250
following config is support:
248251
disable_p2p_cache_shape, if you max sequence length is varying, please set disable_p2p_cache_shape.
249252
disable_partial_send_recv, optmize send speed for tensor parallel.
250-
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
253+
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.
251254
enable_dp_comm_overlap, fuse data parallel gradient communication.
252255
enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication.
253256
enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
@@ -600,7 +603,10 @@ class TrainingArguments:
600603
"enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance. \n"
601604
"enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.\n"
602605
"enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.\n"
603-
"enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.\n"
606+
"enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.\n"
607+
"sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.\n"
608+
"sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.\n"
609+
"sync_moment, in optimizer step, use broadcast to sync momentums those attr 'is_distributed' is False.\n"
604610
)
605611
},
606612
)
@@ -612,7 +618,7 @@ class TrainingArguments:
612618
"following config is support:\n"
613619
"disable_p2p_cache_shape, if you max sequence length is varying, please set disable_p2p_cache_shape. \n"
614620
"disable_partial_send_recv, optmize send speed for tensor parallel.\n"
615-
"enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.\n"
621+
"enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.\n"
616622
"enable_dp_comm_overlap, fuse data parallel gradient communication. \n"
617623
"enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication. \n"
618624
"enable_overlap_p2p_comm, overlap p2p communication with computation. \n"
@@ -1062,10 +1068,13 @@ def __post_init__(self):
10621068
"enable_mp_skip_c_identity",
10631069
"enable_mp_fused_linear_param_grad_add",
10641070
"enable_delay_scale_loss",
1071+
"sync_param",
1072+
"sync_grad",
1073+
"sync_moment",
10651074
]:
10661075
raise ValueError(
10671076
f"Found unknown tensor parallell config {x}, "
1068-
f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add"
1077+
f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity, enable_mp_fused_linear_param_grad_add, sync_param, sync_grad and sync_moment."
10691078
)
10701079
try:
10711080
if "enable_mp_async_allreduce" in mp_config:
@@ -1083,6 +1092,25 @@ def __post_init__(self):
10831092
warnings.warn(
10841093
"enable_mp_fused_linear_param_grad_add only works with enable_mp_async_allreduce. It will not work."
10851094
)
1095+
1096+
sync_param = sync_grad = sync_moment = True # For CI test
1097+
1098+
# sync_param = "sync_param" in mp_config
1099+
# sync_grad = "sync_grad" in mp_config
1100+
# sync_moment = "sync_moment" in mp_config
1101+
1102+
# sync_param_name = [""] matches any parameter name.
1103+
# If sync_param, sync_grad and sync_moment are not set, the default value in Paddle is :
1104+
# sync_param = True, sync_grad = False, sync_moment = False, sync_param_name = ["embedding", "layer_norm", ".b_"].
1105+
if sync_param:
1106+
strategy.hybrid_configs["mp_configs"].sync_param = True
1107+
strategy.hybrid_configs["mp_configs"].sync_param_name = [""]
1108+
if sync_grad:
1109+
strategy.hybrid_configs["mp_configs"].sync_grad = True
1110+
strategy.hybrid_configs["mp_configs"].sync_grad_name = [""]
1111+
if sync_moment:
1112+
strategy.hybrid_configs["mp_configs"].sync_moment = True
1113+
strategy.hybrid_configs["mp_configs"].sync_moment_name = [""]
10861114
except:
10871115
warnings.warn(
10881116
"The enable_mp_async_allreduce, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add are not supported "

0 commit comments

Comments
 (0)