Skip to content

Commit 439f8f3

Browse files
authored
[Bug Fix] fix sharding stage1 allgather overlap bug, which needs to forbiden pin memory (#8594)
* forbiden pin memory
1 parent 2dc4d7b commit 439f8f3

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

llm/run_pretrain.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,16 @@ def main():
498498
config.seq_length % config.context_parallel_degree == 0
499499
), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}"
500500

501+
if training_args.sharding_parallel_config is not None:
502+
# for stage1 overlap optimization
503+
if (
504+
"enable_stage1_allgather_overlap" in training_args.sharding_parallel_config
505+
or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config
506+
):
507+
from paddle.io.reader import use_pinned_memory
508+
509+
use_pinned_memory(False)
510+
501511
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
502512
try:
503513
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

0 commit comments

Comments
 (0)