Skip to content

Commit d7ae1d0

Browse files
committed
update async_save_info
1 parent 7deb33c commit d7ae1d0

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,16 +2236,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
22362236
self.model_wrapped.get_all_parameters(convert2cpu=True)
22372237

22382238
if self.args.should_save_model_state:
2239-
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
2240-
# backup and remove unified_checkpoint_config for not trine stage
2241-
if not self.is_in_train:
2242-
self.args.unified_checkpoint_config = []
2243-
22442239
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
2245-
2246-
# recover unified_checkpoint_config for not trine stage
2247-
if not self.is_in_train:
2248-
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
22492240
else:
22502241
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
22512242
os.makedirs(output_dir, exist_ok=True)
@@ -2523,10 +2514,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
25232514
# Save a trained model and configuration using `save_pretrained()`.
25242515
# They can then be reloaded using `from_pretrained()`
25252516

2526-
local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
25272517
if (
25282518
strtobool(os.getenv("FLAG_LLM_PDC", "False"))
2529-
and local_rank == 0
2519+
and paddle.distributed.get_rank() == 0
25302520
and self.args.unified_checkpoint
25312521
and "async_save" in self.args.unified_checkpoint_config
25322522
):
@@ -2537,9 +2527,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
25372527
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
25382528
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
25392529
}
2540-
if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")):
2541-
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
2542-
json.dump(save_info, f)
2530+
if os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): # afs cannot overwrite
2531+
os.remove(os.path.join(self.args.logging_dir, "async_save_info.json"))
2532+
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
2533+
json.dump(save_info, f)
25432534

25442535
if self.args.should_save:
25452536
if self.tokenizer is not None:
@@ -2548,7 +2539,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
25482539
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
25492540

25502541
if self.args.unified_checkpoint:
2542+
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
2543+
# backup and remove unified_checkpoint_config for not trine stage
2544+
if not self.is_in_train:
2545+
self.args.unified_checkpoint_config = []
2546+
25512547
self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir)
2548+
2549+
# recover unified_checkpoint_config for not trine stage
2550+
if not self.is_in_train:
2551+
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
2552+
25522553
return
25532554

25542555
merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel

0 commit comments

Comments
 (0)