13
13
# limitations under the License.
14
14
from __future__ import annotations
15
15
16
- import json
17
16
import os
18
- from functools import partial
19
17
20
- import numpy as np
21
18
import paddle
22
- from tqdm import tqdm
23
19
24
- from paddlenlp .transformers import AutoConfig
25
20
from paddlenlp .transformers .model_utils import (
26
- _add_variant ,
27
21
dtype_guard ,
28
- load_state_dict ,
22
+ load_tp_checkpoint ,
29
23
no_init_weights ,
30
24
)
31
25
from paddlenlp .transformers .utils import (
32
26
ContextManagers ,
33
27
is_paddle_support_lazy_init ,
34
28
is_safetensors_available ,
35
- paddlenlp_load ,
36
29
)
37
- from paddlenlp .utils .env import (
38
- PADDLE_WEIGHTS_INDEX_NAME ,
39
- SAFE_MASTER_WEIGHTS_INDEX_NAME ,
40
- SAFE_PEFT_WEIGHTS_INDEX_NAME ,
41
- SAFE_WEIGHTS_INDEX_NAME ,
42
- )
43
-
44
- try :
45
- from paddlenlp .utils .safetensors import fast_load_file as safe_load_file
46
- from paddlenlp .utils .safetensors import fast_safe_open as safe_open
47
- except :
48
- from safetensors import safe_open
49
- from safetensors .numpy import load_file as safe_load_file
50
-
51
-
52
- def load_sharded_checkpoint (folder , variant = None , return_numpy = False ):
53
- """
54
-
55
- This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
56
- loaded in the model.
57
-
58
- Args:
59
- folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
60
- variant (`str`): The model variant.
61
- return_numpy (`bool`): Whether to return numpy array instead of paddle tensor.
62
-
63
- """
64
- # Load the index
65
- pdparams_file = os .path .join (folder , _add_variant ("model_state.pdparams" , variant ))
66
- lora_pdparams_file = os .path .join (folder , _add_variant ("lora_model_state.pdparams" , variant ))
67
- safetensors_file = os .path .join (folder , _add_variant ("model.safetensors" , variant ))
68
- if os .path .isfile (pdparams_file ):
69
- return paddle .load (pdparams_file , return_numpy = return_numpy )
70
- if os .path .isfile (lora_pdparams_file ):
71
- return paddle .load (lora_pdparams_file , return_numpy = return_numpy )
72
- if os .path .isfile (safetensors_file ):
73
- state_dict = safe_load_file (safetensors_file )
74
- if not return_numpy :
75
- for key in list (state_dict .keys ()):
76
- if isinstance (state_dict [key ], np .ndarray ):
77
- state_dict [key ] = paddle .Tensor (state_dict .pop (key ), zero_copy = True )
78
- return state_dict
79
-
80
- index_file = os .path .join (folder , _add_variant (PADDLE_WEIGHTS_INDEX_NAME , variant ))
81
- safe_index_file = os .path .join (folder , _add_variant (SAFE_WEIGHTS_INDEX_NAME , variant ))
82
- safe_master_file = os .path .join (folder , _add_variant (SAFE_MASTER_WEIGHTS_INDEX_NAME , variant ))
83
- safe_peft_file = os .path .join (folder , _add_variant (SAFE_PEFT_WEIGHTS_INDEX_NAME , variant ))
84
-
85
- index_present = os .path .isfile (index_file )
86
- safe_index_present = os .path .isfile (safe_index_file )
87
- safe_master_present = os .path .isfile (safe_master_file )
88
- safe_peft_present = os .path .isfile (safe_peft_file )
89
-
90
- load_safe = False
91
- load_index = None
92
- if safe_index_present :
93
- load_safe = True # load safe due to preference
94
- load_index = safe_index_file
95
- elif safe_master_present :
96
- load_safe = True
97
- load_index = safe_master_file
98
- elif index_present :
99
- load_index = index_file
100
- elif safe_peft_present :
101
- load_safe = True
102
- load_index = safe_peft_file
103
- else :
104
- raise ValueError (f"Could not find { index_file } or { safe_index_file } or { safe_peft_file } " )
105
-
106
- with open (load_index , "r" , encoding = "utf-8" ) as f :
107
- index = json .load (f )
108
-
109
- shard_files = list (set (index ["weight_map" ].values ()))
110
- loader = safe_load_file if load_safe else partial (paddlenlp_load , map_location = "np" if return_numpy else "cpu" )
111
-
112
- ret = {}
113
- for shard_file in tqdm (shard_files ):
114
- state_dict = loader (os .path .join (folder , shard_file ))
115
- ret .update (state_dict )
116
-
117
- if not return_numpy :
118
- for key in list (ret .keys ()):
119
- if isinstance (ret [key ], np .ndarray ):
120
- ret [key ] = paddle .Tensor (ret .pop (key ), zero_copy = True )
121
-
122
- return ret
123
-
124
-
125
- def load_tp_checkpoint (folder , cls , config , return_numpy = False ):
126
- """
127
-
128
- This load is performed efficiently: Load tp checkpoint only from cpu, no need to init the model.
129
-
130
- Args:
131
- folder (`str` or `os.PathLike`): A path to a folder containing the model checkpoint.
132
- cls (`str`): The model class.
133
- config (`AutoConfig`): The model config.
134
- return_numpy (bool): Whether load the tp checkpoint as numpy.
135
- """
136
-
137
- config = AutoConfig .from_pretrained (folder )
138
- if config .tensor_parallel_degree == 1 or config .tensor_parallel_degree == - 1 :
139
- return load_sharded_checkpoint (folder , return_numpy = return_numpy )
140
- else :
141
- rank_model_path = os .path .join (folder , f"model_state.tp0{ config .tensor_parallel_rank } .pdparams" )
142
- model_path = os .path .join (folder , "model_state.pdparams" )
143
- safe_model_path = os .path .join (folder , "model.safetensors" )
144
- if os .path .exists (rank_model_path ):
145
- return paddle .load (rank_model_path , return_numpy = return_numpy )
146
- elif os .path .exists (model_path ):
147
- state_dict = cls .convert_tensor_parallel (model_path , config )
148
- elif os .path .exists (safe_model_path ):
149
- with safe_open (safe_model_path , framework = "np" , device = "cpu" ) as f :
150
- loaded_keys = f .keys ()
151
- tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_keys )
152
- state_dict = load_state_dict (safe_model_path , tp_actions )
153
- else : # shard files safetensors
154
- resolved_archive_file , resolved_sharded_files , sharded_metadata , is_sharded = cls ._resolve_model_file_path (
155
- pretrained_model_name_or_path = folder ,
156
- use_safetensors = True ,
157
- )
158
- if len (resolved_sharded_files ) > 1 :
159
- resolved_sharded_files = tqdm (resolved_sharded_files , desc = "Loading checkpoint shards" )
160
- loaded_state_dict_keys = sharded_metadata ["all_checkpoint_keys" ]
161
- tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_state_dict_keys , ignore_error = True )
162
- state_dict = {}
163
- for shard_file in resolved_sharded_files :
164
- shard_state_dict = load_state_dict (
165
- shard_file ,
166
- tp_actions ,
167
- loaded_state_dict_keys ,
168
- )
169
- state_dict .update (shard_state_dict )
170
- if return_numpy :
171
- for k in list (state_dict .keys ()):
172
- if not isinstance (state_dict [k ], np .ndarray ):
173
- state_dict [k ] = state_dict .pop (k ).cpu ().numpy ()
174
- return state_dict
175
30
176
31
177
- def infererence_model_from_pretrained (cls , pretrained_model_name_or_path , args , kwargs ):
32
+ def infererence_model_from_pretrained (cls , pretrained_model_name_or_path , args , kwargs , return_numpy = True ):
178
33
r"""
179
34
Instantiate a pretrained model configuration from a pre-trained model name or path.
180
35
"""
@@ -203,7 +58,7 @@ def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args,
203
58
with ContextManagers (init_contexts ):
204
59
model = cls (config )
205
60
206
- resolved_archive_file , resolved_sharded_files , sharded_metadata , is_sharded = cls ._resolve_model_file_path (
61
+ resolved_archive_file , _ , _ , _ = cls ._resolve_model_file_path (
207
62
pretrained_model_name_or_path ,
208
63
cache_dir = cache_dir ,
209
64
subfolder = subfolder ,
@@ -216,7 +71,7 @@ def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args,
216
71
)
217
72
218
73
model_path = os .path .dirname (resolved_archive_file )
219
- state_dict = load_tp_checkpoint (model_path , cls , config , return_numpy = True )
74
+ state_dict = load_tp_checkpoint (model_path , cls , config , return_numpy = return_numpy )
220
75
model .set_state_dict (state_dict )
221
76
222
77
return model
0 commit comments