Skip to content

Commit e2cc486

Browse files
committed
bug fix
1 parent e7d96a0 commit e2cc486

File tree

3 files changed

+140
-160
lines changed

3 files changed

+140
-160
lines changed

paddlenlp/experimental/transformers/chatglm/modeling.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -389,20 +389,20 @@ def set_state_dict(self, state_dict, use_structured_name=True):
389389
head_dim = embed_dim // config.num_attention_heads
390390

391391
for k, v in state_dict.items():
392-
if k.startswith("transformer.word_embeddings.weight"):
392+
if k.startswith("chatglm.transformer.word_embeddings.weight"):
393393
self.word_embeddings.weight.set_value(v.astype(dtype))
394394
continue
395-
elif k.startswith("transformer.final_layernorm.weight"):
395+
elif k.startswith("chatglm.transformer.final_layernorm.weight"):
396396
self.transformer_block.ffn_ln_scales[config.num_hidden_layers - 1].set_value(v.astype("float32"))
397397
continue
398-
elif k.startswith("transformer.final_layernorm.bias"):
398+
elif k.startswith("chatglm.transformer.final_layernorm.bias"):
399399
self.transformer_block.ffn_ln_biases[config.num_hidden_layers - 1].set_value(v.astype("float32"))
400400
continue
401401
elif k.startswith("lm_head.weight"):
402402
continue
403403
elif k.endswith("rotary_embeddings.inv_freq") or k.endswith("rotary_emb.inv_freq"):
404404
continue
405-
idx = int(k.split(".")[2])
405+
idx = int(k.split(".")[3])
406406
if k.endswith("input_layernorm.weight"):
407407
if idx == 0:
408408
self.input_layernorm.weight.set_value(v.astype(dtype))
@@ -584,7 +584,7 @@ def __init__(self, config: ChatGLMConfig):
584584

585585
@classmethod
586586
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
587-
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)
587+
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs, return_numpy=False)
588588

589589
@classmethod
590590
def get_cache_kvs_shape(
@@ -745,6 +745,6 @@ def forward(
745745
@paddle.no_grad()
746746
def set_state_dict(self, state_dict):
747747
self.lm_head.weight.set_value(
748-
state_dict["transformer.word_embeddings.weight"].astype(self.lm_head.weight.dtype)
748+
state_dict["chatglm.transformer.word_embeddings.weight"].astype(self.lm_head.weight.dtype)
749749
)
750750
self.model.transformer.set_state_dict({k: state_dict[k] for k in state_dict.keys()})

paddlenlp/experimental/transformers/utils.py

Lines changed: 4 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -13,168 +13,23 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
import json
1716
import os
18-
from functools import partial
1917

20-
import numpy as np
2118
import paddle
22-
from tqdm import tqdm
2319

24-
from paddlenlp.transformers import AutoConfig
2520
from paddlenlp.transformers.model_utils import (
26-
_add_variant,
2721
dtype_guard,
28-
load_state_dict,
22+
load_tp_checkpoint,
2923
no_init_weights,
3024
)
3125
from paddlenlp.transformers.utils import (
3226
ContextManagers,
3327
is_paddle_support_lazy_init,
3428
is_safetensors_available,
35-
paddlenlp_load,
3629
)
37-
from paddlenlp.utils.env import (
38-
PADDLE_WEIGHTS_INDEX_NAME,
39-
SAFE_MASTER_WEIGHTS_INDEX_NAME,
40-
SAFE_PEFT_WEIGHTS_INDEX_NAME,
41-
SAFE_WEIGHTS_INDEX_NAME,
42-
)
43-
44-
try:
45-
from paddlenlp.utils.safetensors import fast_load_file as safe_load_file
46-
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
47-
except:
48-
from safetensors import safe_open
49-
from safetensors.numpy import load_file as safe_load_file
50-
51-
52-
def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
53-
"""
54-
55-
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
56-
loaded in the model.
57-
58-
Args:
59-
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
60-
variant (`str`): The model variant.
61-
return_numpy (`bool`): Whether to return numpy array instead of paddle tensor.
62-
63-
"""
64-
# Load the index
65-
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
66-
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
67-
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
68-
if os.path.isfile(pdparams_file):
69-
return paddle.load(pdparams_file, return_numpy=return_numpy)
70-
if os.path.isfile(lora_pdparams_file):
71-
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
72-
if os.path.isfile(safetensors_file):
73-
state_dict = safe_load_file(safetensors_file)
74-
if not return_numpy:
75-
for key in list(state_dict.keys()):
76-
if isinstance(state_dict[key], np.ndarray):
77-
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
78-
return state_dict
79-
80-
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
81-
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
82-
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
83-
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
84-
85-
index_present = os.path.isfile(index_file)
86-
safe_index_present = os.path.isfile(safe_index_file)
87-
safe_master_present = os.path.isfile(safe_master_file)
88-
safe_peft_present = os.path.isfile(safe_peft_file)
89-
90-
load_safe = False
91-
load_index = None
92-
if safe_index_present:
93-
load_safe = True # load safe due to preference
94-
load_index = safe_index_file
95-
elif safe_master_present:
96-
load_safe = True
97-
load_index = safe_master_file
98-
elif index_present:
99-
load_index = index_file
100-
elif safe_peft_present:
101-
load_safe = True
102-
load_index = safe_peft_file
103-
else:
104-
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")
105-
106-
with open(load_index, "r", encoding="utf-8") as f:
107-
index = json.load(f)
108-
109-
shard_files = list(set(index["weight_map"].values()))
110-
loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="np" if return_numpy else "cpu")
111-
112-
ret = {}
113-
for shard_file in tqdm(shard_files):
114-
state_dict = loader(os.path.join(folder, shard_file))
115-
ret.update(state_dict)
116-
117-
if not return_numpy:
118-
for key in list(ret.keys()):
119-
if isinstance(ret[key], np.ndarray):
120-
ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True)
121-
122-
return ret
123-
124-
125-
def load_tp_checkpoint(folder, cls, config, return_numpy=False):
126-
"""
127-
128-
This load is performed efficiently: Load tp checkpoint only from cpu, no need to init the model.
129-
130-
Args:
131-
folder (`str` or `os.PathLike`): A path to a folder containing the model checkpoint.
132-
cls (`str`): The model class.
133-
config (`AutoConfig`): The model config.
134-
return_numpy (bool): Whether load the tp checkpoint as numpy.
135-
"""
136-
137-
config = AutoConfig.from_pretrained(folder)
138-
if config.tensor_parallel_degree == 1 or config.tensor_parallel_degree == -1:
139-
return load_sharded_checkpoint(folder, return_numpy=return_numpy)
140-
else:
141-
rank_model_path = os.path.join(folder, f"model_state.tp0{config.tensor_parallel_rank}.pdparams")
142-
model_path = os.path.join(folder, "model_state.pdparams")
143-
safe_model_path = os.path.join(folder, "model.safetensors")
144-
if os.path.exists(rank_model_path):
145-
return paddle.load(rank_model_path, return_numpy=return_numpy)
146-
elif os.path.exists(model_path):
147-
state_dict = cls.convert_tensor_parallel(model_path, config)
148-
elif os.path.exists(safe_model_path):
149-
with safe_open(safe_model_path, framework="np", device="cpu") as f:
150-
loaded_keys = f.keys()
151-
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
152-
state_dict = load_state_dict(safe_model_path, tp_actions)
153-
else: # shard files safetensors
154-
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
155-
pretrained_model_name_or_path=folder,
156-
use_safetensors=True,
157-
)
158-
if len(resolved_sharded_files) > 1:
159-
resolved_sharded_files = tqdm(resolved_sharded_files, desc="Loading checkpoint shards")
160-
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
161-
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_state_dict_keys, ignore_error=True)
162-
state_dict = {}
163-
for shard_file in resolved_sharded_files:
164-
shard_state_dict = load_state_dict(
165-
shard_file,
166-
tp_actions,
167-
loaded_state_dict_keys,
168-
)
169-
state_dict.update(shard_state_dict)
170-
if return_numpy:
171-
for k in list(state_dict.keys()):
172-
if not isinstance(state_dict[k], np.ndarray):
173-
state_dict[k] = state_dict.pop(k).cpu().numpy()
174-
return state_dict
17530

17631

177-
def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs):
32+
def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs, return_numpy=True):
17833
r"""
17934
Instantiate a pretrained model configuration from a pre-trained model name or path.
18035
"""
@@ -203,7 +58,7 @@ def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args,
20358
with ContextManagers(init_contexts):
20459
model = cls(config)
20560

206-
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
61+
resolved_archive_file, _, _, _ = cls._resolve_model_file_path(
20762
pretrained_model_name_or_path,
20863
cache_dir=cache_dir,
20964
subfolder=subfolder,
@@ -216,7 +71,7 @@ def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args,
21671
)
21772

21873
model_path = os.path.dirname(resolved_archive_file)
219-
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
74+
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=return_numpy)
22075
model.set_state_dict(state_dict)
22176

22277
return model

paddlenlp/transformers/model_utils.py

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import json
2121
import os
2222
import re
23-
import sys
2423
import tempfile
2524
import warnings
2625
from contextlib import contextmanager
@@ -59,6 +58,8 @@
5958
PADDLE_WEIGHTS_NAME,
6059
PYTORCH_WEIGHTS_INDEX_NAME,
6160
PYTORCH_WEIGHTS_NAME,
61+
SAFE_MASTER_WEIGHTS_INDEX_NAME,
62+
SAFE_PEFT_WEIGHTS_INDEX_NAME,
6263
SAFE_WEIGHTS_INDEX_NAME,
6364
SAFE_WEIGHTS_NAME,
6465
)
@@ -109,13 +110,14 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):
109110

110111

111112
if is_safetensors_available():
112-
from safetensors.numpy import load_file as safe_load_file
113113
from safetensors.numpy import save_file as safe_save_file
114114

115-
if sys.platform.startswith("win"):
116-
from safetensors import safe_open
117-
else:
115+
try:
116+
from paddlenlp.utils.safetensors import fast_load_file as safe_load_file
118117
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
118+
except:
119+
from safetensors import safe_open
120+
from safetensors.numpy import load_file as safe_load_file
119121

120122

121123
def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
@@ -2665,3 +2667,126 @@ def set_state_dict(self, state_dict, *args, **kwargs):
26652667

26662668
ret = super().set_state_dict(state_dict, *args, **kwargs)
26672669
return ret
2670+
2671+
2672+
def load_sharded_checkpoint_as_one(folder, variant=None, return_numpy=False):
2673+
"""
2674+
2675+
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
2676+
loaded in the model.
2677+
2678+
Args:
2679+
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
2680+
variant (`str`): The model variant.
2681+
return_numpy (`bool`): Whether to return numpy array instead of paddle tensor.
2682+
2683+
"""
2684+
# Load the index
2685+
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
2686+
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
2687+
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
2688+
if os.path.isfile(pdparams_file):
2689+
return paddle.load(pdparams_file, return_numpy=return_numpy)
2690+
if os.path.isfile(lora_pdparams_file):
2691+
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
2692+
if os.path.isfile(safetensors_file):
2693+
state_dict = safe_load_file(safetensors_file)
2694+
if not return_numpy:
2695+
for key in list(state_dict.keys()):
2696+
if isinstance(state_dict[key], np.ndarray):
2697+
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
2698+
return state_dict
2699+
2700+
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
2701+
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
2702+
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
2703+
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
2704+
2705+
index_present = os.path.isfile(index_file)
2706+
safe_index_present = os.path.isfile(safe_index_file)
2707+
safe_master_present = os.path.isfile(safe_master_file)
2708+
safe_peft_present = os.path.isfile(safe_peft_file)
2709+
2710+
load_safe = False
2711+
load_index = None
2712+
if safe_index_present:
2713+
load_safe = True # load safe due to preference
2714+
load_index = safe_index_file
2715+
elif safe_master_present:
2716+
load_safe = True
2717+
load_index = safe_master_file
2718+
elif index_present:
2719+
load_index = index_file
2720+
elif safe_peft_present:
2721+
load_safe = True
2722+
load_index = safe_peft_file
2723+
else:
2724+
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")
2725+
2726+
with open(load_index, "r", encoding="utf-8") as f:
2727+
index = json.load(f)
2728+
2729+
shard_files = list(set(index["weight_map"].values()))
2730+
loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="np" if return_numpy else "cpu")
2731+
2732+
ret = {}
2733+
for shard_file in tqdm(shard_files):
2734+
state_dict = loader(os.path.join(folder, shard_file))
2735+
ret.update(state_dict)
2736+
2737+
if not return_numpy:
2738+
for key in list(ret.keys()):
2739+
if isinstance(ret[key], np.ndarray):
2740+
ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True)
2741+
2742+
return ret
2743+
2744+
2745+
def load_tp_checkpoint(folder, cls, config, return_numpy=False):
2746+
"""
2747+
2748+
This load is performed efficiently: Load tp checkpoint only from cpu, no need to init the model.
2749+
2750+
Args:
2751+
folder (`str` or `os.PathLike`): A path to a folder containing the model checkpoint.
2752+
cls (`str`): The model class.
2753+
config (`AutoConfig`): The model config.
2754+
return_numpy (bool): Whether load the tp checkpoint as numpy.
2755+
"""
2756+
if config.tensor_parallel_degree == 1 or config.tensor_parallel_degree == -1:
2757+
return load_sharded_checkpoint_as_one(folder, return_numpy=return_numpy)
2758+
else:
2759+
rank_model_path = os.path.join(folder, f"model_state.tp0{config.tensor_parallel_rank}.pdparams")
2760+
model_path = os.path.join(folder, "model_state.pdparams")
2761+
safe_model_path = os.path.join(folder, "model.safetensors")
2762+
if os.path.exists(rank_model_path):
2763+
return paddle.load(rank_model_path, return_numpy=return_numpy)
2764+
elif os.path.exists(model_path):
2765+
state_dict = cls.convert_tensor_parallel(model_path, config)
2766+
elif os.path.exists(safe_model_path):
2767+
with safe_open(safe_model_path, framework="np", device="cpu") as f:
2768+
loaded_keys = f.keys()
2769+
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
2770+
state_dict = load_state_dict(safe_model_path, tp_actions)
2771+
else: # shard files safetensors
2772+
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
2773+
pretrained_model_name_or_path=folder,
2774+
use_safetensors=True,
2775+
)
2776+
if len(resolved_sharded_files) > 1:
2777+
resolved_sharded_files = tqdm(resolved_sharded_files, desc="Loading checkpoint shards")
2778+
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
2779+
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_state_dict_keys, ignore_error=True)
2780+
state_dict = {}
2781+
for shard_file in resolved_sharded_files:
2782+
shard_state_dict = load_state_dict(
2783+
shard_file,
2784+
tp_actions,
2785+
loaded_state_dict_keys,
2786+
)
2787+
state_dict.update(shard_state_dict)
2788+
if return_numpy:
2789+
for k in list(state_dict.keys()):
2790+
if not isinstance(state_dict[k], np.ndarray):
2791+
state_dict[k] = state_dict.pop(k).cpu().numpy()
2792+
return state_dict

0 commit comments

Comments
 (0)