Skip to content

Commit 7b92936

Browse files
zhangyuqin1998lvdongyi
authored andcommitted
[Auto Parallel] Fix ckpt_converter for auto_parallel (PaddlePaddle#9136)
1 parent 839dd33 commit 7b92936

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,5 +780,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
780780
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
781781
self._load_ckpt_func(state_dict, ckpt_path)
782782

783+
if self.args.to_static:
784+
self.model_wrapped.set_state_dict(model_state_dict)
785+
self.model_wrapped.set_state_dict(optim_state_dict)
783786
# release memory
784787
del state_dict

paddlenlp/trainer/utils/ckpt_converter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,8 @@ def load_from_hybrid_parallel_checkpoint(self):
148148

149149
# In this scenario, the data type of the model state is bfloat16.
150150
for param_name, param_value in model_params.items():
151-
if param_value.is_dist():
152-
master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"]
153-
cast_master_weight = paddle.cast(master_weight._local_value(), param_value.dtype)
154-
paddle.assign(cast_master_weight, param_value._local_value())
155-
else:
151+
if param_value._is_initialized():
152+
# These codes are compatible for both dense tensor and dist tensor
156153
master_weight = self.auto_parallel_state_dict[param_name + ".master_weight"]
157154
cast_master_weight = paddle.cast(master_weight, param_value.dtype)
158155
paddle.assign(cast_master_weight, param_value)

0 commit comments

Comments
 (0)