From 5451d31486b95957864e21cbe1151f4968554ab0 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 21 Aug 2024 13:53:49 +0800 Subject: [PATCH 1/2] [Unified checkpoint] update optimizer async save signal --- paddlenlp/trainer/trainer.py | 7 ++++++- paddlenlp/trainer/trainer_utils.py | 3 ++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index b77c45b1427c..58207834a0bc 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2305,7 +2305,12 @@ def _save_checkpoint(self, model, metrics=None): self._save_ckpt_func(state_dict, save_path) with open(saved_signal_path, mode="w+") as f: f.write("1") - + else: + if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 + paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + if "skip_save_model_weight" not in self.args.unified_checkpoint_config: + paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) if self.args.should_save or self.args.use_expert_parallel: if not self.args.use_hybrid_parallel: logger.info("Saving optimizer files.") diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index a385e36550de..86504648cc48 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -46,6 +46,7 @@ from ..transformers.tokenizer_utils_base import BatchEncoding from ..utils.import_utils import is_paddle_cuda_available, is_psutil_available from ..utils.log import logger +from .utils.helper import distributed_file __all__ = [ "TrainOutput", @@ -273,7 +274,7 @@ def get_last_checkpoint(folder, uc_async_save=False): if os.path.exists(os.path.join(current_path, ".checkpoint_done")): return current_path else: - saving_info = paddle.load(os.path.join(current_path, ".saving_info")) + saving_info = paddle.load(distributed_file(os.path.join(current_path, ".saving_info"))) pre_world_size = saving_info.get("world_size", 1) ignore_save_lr_and_optim = saving_info.get("ignore_save_lr_and_optim", False) skip_save_model_weight = saving_info.get("skip_save_model_weight", False) From 469139997e22c1b12ceb05c9dd76d208f96110ce Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 21 Aug 2024 20:23:31 +0800 Subject: [PATCH 2/2] [Unified Checkpoint] Update async save info --- paddlenlp/trainer/trainer.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 58207834a0bc..957bb6aea306 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -19,6 +19,7 @@ import collections import contextlib import inspect +import json import math import os import random @@ -2471,6 +2472,24 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_ # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` + local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + if ( + strtobool(os.getenv("FLAG_LLM_PDC", "False")) + and local_rank == 0 + and self.args.unified_checkpoint + and "async_save" in self.args.unified_checkpoint_config + ): + os.makedirs(self.args.logging_dir, exist_ok=True) + world_size = paddle.distributed.get_world_size() + save_info = { + "world_size": world_size, + "ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim, + "skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config, + } + if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): + with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f: + json.dump(save_info, f) + if self.args.should_save: if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir)