Skip to content

Commit 45f6314

Browse files
committed
update
1 parent 28590c5 commit 45f6314

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

llm/run_pretrain.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import List, Optional
2121

2222
import paddle
23+
from paddle.io.reader import use_pinned_memory
2324

2425
from paddlenlp.data.causal_dataset import (
2526
build_train_valid_test_datasets,
@@ -47,7 +48,6 @@
4748
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
4849
from paddlenlp.utils.log import logger
4950
from paddlenlp.utils.tools import get_env_device
50-
from paddle.io.reader import use_pinned_memory
5151

5252
# Pretaining Environment Variables to support sharding stage1 overlap optimization.
5353
os.environ["USE_CASUAL_MASK"] = "True"
@@ -403,6 +403,10 @@ def main():
403403
else:
404404
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
405405

406+
print("--888--" * 100)
407+
print("training_args:", training_args)
408+
print("--888--" * 100)
409+
406410
if training_args.enable_linear_fused_grad_add:
407411
from fused_layers import mock_layers
408412

@@ -499,6 +503,14 @@ def main():
499503
config.seq_length % config.context_parallel_degree == 0
500504
), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}"
501505

506+
if training_args.sharding_parallel_config is not None:
507+
# for stage1 overlap optimization
508+
if (
509+
"enable_stage1_allgather_overlap" in training_args.sharding_parallel_config
510+
or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config
511+
):
512+
use_pinned_memory(False)
513+
502514
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
503515
try:
504516
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
@@ -635,5 +647,4 @@ def main():
635647

636648

637649
if __name__ == "__main__":
638-
use_pinned_memory(False)
639650
main()

0 commit comments

Comments
 (0)