Skip to content

Commit 823d69f

Browse files
authored
[BugFix] Fix async hang (#9276)
* update async save signal * fix async save hang
1 parent cf0f478 commit 823d69f

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,13 @@ def _save_checkpoint(self, model, metrics=None):
687687
# For ckpt integrity
688688
paddle.save(self.state.global_step, os.path.join(output_dir, ".checkpoint_done"))
689689

690-
def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False):
690+
def _save(
691+
self,
692+
output_dir: Optional[str] = None,
693+
state_dict=None,
694+
merge_tensor_parallel=False,
695+
signal_dir: Optional[str] = None,
696+
):
691697
output_dir = output_dir if output_dir is not None else self.args.output_dir
692698
os.makedirs(output_dir, exist_ok=True)
693699
logger.info(f"Saving model checkpoint to {output_dir}")

paddlenlp/trainer/trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
581581
# Load potential model checkpoint
582582
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
583583
uc_async_save = self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config
584-
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir, uc_async_save)
584+
resume_from_checkpoint = get_last_checkpoint(
585+
self.args.output_dir, signal_folder=self.args.output_signal_dir, uc_async_save=uc_async_save
586+
)
585587
if resume_from_checkpoint is None:
586588
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
587589

@@ -2509,7 +2511,7 @@ def _save_checkpoint(self, model, metrics=None):
25092511
need_to_rotate_checkpoints = need_to_rotate_checkpoints and self.args.local_rank == 0
25102512
if need_to_rotate_checkpoints:
25112513
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
2512-
self._rotate_checkpoints(use_mtime=False, output_dir=run_signal_dir)
2514+
self._rotate_checkpoints(use_mtime=True, output_dir=run_signal_dir)
25132515

25142516
if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and not ("async_save" in self.args.unified_checkpoint_config):
25152517
# save checkpoint_done file to ensure checkpoint is complete

0 commit comments

Comments
 (0)