File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -498,6 +498,16 @@ def main():
498
498
config .seq_length % config .context_parallel_degree == 0
499
499
), f"seq_length:{ config .seq_length } must be divisible by context_parallel_degree { config .context_parallel_degree } "
500
500
501
+ if training_args .sharding_parallel_config is not None :
502
+ # for stage1 overlap optimization
503
+ if (
504
+ "enable_stage1_allgather_overlap" in training_args .sharding_parallel_config
505
+ or "enable_stage1_broadcast_overlap" in training_args .sharding_parallel_config
506
+ ):
507
+ from paddle .io .reader import use_pinned_memory
508
+
509
+ use_pinned_memory (False )
510
+
501
511
if get_env_device () == "xpu" and training_args .gradient_accumulation_steps > 1 :
502
512
try :
503
513
from paddle_xpu .layers .nn .linear import LinearConfig # noqa: F401
You can’t perform that action at this time.
0 commit comments