Skip to content

Commit d6ac1bd

Browse files
authored
[Performance] Optimize unified checkpoint save/load speed. (#8204)
* opt unified checkpoint save/load speed. * fix bug. * add fast safe open API. * mix file open and mmap. * fix * add test for read fast read tensors. * fix * fix tests. * remove profile log. * fix * fix ci
1 parent 2f3faff commit d6ac1bd

File tree

7 files changed

+485
-39
lines changed

7 files changed

+485
-39
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from paddlenlp.transformers.model_utils import (
3131
PretrainedModel,
3232
_load_state_dict_into_model,
33+
faster_set_state_dict,
3334
get_parameter_dtype,
3435
load_state_dict,
3536
unwrap_model,
@@ -64,9 +65,10 @@
6465
from paddlenlp.utils.log import logger
6566

6667
if is_safetensors_available():
67-
from safetensors import safe_open
68+
# from safetensors import safe_open
6869
from safetensors.numpy import save_file as safe_save_file
6970

71+
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
7072

7173
FP32_MASTER = "fp32_master_0"
7274
optimizer_scalar_name = [
@@ -195,7 +197,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
195197
Returns:
196198
None
197199
"""
198-
199200
if paddle.distributed.get_world_size() <= 1:
200201
load_single_card_checkpoint(args, model, resume_from_checkpoint)
201202
return
@@ -221,7 +222,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
221222
pretrained_model_name_or_path=resume_from_checkpoint,
222223
index_filename=os.path.join(resume_from_checkpoint, index_filename),
223224
)
224-
225225
loaded_keys = sharded_metadata["all_checkpoint_keys"]
226226

227227
model_state_dict = get_expected_state_dict(model)
@@ -265,7 +265,9 @@ def _remove_unused_keys(
265265
else:
266266
tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True)
267267
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
268-
state_dict = load_state_dict(shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys)
268+
state_dict = load_state_dict(
269+
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected"
270+
)
269271

270272
if not pre_tensor_parallel_split:
271273
# Since we load all keys but we only need one of pipeline stages
@@ -278,11 +280,12 @@ def _remove_unused_keys(
278280
None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1
279281
)
280282

281-
error_msgs += _load_state_dict_into_model(model, state_dict, "")
283+
# error_msgs += _load_state_dict_into_model(model, state_dict, "")
284+
error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False)
282285

283286
# force memory release
284287
del state_dict
285-
gc.collect()
288+
# gc.collect()
286289

287290
if len(error_msgs) > 0:
288291
error_msg = "\n\t".join(error_msgs)
@@ -336,6 +339,7 @@ def unified_checkpoint_into_shards(
336339
tp_actions = model_to_save.get_tensor_parallel_convert_actions(
337340
model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True
338341
)
342+
logger.info("Unified model tensor parallel weights in shards")
339343
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)
340344

341345
# build index json file
@@ -489,6 +493,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
489493
# This should always be a list but, just to be sure.
490494
if not isinstance(resolved_archive_file, list):
491495
resolved_archive_file = [resolved_archive_file]
496+
492497
if len(resolved_archive_file) > 1:
493498
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")
494499

@@ -536,10 +541,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
536541
tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys)
537542

538543
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
539-
state_dict = load_state_dict(shard_file, tp_actions, expected_keys)
544+
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
540545
else:
541546
# for pipeline model, we don't need to use tp_actions
542-
state_dict = load_state_dict(shard_file, None, expected_keys)
547+
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")
543548

544549
returned_state_dict.update(state_dict)
545550
# force memory release
@@ -552,7 +557,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
552557
state_dict_master_weight = load_resolved_archive_file(
553558
resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True
554559
)
555-
556560
# rename optimizer param
557561
for key in list(state_dict_optim.keys()):
558562
key_name = key.split("/")
@@ -561,13 +565,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
561565
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
562566
else:
563567
key_name = "_".join([static_name, key_name[1]])
564-
returned_optim_state_dict[key_name] = state_dict_optim[key]
568+
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
565569
returned_optim_state_dict[key_name].name = key_name
566570

567571
if has_master_weights:
568572
for key in list(state_dict_master_weight.keys()):
569573
static_name = struct2static_name_mappings[key]
570-
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight[key]
574+
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
571575
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
572576

573577
returned_optim_state_dict = nested_copy_place(
@@ -639,6 +643,7 @@ def unified_optimizer_into_shards(
639643
tp_actions = model.get_tensor_parallel_convert_actions(
640644
model.config, model_keys, is_split=False, ignore_error=True
641645
)
646+
logger.info("Unified optimizer tensor parallel in shards")
642647
optim_state_dict = merge_tensor_parallel_for_optimizer(
643648
optim_state_dict,
644649
tp_actions,
@@ -647,6 +652,7 @@ def unified_optimizer_into_shards(
647652
paddle.device.cuda.empty_cache()
648653

649654
if master_weights is not None:
655+
logger.info("Unified master weight tensor parallel in shards")
650656
master_weights = merge_tensor_parallel_for_optimizer(
651657
master_weights,
652658
tp_actions,
@@ -702,7 +708,6 @@ def unified_optimizer_into_shards(
702708
def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
703709
index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False)
704710
index_filename = os.path.join(resume_from_checkpoint, index_filename)
705-
706711
# Find index json file and distribute this file in global group.
707712
if distributed_isfile(index_filename):
708713
distributed_file(index_filename)
@@ -1604,7 +1609,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
16041609
tp_group = hcg.get_model_parallel_group()
16051610
pp_group = hcg.get_pipe_parallel_group()
16061611

1607-
logger.info("Unified checkpoint generating sharded_index json files.")
1612+
logger.info(
1613+
f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}."
1614+
)
16081615

16091616
if tp_group.nranks > 1:
16101617
dist.all_gather_object(index_file_list, index_file, tp_group)
@@ -1713,8 +1720,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
17131720

17141721

17151722
def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
1716-
logger.info("Unified checkpoint merge tensor parallel in shards")
1717-
17181723
hcg = fleet.get_hybrid_communicate_group()
17191724
tp_group = hcg.get_model_parallel_group()
17201725
tp_rank = tp_group.rank
@@ -1740,7 +1745,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17401745
action = tp_actions.pop(key)
17411746
tensor = action(ret) if is_dst else None
17421747
else:
1743-
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1748+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
17441749

17451750
if is_dst:
17461751
state_dict_to_save[key] = tensor
@@ -1753,8 +1758,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17531758

17541759

17551760
def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
1756-
logger.info("Unified optimizer tensor parallel in shards")
1757-
1761+
# Core function for UC
17581762
hcg = fleet.get_hybrid_communicate_group()
17591763
tp_group = hcg.get_model_parallel_group()
17601764
tp_rank = tp_group.rank
@@ -1773,14 +1777,14 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17731777
# for example: beta1, beta2
17741778
if tensor.numel().item() == 1:
17751779
tensor = (
1776-
tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1780+
tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
17771781
) # Need broadcast when loaded
17781782
else:
17791783
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
17801784
action = tp_actions[model_key]
17811785
tensor = action(ret) if is_dst else None
17821786
else:
1783-
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1787+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
17841788

17851789
if is_dst:
17861790
state_dict_to_save[filter_keys[i]] = tensor
@@ -1892,7 +1896,10 @@ def nested_copy_place(inputs, place=None, blocking=False):
18921896
outputs[key] = nested_copy_place(inputs[key], place, blocking)
18931897
return outputs
18941898
if isinstance(inputs, paddle.Tensor):
1895-
inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking)
1899+
if inputs.place._equals(place):
1900+
return inputs
1901+
else:
1902+
return inputs._copy_to(place, blocking)
18961903
return inputs
18971904

18981905

paddlenlp/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2406,6 +2406,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24062406
self.runtime_timer.stop()
24072407
return
24082408

2409+
logger.info("Loading optimizer and scheduler...")
24092410
if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
24102411
self.runtime_timer.stop()
24112412
return

paddlenlp/transformers/conversion_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,12 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):
285285

286286
if isinstance(weight_list[0], np.ndarray):
287287
return np.concatenate([reorder[i] for i in index], axis=axis)
288+
else:
289+
tensor = paddle.concat([reorder[i] for i in index], axis=axis)
288290

289-
return paddle.concat([reorder[i] for i in index], axis=axis)._copy_to(paddle.CPUPlace(), False)
291+
if tensor.place.is_gpu_place():
292+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
293+
return tensor
290294

291295

292296
def naive_fuse_split_tp(
@@ -361,12 +365,18 @@ def normal_fuse_merge_tp(weight_list, is_column=True):
361365
if isinstance(weight_list[0], np.ndarray):
362366
return np.concatenate(weight_list, axis=-1)
363367
else:
364-
return paddle.concat(weight_list, axis=-1)._copy_to(paddle.CPUPlace(), False)
368+
tensor = paddle.concat(weight_list, axis=-1)
369+
if tensor.place.is_gpu_place():
370+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
371+
return tensor
365372
else:
366373
if isinstance(weight_list[0], np.ndarray):
367374
return np.concatenate(weight_list, axis=0)
368375
else:
369-
return paddle.concat(weight_list, axis=0)._copy_to(paddle.CPUPlace(), False)
376+
tensor = paddle.concat(weight_list, axis=0)
377+
if tensor.place.is_gpu_place():
378+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
379+
return tensor
370380

371381

372382
def normal_fuse_split_tp(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True):

paddlenlp/transformers/model_utils.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,13 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):
108108

109109

110110
if is_safetensors_available():
111-
from safetensors import safe_open
111+
112+
# from safetensors import safe_open
112113
from safetensors.numpy import load_file as safe_load_file
113114
from safetensors.numpy import save_file as safe_save_file
114115

116+
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
117+
115118

116119
def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
117120
"""
@@ -312,7 +315,7 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
312315

313316

314317
def load_state_dict(
315-
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None
318+
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu"
316319
):
317320
"""
318321
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
@@ -345,11 +348,16 @@ def load_state_dict(
345348
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
346349
else:
347350
weight = py_safe_slice_[:]
351+
if device == "expected":
352+
with device_guard():
353+
weight = paddle.Tensor(weight, zero_copy=True)
354+
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
348355
state_dict[key] = weight
349356

350-
for k in list(state_dict.keys()):
351-
with device_guard():
352-
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
357+
if device == "cpu":
358+
for k in list(state_dict.keys()):
359+
with device_guard():
360+
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
353361

354362
return state_dict
355363

@@ -671,8 +679,10 @@ def load_sharded_checkpoint(model, folder, variant=None, strict=True, prefer_saf
671679
return missing_keys, unexpected_keys
672680

673681

674-
def faster_set_state_dict(model, state_dict):
682+
def faster_set_state_dict(model, state_dict, strict_dtype=True):
675683
# the state_dict will be destroied.
684+
unused_keys = set(state_dict.keys())
685+
unset_keys = set(model.state_dict().keys())
676686
with paddle.no_grad():
677687
for k, v in model.state_dict().items():
678688
if k in state_dict:
@@ -682,8 +692,10 @@ def faster_set_state_dict(model, state_dict):
682692
f"faster_set_state_dict need state dict with paddle.Tensor, but got {type(v_new)}"
683693
)
684694
# 2. cast param / Tensor to dtype
695+
#
685696
if v.dtype != v_new.dtype:
686-
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
697+
if strict_dtype or (not v.is_floating_point() or not v_new.is_floating_point()):
698+
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
687699
# check shape
688700
if list(v.shape) != list(v_new.shape):
689701
raise ValueError(f"for key: {k}, expect shape {v.shape}, but got {v_new.shape}")
@@ -699,9 +711,22 @@ def faster_set_state_dict(model, state_dict):
699711
else:
700712
new_t = v_new
701713

714+
if not strict_dtype and v.dtype != new_t.dtype:
715+
new_t = new_t.astype(v.dtype)
716+
702717
# 4. share Tensor to origin param / Tensor
703718
src_tensor = new_t.value().get_tensor()
704719
dst_tensor._share_data_with(src_tensor)
720+
unset_keys.remove(k)
721+
unused_keys.remove(k)
722+
723+
error_msgs = []
724+
# if len(unset_keys) > 0:
725+
# error_msgs.append(f"Those weight of model is not initialized: {list(unset_keys)}")
726+
if len(unused_keys) > 0:
727+
error_msgs.append(f"Those state dict keys are not using in model: {list(unused_keys)}")
728+
729+
return error_msgs
705730

706731

707732
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
@@ -733,22 +758,16 @@ def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
733758
def is_0d_or_1d(tensor):
734759
return len(tensor.shape) == 0 or list(tensor.shape) == [1]
735760

736-
expected_place = paddle.framework._current_expected_place()
737761
for key, value in model_to_load.state_dict().items():
738-
if key in state_dict:
762+
if key in list(state_dict.keys()):
739763
if isinstance(state_dict[key], np.ndarray):
740764
raise ValueError(
741765
"convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, plase convert numpy.ndarray to paddle.Tensor"
742766
)
743767
# confirm parameter cast is executed on the same device as model
744768
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
745769
if state_dict[key].is_floating_point() and state_dict[key].dtype != value.dtype:
746-
value_pop = state_dict.pop(key)
747-
value_new_place = (
748-
value_pop if value_pop.place == expected_place else value_pop._copy_to(expected_place, False)
749-
)
750-
state_dict[key] = paddle.cast(value_new_place, value.dtype)._copy_to(value_pop.place, False)
751-
del value_new_place
770+
state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
752771
# unified 0d and 1d tensor
753772
if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
754773
if list(value.shape) != list(state_dict[key].shape):

0 commit comments

Comments
 (0)