Skip to content

Commit c4d79f4

Browse files
authored
Fix multi-threading load_state_dict (#9464)
* Update model_utils.py * Update model_utils.py
1 parent 2b3d7bf commit c4d79f4

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,9 @@ def load_state_dict(
435435
raise ValueError("Currently unsupport paddle weights file, use numpy instead.")
436436
if metadata.get("format", "np") == "np":
437437
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
438+
if thread_num > 1:
439+
logger.info(f"Set loading state_dict thread num to {thread_num}")
440+
state_dict, scale_dict = {}, {}
438441
if thread_num <= 1:
439442
with safe_open(checkpoint_file, framework="np") as f:
440443
state_dict, scale_dict = _load_part_state_dict(
@@ -461,9 +464,9 @@ def load_state_dict(
461464
for keys in keys_groups
462465
}
463466
for future in concurrent.futures.as_completed(future_to_key):
464-
state_dict, scale_dict = future.result()
465-
state_dict.update(state_dict)
466-
scale_dict.update(scale_dict)
467+
res_state_dict, res_scale_dict = future.result()
468+
state_dict.update(res_state_dict)
469+
scale_dict.update(res_scale_dict)
467470

468471
if device == "cpu":
469472
for k in list(state_dict.keys()):

0 commit comments

Comments
 (0)