|
20 | 20 | from typing import List, Optional
|
21 | 21 |
|
22 | 22 | import paddle
|
| 23 | +from paddle.io.reader import use_pinned_memory |
23 | 24 |
|
24 | 25 | from paddlenlp.data.causal_dataset import (
|
25 | 26 | build_train_valid_test_datasets,
|
|
47 | 48 | from paddlenlp.utils.batch_sampler import DistributedBatchSampler
|
48 | 49 | from paddlenlp.utils.log import logger
|
49 | 50 | from paddlenlp.utils.tools import get_env_device
|
50 |
| -from paddle.io.reader import use_pinned_memory |
51 | 51 |
|
52 | 52 | # Pretaining Environment Variables to support sharding stage1 overlap optimization.
|
53 | 53 | os.environ["USE_CASUAL_MASK"] = "True"
|
@@ -403,6 +403,10 @@ def main():
|
403 | 403 | else:
|
404 | 404 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
405 | 405 |
|
| 406 | + print("--888--" * 100) |
| 407 | + print("training_args:", training_args) |
| 408 | + print("--888--" * 100) |
| 409 | + |
406 | 410 | if training_args.enable_linear_fused_grad_add:
|
407 | 411 | from fused_layers import mock_layers
|
408 | 412 |
|
@@ -499,6 +503,14 @@ def main():
|
499 | 503 | config.seq_length % config.context_parallel_degree == 0
|
500 | 504 | ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}"
|
501 | 505 |
|
| 506 | + if training_args.sharding_parallel_config is not None: |
| 507 | + # for stage1 overlap optimization |
| 508 | + if ( |
| 509 | + "enable_stage1_allgather_overlap" in training_args.sharding_parallel_config |
| 510 | + or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config |
| 511 | + ): |
| 512 | + use_pinned_memory(False) |
| 513 | + |
502 | 514 | if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
|
503 | 515 | try:
|
504 | 516 | from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401
|
@@ -635,5 +647,4 @@ def main():
|
635 | 647 |
|
636 | 648 |
|
637 | 649 | if __name__ == "__main__":
|
638 |
| - use_pinned_memory(False) |
639 | 650 | main()
|
0 commit comments