13
13
# limitations under the License.
14
14
from __future__ import annotations
15
15
16
+ import concurrent .futures
16
17
import contextlib
17
18
import copy
18
19
import gc
@@ -319,6 +320,65 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
319
320
return last_dtype
320
321
321
322
323
+ def _split_keys_evenly (keys : list , n : int ) -> list :
324
+ """Split a list into n lists with an equal number of elements.
325
+
326
+ Args:
327
+ keys (list): the list to be split
328
+ n (int): number of splits
329
+
330
+ Returns:
331
+ result: list of lists
332
+ """
333
+
334
+ total_len = len (keys )
335
+ base_size = total_len // n
336
+ extra = total_len % n
337
+
338
+ result = []
339
+ index = 0
340
+ for _ in range (n ):
341
+ part_size = base_size + 1 if extra > 0 else base_size
342
+ extra -= 1
343
+ result .append (keys [index : index + part_size ])
344
+ index += part_size
345
+
346
+ return result
347
+
348
+
349
+ def _load_part_state_dict (
350
+ keys , checkpoint_file : Union [str , os .PathLike ], tensor_parallel_split_mapping , fliter_dict_keys , device
351
+ ):
352
+ """load part state dict from checkpoint file.
353
+
354
+ Args:
355
+ keys (list): the keys of part state dict
356
+ checkpoint_file (str): the path of checkpoint file
357
+ tensor_parallel_split_mapping (dict): mapping from key to function
358
+ fliter_dict_keys (list): filter keys in state dict
359
+
360
+ Returns:
361
+ part_state_dict (dict): the part state dict
362
+
363
+ """
364
+ part_state_dict = {}
365
+ with safe_open (checkpoint_file , framework = "np" ) as f :
366
+ for key in keys :
367
+ if fliter_dict_keys is not None and key not in fliter_dict_keys :
368
+ continue
369
+ py_safe_slice_ = f .get_slice (key )
370
+ if key in tensor_parallel_split_mapping :
371
+ weight = tensor_parallel_split_mapping [key ](py_safe_slice_ )
372
+ else :
373
+ weight = py_safe_slice_ [:]
374
+ if device == "expected" :
375
+ with device_guard ():
376
+ weight = paddle .Tensor (weight , zero_copy = True )
377
+ weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
378
+ part_state_dict [key ] = weight
379
+ return part_state_dict
380
+
381
+
322
382
def load_state_dict (
323
383
checkpoint_file : Union [str , os .PathLike ], tensor_parallel_split_mapping = None , fliter_dict_keys = None , device = "cpu"
324
384
):
@@ -343,21 +403,36 @@ def load_state_dict(
343
403
if metadata .get ("format" , "np" ) == "pd" :
344
404
raise ValueError ("Currently unsupport paddle weights file, use numpy instead." )
345
405
if metadata .get ("format" , "np" ) == "np" :
406
+ thread_num = int (os .environ .get ("LOAD_STATE_DICT_THREAD_NUM" , "1" ))
346
407
state_dict = {}
347
- with safe_open (checkpoint_file , framework = "np" ) as f :
348
- for key in f .keys ():
349
- if fliter_dict_keys is not None and key not in fliter_dict_keys :
350
- continue
351
- py_safe_slice_ = f .get_slice (key )
352
- if key in tensor_parallel_split_mapping :
353
- weight = tensor_parallel_split_mapping [key ](py_safe_slice_ )
354
- else :
355
- weight = py_safe_slice_ [:]
356
- if device == "expected" :
357
- with device_guard ():
358
- weight = paddle .Tensor (weight , zero_copy = True )
359
- weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
360
- state_dict [key ] = weight
408
+ if thread_num <= 1 :
409
+ with safe_open (checkpoint_file , framework = "np" ) as f :
410
+ state_dict = _load_part_state_dict (
411
+ list (f .keys ()),
412
+ checkpoint_file ,
413
+ tensor_parallel_split_mapping ,
414
+ fliter_dict_keys ,
415
+ device ,
416
+ )
417
+ else :
418
+ # Load state dict in multi-thread to speed up loading
419
+ with safe_open (checkpoint_file , framework = "np" ) as f :
420
+ keys_groups = _split_keys_evenly (list (f .keys ()), thread_num )
421
+ with concurrent .futures .ThreadPoolExecutor (max_workers = thread_num ) as executor :
422
+ future_to_key = {
423
+ executor .submit (
424
+ _load_part_state_dict ,
425
+ keys ,
426
+ checkpoint_file ,
427
+ tensor_parallel_split_mapping ,
428
+ fliter_dict_keys ,
429
+ device ,
430
+ ): keys
431
+ for keys in keys_groups
432
+ }
433
+ for future in concurrent .futures .as_completed (future_to_key ):
434
+ result = future .result ()
435
+ state_dict .update (result )
361
436
362
437
if device == "cpu" :
363
438
for k in list (state_dict .keys ()):
@@ -1963,7 +2038,6 @@ def _fuse_or_split_keys(
1963
2038
1964
2039
if config .quantization_config .is_weight_quantize ():
1965
2040
filter_dict_keys = None
1966
-
1967
2041
state_dict = load_state_dict (
1968
2042
shard_file , tp_actions if pre_tensor_parallel_split else None , filter_dict_keys
1969
2043
)
@@ -2279,7 +2353,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
2279
2353
else :
2280
2354
raise ValueError (f"Unexpected file: { resolved_archive_file } for weight conversion." )
2281
2355
# load pt weights early so that we know which dtype to init the model under
2282
-
2283
2356
if not is_sharded and state_dict is None :
2284
2357
# 4. loading non-sharded ckpt from the state dict
2285
2358
if config .tensor_parallel_degree > 1 and resolved_archive_file .endswith ("model_state.pdparams" ):
0 commit comments