diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 1e5d9b32de05..20f887389dd1 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -498,6 +498,16 @@ def main(): config.seq_length % config.context_parallel_degree == 0 ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" + if training_args.sharding_parallel_config is not None: + # for stage1 overlap optimization + if ( + "enable_stage1_allgather_overlap" in training_args.sharding_parallel_config + or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config + ): + from paddle.io.reader import use_pinned_memory + + use_pinned_memory(False) + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: try: from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401