Skip to content

Commit 0e7a4a2

Browse files
DesmonDayMangodadada
authored andcommitted
[Unified Checkpoint] Update async save info (PaddlePaddle#8982)
* [Unified checkpoint] update optimizer async save signal * [Unified Checkpoint] Update async save info
1 parent 9e3053c commit 0e7a4a2

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import collections
2020
import contextlib
2121
import inspect
22+
import json
2223
import math
2324
import os
2425
import random
@@ -2475,6 +2476,24 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
24752476
# Save a trained model and configuration using `save_pretrained()`.
24762477
# They can then be reloaded using `from_pretrained()`
24772478

2479+
local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
2480+
if (
2481+
strtobool(os.getenv("FLAG_LLM_PDC", "False"))
2482+
and local_rank == 0
2483+
and self.args.unified_checkpoint
2484+
and "async_save" in self.args.unified_checkpoint_config
2485+
):
2486+
os.makedirs(self.args.logging_dir, exist_ok=True)
2487+
world_size = paddle.distributed.get_world_size()
2488+
save_info = {
2489+
"world_size": world_size,
2490+
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
2491+
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
2492+
}
2493+
if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")):
2494+
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
2495+
json.dump(save_info, f)
2496+
24782497
if self.args.should_save:
24792498
if self.tokenizer is not None:
24802499
self.tokenizer.save_pretrained(output_dir)

0 commit comments

Comments
 (0)