Skip to content

[Performance] Optimize unified checkpoint save/load speed. #8204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from paddlenlp.transformers.model_utils import (
PretrainedModel,
_load_state_dict_into_model,
faster_set_state_dict,
get_parameter_dtype,
load_state_dict,
unwrap_model,
Expand Down Expand Up @@ -64,9 +65,10 @@
from paddlenlp.utils.log import logger

if is_safetensors_available():
from safetensors import safe_open
# from safetensors import safe_open
from safetensors.numpy import save_file as safe_save_file

from paddlenlp.utils.safetensors import fast_safe_open as safe_open

FP32_MASTER = "fp32_master_0"
optimizer_scalar_name = [
Expand Down Expand Up @@ -195,7 +197,6 @@
Returns:
None
"""

if paddle.distributed.get_world_size() <= 1:
load_single_card_checkpoint(args, model, resume_from_checkpoint)
return
Expand All @@ -221,7 +222,6 @@
pretrained_model_name_or_path=resume_from_checkpoint,
index_filename=os.path.join(resume_from_checkpoint, index_filename),
)

loaded_keys = sharded_metadata["all_checkpoint_keys"]

model_state_dict = get_expected_state_dict(model)
Expand Down Expand Up @@ -265,7 +265,9 @@
else:
tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True)
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict(shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys)
state_dict = load_state_dict(

Check warning on line 268 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L268

Added line #L268 was not covered by tests
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected"
)

if not pre_tensor_parallel_split:
# Since we load all keys but we only need one of pipeline stages
Expand All @@ -278,11 +280,12 @@
None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1
)

error_msgs += _load_state_dict_into_model(model, state_dict, "")
# error_msgs += _load_state_dict_into_model(model, state_dict, "")
error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False)

Check warning on line 284 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L284

Added line #L284 was not covered by tests

# force memory release
del state_dict
gc.collect()
# gc.collect()

if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
Expand Down Expand Up @@ -336,6 +339,7 @@
tp_actions = model_to_save.get_tensor_parallel_convert_actions(
model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True
)
logger.info("Unified model tensor parallel weights in shards")

Check warning on line 342 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L342

Added line #L342 was not covered by tests
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)

# build index json file
Expand Down Expand Up @@ -489,6 +493,7 @@
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]

if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")

Expand Down Expand Up @@ -536,10 +541,10 @@
tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys)

# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
state_dict = load_state_dict(shard_file, tp_actions, expected_keys)
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")

Check warning on line 544 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L544

Added line #L544 was not covered by tests
else:
# for pipeline model, we don't need to use tp_actions
state_dict = load_state_dict(shard_file, None, expected_keys)
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")

Check warning on line 547 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L547

Added line #L547 was not covered by tests

returned_state_dict.update(state_dict)
# force memory release
Expand All @@ -552,7 +557,6 @@
state_dict_master_weight = load_resolved_archive_file(
resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True
)

# rename optimizer param
for key in list(state_dict_optim.keys()):
key_name = key.split("/")
Expand All @@ -561,13 +565,13 @@
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
returned_optim_state_dict[key_name] = state_dict_optim[key]
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)

Check warning on line 568 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L568

Added line #L568 was not covered by tests
returned_optim_state_dict[key_name].name = key_name

if has_master_weights:
for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings[key]
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight[key]
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)

Check warning on line 574 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L574

Added line #L574 was not covered by tests
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

returned_optim_state_dict = nested_copy_place(
Expand Down Expand Up @@ -639,6 +643,7 @@
tp_actions = model.get_tensor_parallel_convert_actions(
model.config, model_keys, is_split=False, ignore_error=True
)
logger.info("Unified optimizer tensor parallel in shards")

Check warning on line 646 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L646

Added line #L646 was not covered by tests
optim_state_dict = merge_tensor_parallel_for_optimizer(
optim_state_dict,
tp_actions,
Expand All @@ -647,6 +652,7 @@
paddle.device.cuda.empty_cache()

if master_weights is not None:
logger.info("Unified master weight tensor parallel in shards")

Check warning on line 655 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L655

Added line #L655 was not covered by tests
master_weights = merge_tensor_parallel_for_optimizer(
master_weights,
tp_actions,
Expand Down Expand Up @@ -702,7 +708,6 @@
def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False)
index_filename = os.path.join(resume_from_checkpoint, index_filename)

# Find index json file and distribute this file in global group.
if distributed_isfile(index_filename):
distributed_file(index_filename)
Expand Down Expand Up @@ -1604,7 +1609,9 @@
tp_group = hcg.get_model_parallel_group()
pp_group = hcg.get_pipe_parallel_group()

logger.info("Unified checkpoint generating sharded_index json files.")
logger.info(

Check warning on line 1612 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1612

Added line #L1612 was not covered by tests
f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}."
)

if tp_group.nranks > 1:
dist.all_gather_object(index_file_list, index_file, tp_group)
Expand Down Expand Up @@ -1713,8 +1720,6 @@


def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
logger.info("Unified checkpoint merge tensor parallel in shards")

hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
tp_rank = tp_group.rank
Expand All @@ -1740,7 +1745,7 @@
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None

Check warning on line 1748 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1748

Added line #L1748 was not covered by tests

if is_dst:
state_dict_to_save[key] = tensor
Expand All @@ -1753,8 +1758,7 @@


def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
logger.info("Unified optimizer tensor parallel in shards")

# Core function for UC
hcg = fleet.get_hybrid_communicate_group()
tp_group = hcg.get_model_parallel_group()
tp_rank = tp_group.rank
Expand All @@ -1773,14 +1777,14 @@
# for example: beta1, beta2
if tensor.numel().item() == 1:
tensor = (
tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None
) # Need broadcast when loaded
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
else:
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False) if is_dst else None

Check warning on line 1787 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1787

Added line #L1787 was not covered by tests

if is_dst:
state_dict_to_save[filter_keys[i]] = tensor
Expand Down Expand Up @@ -1892,7 +1896,10 @@
outputs[key] = nested_copy_place(inputs[key], place, blocking)
return outputs
if isinstance(inputs, paddle.Tensor):
inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking)
if inputs.place._equals(place):
return inputs

Check warning on line 1900 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1899-L1900

Added lines #L1899 - L1900 were not covered by tests
else:
return inputs._copy_to(place, blocking)

Check warning on line 1902 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L1902

Added line #L1902 was not covered by tests
return inputs


Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2406,6 +2406,7 @@
self.runtime_timer.stop()
return

logger.info("Loading optimizer and scheduler...")

Check warning on line 2409 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2409

Added line #L2409 was not covered by tests
if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
self.runtime_timer.stop()
return
Expand Down
16 changes: 13 additions & 3 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,12 @@

if isinstance(weight_list[0], np.ndarray):
return np.concatenate([reorder[i] for i in index], axis=axis)
else:
tensor = paddle.concat([reorder[i] for i in index], axis=axis)

Check warning on line 289 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L289

Added line #L289 was not covered by tests

return paddle.concat([reorder[i] for i in index], axis=axis)._copy_to(paddle.CPUPlace(), False)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor

Check warning on line 293 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L291-L293

Added lines #L291 - L293 were not covered by tests


def naive_fuse_split_tp(
Expand Down Expand Up @@ -361,12 +365,18 @@
if isinstance(weight_list[0], np.ndarray):
return np.concatenate(weight_list, axis=-1)
else:
return paddle.concat(weight_list, axis=-1)._copy_to(paddle.CPUPlace(), False)
tensor = paddle.concat(weight_list, axis=-1)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor

Check warning on line 371 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L368-L371

Added lines #L368 - L371 were not covered by tests
else:
if isinstance(weight_list[0], np.ndarray):
return np.concatenate(weight_list, axis=0)
else:
return paddle.concat(weight_list, axis=0)._copy_to(paddle.CPUPlace(), False)
tensor = paddle.concat(weight_list, axis=0)
if tensor.place.is_gpu_place():
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
return tensor

Check warning on line 379 in paddlenlp/transformers/conversion_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/conversion_utils.py#L376-L379

Added lines #L376 - L379 were not covered by tests


def normal_fuse_split_tp(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True):
Expand Down
49 changes: 34 additions & 15 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@


if is_safetensors_available():
from safetensors import safe_open

# from safetensors import safe_open
from safetensors.numpy import load_file as safe_load_file
from safetensors.numpy import save_file as safe_save_file

from paddlenlp.utils.safetensors import fast_safe_open as safe_open


def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
"""
Expand Down Expand Up @@ -312,7 +315,7 @@


def load_state_dict(
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu"
):
"""
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
Expand Down Expand Up @@ -345,11 +348,16 @@
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
else:
weight = py_safe_slice_[:]
if device == "expected":
with device_guard():
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)

Check warning on line 354 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L352-L354

Added lines #L352 - L354 were not covered by tests
state_dict[key] = weight

for k in list(state_dict.keys()):
with device_guard():
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
if device == "cpu":
for k in list(state_dict.keys()):
with device_guard():
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)

return state_dict

Expand Down Expand Up @@ -671,8 +679,10 @@
return missing_keys, unexpected_keys


def faster_set_state_dict(model, state_dict):
def faster_set_state_dict(model, state_dict, strict_dtype=True):
# the state_dict will be destroied.
unused_keys = set(state_dict.keys())
unset_keys = set(model.state_dict().keys())

Check warning on line 685 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L684-L685

Added lines #L684 - L685 were not covered by tests
with paddle.no_grad():
for k, v in model.state_dict().items():
if k in state_dict:
Expand All @@ -682,8 +692,10 @@
f"faster_set_state_dict need state dict with paddle.Tensor, but got {type(v_new)}"
)
# 2. cast param / Tensor to dtype
#
if v.dtype != v_new.dtype:
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
if strict_dtype or (not v.is_floating_point() or not v_new.is_floating_point()):
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")

Check warning on line 698 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L697-L698

Added lines #L697 - L698 were not covered by tests
# check shape
if list(v.shape) != list(v_new.shape):
raise ValueError(f"for key: {k}, expect shape {v.shape}, but got {v_new.shape}")
Expand All @@ -699,9 +711,22 @@
else:
new_t = v_new

if not strict_dtype and v.dtype != new_t.dtype:
new_t = new_t.astype(v.dtype)

Check warning on line 715 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L714-L715

Added lines #L714 - L715 were not covered by tests

# 4. share Tensor to origin param / Tensor
src_tensor = new_t.value().get_tensor()
dst_tensor._share_data_with(src_tensor)
unset_keys.remove(k)
unused_keys.remove(k)

Check warning on line 721 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L720-L721

Added lines #L720 - L721 were not covered by tests

error_msgs = []

Check warning on line 723 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L723

Added line #L723 was not covered by tests
# if len(unset_keys) > 0:
# error_msgs.append(f"Those weight of model is not initialized: {list(unset_keys)}")
if len(unused_keys) > 0:
error_msgs.append(f"Those state dict keys are not using in model: {list(unused_keys)}")

Check warning on line 727 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L726-L727

Added lines #L726 - L727 were not covered by tests

return error_msgs

Check warning on line 729 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L729

Added line #L729 was not covered by tests


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
Expand Down Expand Up @@ -733,22 +758,16 @@
def is_0d_or_1d(tensor):
return len(tensor.shape) == 0 or list(tensor.shape) == [1]

expected_place = paddle.framework._current_expected_place()
for key, value in model_to_load.state_dict().items():
if key in state_dict:
if key in list(state_dict.keys()):
if isinstance(state_dict[key], np.ndarray):
raise ValueError(
"convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, plase convert numpy.ndarray to paddle.Tensor"
)
# confirm parameter cast is executed on the same device as model
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
if state_dict[key].is_floating_point() and state_dict[key].dtype != value.dtype:
value_pop = state_dict.pop(key)
value_new_place = (
value_pop if value_pop.place == expected_place else value_pop._copy_to(expected_place, False)
)
state_dict[key] = paddle.cast(value_new_place, value.dtype)._copy_to(value_pop.place, False)
del value_new_place
state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
# unified 0d and 1d tensor
if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
if list(value.shape) != list(state_dict[key].shape):
Expand Down
Loading
Loading