Skip to content

Commit 7551730

Browse files
[Unified Checkpoint] Accelerate loading checkpoint by multi-thread (#9034)
* [Unified Checkpoint] speed up loading checkpoint by multi thread * [Unified CHeckpoint] speed up load by multi-thread * [Unified CHeckpoint] speed up load by multi-thread * [Unified CHeckpoint] speed up load by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread * Unified CHeckpoint] speed up loading checkpoint by multi-thread
1 parent 6211e3d commit 7551730

File tree

1 file changed

+89
-16
lines changed

1 file changed

+89
-16
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 89 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import concurrent.futures
1617
import contextlib
1718
import copy
1819
import gc
@@ -319,6 +320,65 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
319320
return last_dtype
320321

321322

323+
def _split_keys_evenly(keys: list, n: int) -> list:
324+
"""Split a list into n lists with an equal number of elements.
325+
326+
Args:
327+
keys (list): the list to be split
328+
n (int): number of splits
329+
330+
Returns:
331+
result: list of lists
332+
"""
333+
334+
total_len = len(keys)
335+
base_size = total_len // n
336+
extra = total_len % n
337+
338+
result = []
339+
index = 0
340+
for _ in range(n):
341+
part_size = base_size + 1 if extra > 0 else base_size
342+
extra -= 1
343+
result.append(keys[index : index + part_size])
344+
index += part_size
345+
346+
return result
347+
348+
349+
def _load_part_state_dict(
350+
keys, checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping, fliter_dict_keys, device
351+
):
352+
"""load part state dict from checkpoint file.
353+
354+
Args:
355+
keys (list): the keys of part state dict
356+
checkpoint_file (str): the path of checkpoint file
357+
tensor_parallel_split_mapping (dict): mapping from key to function
358+
fliter_dict_keys (list): filter keys in state dict
359+
360+
Returns:
361+
part_state_dict (dict): the part state dict
362+
363+
"""
364+
part_state_dict = {}
365+
with safe_open(checkpoint_file, framework="np") as f:
366+
for key in keys:
367+
if fliter_dict_keys is not None and key not in fliter_dict_keys:
368+
continue
369+
py_safe_slice_ = f.get_slice(key)
370+
if key in tensor_parallel_split_mapping:
371+
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
372+
else:
373+
weight = py_safe_slice_[:]
374+
if device == "expected":
375+
with device_guard():
376+
weight = paddle.Tensor(weight, zero_copy=True)
377+
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
378+
part_state_dict[key] = weight
379+
return part_state_dict
380+
381+
322382
def load_state_dict(
323383
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu"
324384
):
@@ -343,21 +403,36 @@ def load_state_dict(
343403
if metadata.get("format", "np") == "pd":
344404
raise ValueError("Currently unsupport paddle weights file, use numpy instead.")
345405
if metadata.get("format", "np") == "np":
406+
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
346407
state_dict = {}
347-
with safe_open(checkpoint_file, framework="np") as f:
348-
for key in f.keys():
349-
if fliter_dict_keys is not None and key not in fliter_dict_keys:
350-
continue
351-
py_safe_slice_ = f.get_slice(key)
352-
if key in tensor_parallel_split_mapping:
353-
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
354-
else:
355-
weight = py_safe_slice_[:]
356-
if device == "expected":
357-
with device_guard():
358-
weight = paddle.Tensor(weight, zero_copy=True)
359-
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
360-
state_dict[key] = weight
408+
if thread_num <= 1:
409+
with safe_open(checkpoint_file, framework="np") as f:
410+
state_dict = _load_part_state_dict(
411+
list(f.keys()),
412+
checkpoint_file,
413+
tensor_parallel_split_mapping,
414+
fliter_dict_keys,
415+
device,
416+
)
417+
else:
418+
# Load state dict in multi-thread to speed up loading
419+
with safe_open(checkpoint_file, framework="np") as f:
420+
keys_groups = _split_keys_evenly(list(f.keys()), thread_num)
421+
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
422+
future_to_key = {
423+
executor.submit(
424+
_load_part_state_dict,
425+
keys,
426+
checkpoint_file,
427+
tensor_parallel_split_mapping,
428+
fliter_dict_keys,
429+
device,
430+
): keys
431+
for keys in keys_groups
432+
}
433+
for future in concurrent.futures.as_completed(future_to_key):
434+
result = future.result()
435+
state_dict.update(result)
361436

362437
if device == "cpu":
363438
for k in list(state_dict.keys()):
@@ -1963,7 +2038,6 @@ def _fuse_or_split_keys(
19632038

19642039
if config.quantization_config.is_weight_quantize():
19652040
filter_dict_keys = None
1966-
19672041
state_dict = load_state_dict(
19682042
shard_file, tp_actions if pre_tensor_parallel_split else None, filter_dict_keys
19692043
)
@@ -2279,7 +2353,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
22792353
else:
22802354
raise ValueError(f"Unexpected file: {resolved_archive_file} for weight conversion.")
22812355
# load pt weights early so that we know which dtype to init the model under
2282-
22832356
if not is_sharded and state_dict is None:
22842357
# 4. loading non-sharded ckpt from the state dict
22852358
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"):

0 commit comments

Comments
 (0)