Skip to content

Commit 6944c25

Browse files
committed
update llama predict
1 parent 62a18c4 commit 6944c25

File tree

3 files changed

+245
-47
lines changed

3 files changed

+245
-47
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ repos:
4444
entry: python .copyright.hook
4545
language: system
4646
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|xpu|kps|py|sh)$
47-
# For Markdown files
48-
- repo: https://github.com/igorshubovych/markdownlint-cli
49-
rev: v0.41.0
50-
hooks:
51-
- id: markdownlint
5247
- repo: local
5348
hooks:
5449
- id: add-spaces-between-chinese-and-english

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 81 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
GenerationBlockInferenceModel,
4444
GenerationInferenceModel,
4545
)
46+
from paddlenlp.experimental.transformers.utils import load_tp_checkpoint
4647
from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel
4748
from paddlenlp.transformers.conversion_utils import split_param_func
4849
from paddlenlp.transformers.llama.modeling import LlamaLMHead
@@ -51,9 +52,16 @@
5152
CausalLMOutputWithCrossAttentions,
5253
)
5354
from paddlenlp.transformers.model_utils import (
55+
dtype_guard,
5456
dy2st_nocheck_guard_context,
57+
no_init_weights,
5558
register_base_model,
5659
)
60+
from paddlenlp.transformers.utils import (
61+
ContextManagers,
62+
is_paddle_support_lazy_init,
63+
is_safetensors_available,
64+
)
5765
from paddlenlp.utils.log import logger
5866

5967
__all__ = [
@@ -1238,9 +1246,47 @@ def __init__(self, config):
12381246

12391247
@classmethod
12401248
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
1241-
# TODO: Support safetensors loading.
1242-
kwargs["use_safetensors"] = False
1243-
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
1249+
config = kwargs.pop("config", None)
1250+
cache_dir = kwargs.pop("cache_dir", None)
1251+
dtype = kwargs.pop("dtype", None)
1252+
if dtype is None:
1253+
dtype = config.dtype
1254+
subfolder = kwargs.pop("subfolder", None)
1255+
if subfolder is None:
1256+
subfolder = ""
1257+
variant = kwargs.pop("variant", None)
1258+
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1259+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
1260+
1261+
init_contexts = []
1262+
if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
1263+
# Instantiate model.
1264+
init_contexts.append(no_init_weights(_enable=True))
1265+
if is_paddle_support_lazy_init():
1266+
init_contexts.append(paddle.LazyGuard())
1267+
if dtype:
1268+
init_contexts.append(dtype_guard(dtype))
1269+
1270+
# init the model
1271+
with ContextManagers(init_contexts):
1272+
model = cls(config)
1273+
1274+
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
1275+
pretrained_model_name_or_path,
1276+
cache_dir=cache_dir,
1277+
subfolder=subfolder,
1278+
from_hf_hub=False,
1279+
from_aistudio=False,
1280+
config=config,
1281+
convert_from_torch=False,
1282+
use_safetensors=use_safetensors,
1283+
variant=variant,
1284+
)
1285+
1286+
model_path = os.path.dirname(resolved_archive_file)
1287+
state_dict = load_tp_checkpoint(model_path, cls, config)
1288+
model.set_state_dict(state_dict)
1289+
return model
12441290

12451291
@classmethod
12461292
def get_cache_kvs_shape(
@@ -1477,53 +1523,46 @@ def get_tensor_parallel_split_mappings(num_layers):
14771523

14781524
@classmethod
14791525
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
1480-
# TODO: Support safetensors loading.
1481-
kwargs["use_safetensors"] = False
1482-
from paddlenlp.transformers.utils import (
1483-
ContextManagers,
1484-
is_safetensors_available,
1485-
)
1486-
1487-
from_hf_hub = kwargs.pop("from_hf_hub", False)
14881526
config = kwargs.pop("config", None)
1489-
from_aistudio = kwargs.get("from_aistudio", False)
1490-
subfolder = kwargs.get("subfolder", None)
1527+
cache_dir = kwargs.pop("cache_dir", None)
1528+
dtype = kwargs.pop("dtype", None)
1529+
if dtype is None:
1530+
dtype = config.dtype
1531+
subfolder = kwargs.pop("subfolder", None)
1532+
if subfolder is None:
1533+
subfolder = ""
14911534
variant = kwargs.pop("variant", None)
14921535
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1493-
convert_from_torch = kwargs.pop("convert_from_torch", None)
1494-
cache_dir = kwargs.pop("cache_dir", None)
1536+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
14951537

14961538
init_contexts = []
1539+
if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
1540+
# Instantiate model.
1541+
init_contexts.append(no_init_weights(_enable=True))
1542+
if is_paddle_support_lazy_init():
1543+
init_contexts.append(paddle.LazyGuard())
1544+
if dtype:
1545+
init_contexts.append(dtype_guard(dtype))
1546+
1547+
# init the model
14971548
with ContextManagers(init_contexts):
14981549
model = cls(config)
14991550

1500-
if not config.single_card_ptq:
1501-
resolved_archive_file = pretrained_model_name_or_path
1502-
else:
1503-
resolved_archive_file = cls._resolve_model_file_path(
1504-
pretrained_model_name_or_path,
1505-
cache_dir=cache_dir,
1506-
subfolder=subfolder,
1507-
from_hf_hub=from_hf_hub,
1508-
from_aistudio=from_aistudio,
1509-
config=config,
1510-
convert_from_torch=convert_from_torch,
1511-
use_safetensors=use_safetensors,
1512-
variant=variant,
1513-
)[0]
1514-
logger.info(f"Load model form {resolved_archive_file}")
1515-
1516-
if config.tensor_parallel_degree > 1 and config.single_card_ptq:
1517-
logger.info(f"convert_tensor_parallel {config.tensor_parallel_degree}")
1518-
model.state_dict = model.convert_tensor_parallel(resolved_archive_file, config)
1519-
elif config.tensor_parallel_degree > 1:
1520-
resolved_archive_file = os.path.join(
1521-
resolved_archive_file, f"mp_{config.tensor_parallel_rank:0>2d}_sharding_00_pp_00", "model.pdparams"
1522-
)
1523-
model.state_dict = paddle.load(resolved_archive_file, return_numpy=True)
1524-
else:
1525-
model.state_dict = paddle.load(resolved_archive_file, return_numpy=True)
1526-
model.set_state_dict(model.state_dict)
1551+
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
1552+
pretrained_model_name_or_path,
1553+
cache_dir=cache_dir,
1554+
subfolder=subfolder,
1555+
from_hf_hub=False,
1556+
from_aistudio=False,
1557+
config=config,
1558+
convert_from_torch=False,
1559+
use_safetensors=use_safetensors,
1560+
variant=variant,
1561+
)
1562+
1563+
model_path = os.path.dirname(resolved_archive_file)
1564+
state_dict = load_tp_checkpoint(model_path, cls, config)
1565+
model.set_state_dict(state_dict)
15271566

15281567
return model
15291568

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import json
17+
import os
18+
from functools import partial
19+
20+
import numpy as np
21+
import paddle
22+
from tqdm import tqdm
23+
24+
from paddlenlp.transformers import AutoConfig
25+
from paddlenlp.transformers.model_utils import _add_variant, load_state_dict
26+
from paddlenlp.transformers.utils import paddlenlp_load
27+
from paddlenlp.utils.env import (
28+
PADDLE_WEIGHTS_INDEX_NAME,
29+
SAFE_MASTER_WEIGHTS_INDEX_NAME,
30+
SAFE_PEFT_WEIGHTS_INDEX_NAME,
31+
SAFE_WEIGHTS_INDEX_NAME,
32+
)
33+
34+
try:
35+
from paddlenlp.utils.safetensors import fast_load_file as safe_load_file
36+
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
37+
except:
38+
from safetensors import safe_open
39+
from safetensors.numpy import load_file as safe_load_file
40+
41+
42+
def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
43+
"""
44+
45+
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
46+
loaded in the model.
47+
48+
Args:
49+
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
50+
variant (`str`): The model variant.
51+
return_numpy (`bool`): Whether to return numpy array instead of paddle tensor.
52+
53+
"""
54+
# Load the index
55+
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
56+
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
57+
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
58+
if os.path.isfile(pdparams_file):
59+
return paddle.load(pdparams_file, return_numpy=return_numpy)
60+
if os.path.isfile(lora_pdparams_file):
61+
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
62+
if os.path.isfile(safetensors_file):
63+
state_dict = safe_load_file(safetensors_file)
64+
if not return_numpy:
65+
for key in list(state_dict.keys()):
66+
if isinstance(state_dict[key], np.ndarray):
67+
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
68+
return state_dict
69+
70+
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
71+
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
72+
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
73+
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
74+
75+
index_present = os.path.isfile(index_file)
76+
safe_index_present = os.path.isfile(safe_index_file)
77+
safe_master_present = os.path.isfile(safe_master_file)
78+
safe_peft_present = os.path.isfile(safe_peft_file)
79+
80+
load_safe = False
81+
load_index = None
82+
if safe_index_present:
83+
load_safe = True # load safe due to preference
84+
load_index = safe_index_file
85+
elif safe_master_present:
86+
load_safe = True
87+
load_index = safe_master_file
88+
elif index_present:
89+
load_index = index_file
90+
elif safe_peft_present:
91+
load_safe = True
92+
load_index = safe_peft_file
93+
else:
94+
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")
95+
96+
with open(load_index, "r", encoding="utf-8") as f:
97+
index = json.load(f)
98+
99+
shard_files = list(set(index["weight_map"].values()))
100+
loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="np" if return_numpy else "cpu")
101+
102+
ret = {}
103+
for shard_file in tqdm(shard_files):
104+
state_dict = loader(os.path.join(folder, shard_file))
105+
ret.update(state_dict)
106+
107+
if not return_numpy:
108+
for key in list(ret.keys()):
109+
if isinstance(ret[key], np.ndarray):
110+
ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True)
111+
112+
return ret
113+
114+
115+
def load_tp_checkpoint(folder, cls, config, return_numpy=False):
116+
"""
117+
118+
This load is performed efficiently: Load tp checkpoint only from cpu, no need to init the model.
119+
120+
Args:
121+
folder (`str` or `os.PathLike`): A path to a folder containing the model checkpoint.
122+
cls (`str`): The model class.
123+
config (`AutoConfig`): The model config.
124+
return_numpy (bool): Whether load the tp checkpoint as numpy.
125+
"""
126+
127+
config = AutoConfig.from_pretrained(folder)
128+
if config.tensor_parallel_degree == 1:
129+
return load_sharded_checkpoint(folder, return_numpy=return_numpy)
130+
else:
131+
rank_model_path = os.path.join(folder, f"model_state.tp0{config.tensor_parallel_rank}.pdparams")
132+
model_path = os.path.join(folder, "model_state.pdparams")
133+
safe_model_path = os.path.join(folder, "model.safetensors")
134+
if os.path.exists(rank_model_path):
135+
return paddle.load(rank_model_path, return_numpy=return_numpy)
136+
elif os.path.exists(model_path):
137+
state_dict = cls.convert_tensor_parallel(model_path, config)
138+
elif os.path.exists(safe_model_path):
139+
with safe_open(safe_model_path, framework="np", device="cpu") as f:
140+
loaded_keys = f.keys()
141+
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
142+
state_dict = load_state_dict(safe_model_path, tp_actions)
143+
else: # shard files safetensors
144+
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
145+
pretrained_model_name_or_path=folder,
146+
use_safetensors=True,
147+
)
148+
if len(resolved_sharded_files) > 1:
149+
resolved_sharded_files = tqdm(resolved_sharded_files, desc="Loading checkpoint shards")
150+
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
151+
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_state_dict_keys, ignore_error=True)
152+
state_dict = {}
153+
for shard_file in resolved_sharded_files:
154+
shard_state_dict = load_state_dict(
155+
shard_file,
156+
tp_actions,
157+
loaded_state_dict_keys,
158+
)
159+
state_dict.update(shard_state_dict)
160+
if return_numpy:
161+
for k in list(state_dict.keys()):
162+
if not isinstance(state_dict[k], np.ndarray):
163+
state_dict[k] = state_dict.pop(k).cpu().numpy()
164+
return state_dict

0 commit comments

Comments
 (0)