File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -435,6 +435,9 @@ def load_state_dict(
435
435
raise ValueError ("Currently unsupport paddle weights file, use numpy instead." )
436
436
if metadata .get ("format" , "np" ) == "np" :
437
437
thread_num = int (os .environ .get ("LOAD_STATE_DICT_THREAD_NUM" , "1" ))
438
+ if thread_num > 1 :
439
+ logger .info (f"Set loading state_dict thread num to { thread_num } " )
440
+ state_dict , scale_dict = {}, {}
438
441
if thread_num <= 1 :
439
442
with safe_open (checkpoint_file , framework = "np" ) as f :
440
443
state_dict , scale_dict = _load_part_state_dict (
@@ -461,9 +464,9 @@ def load_state_dict(
461
464
for keys in keys_groups
462
465
}
463
466
for future in concurrent .futures .as_completed (future_to_key ):
464
- state_dict , scale_dict = future .result ()
465
- state_dict .update (state_dict )
466
- scale_dict .update (scale_dict )
467
+ res_state_dict , res_scale_dict = future .result ()
468
+ state_dict .update (res_state_dict )
469
+ scale_dict .update (res_scale_dict )
467
470
468
471
if device == "cpu" :
469
472
for k in list (state_dict .keys ()):
You can’t perform that action at this time.
0 commit comments