From c63eb0070b90cdde68414d2137a86e7ea6d222d5 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Fri, 26 Apr 2024 16:15:07 +0800 Subject: [PATCH] Fix sharding overlap bug --- paddlenlp/trainer/training_args.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 3ff9e557378b..b0748d2c23cb 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1026,6 +1026,11 @@ def __post_init__(self): enable_dp_comm_overlap and enable_sharding_comm_overlap ), "dp_comm_overlap and sharding_comm_overlap cannot be enabled at the same time" + if enable_sharding_comm_overlap and not self.amp_master_grad: + raise ValueError( + "If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True." + ) + dygraph_pp_configs = { "delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False, "dp_comm_overlap": enable_dp_comm_overlap,