Skip to content

Commit feea4c3

Browse files
authored
[Cherry-pick] fix multi-threading load, fix single card load. (#9560)
* Fix multi-threading load_state_dict (#9464) * Update model_utils.py * Update model_utils.py * [Unified Checkpoint] fix single card loading without master weights (#9540)
1 parent d3be336 commit feea4c3

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

paddlenlp/trainer/unified_checkpoint/load_save_single_card.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ def load_single_card_optimizer(model, optimizer, resume_from_checkpoint: str):
225225
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
226226
else:
227227
key_name = "_".join([static_name, key_name[1]])
228+
else:
229+
key_name = "_".join([static_name, key_name[1]])
228230
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
229231
returned_optim_state_dict[key_name].name = key_name
230232
if has_master_weights:

paddlenlp/transformers/model_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def load_state_dict(
437437
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
438438
if thread_num > 1:
439439
logger.info(f"Set loading state_dict thread num to {thread_num}")
440-
state_dict = {}
440+
state_dict, scale_dict = {}, {}
441441
if thread_num <= 1:
442442
with safe_open(checkpoint_file, framework="np") as f:
443443
state_dict, scale_dict = _load_part_state_dict(
@@ -464,9 +464,9 @@ def load_state_dict(
464464
for keys in keys_groups
465465
}
466466
for future in concurrent.futures.as_completed(future_to_key):
467-
state_dict, scale_dict = future.result()
468-
state_dict.update(state_dict)
469-
scale_dict.update(scale_dict)
467+
res_state_dict, res_scale_dict = future.result()
468+
state_dict.update(res_state_dict)
469+
scale_dict.update(res_scale_dict)
470470

471471
if device == "cpu":
472472
for k in list(state_dict.keys()):

0 commit comments

Comments
 (0)