17
17
import inspect
18
18
import json
19
19
import os
20
- import re
21
20
from copy import deepcopy
22
21
from dataclasses import dataclass
23
22
from typing import (
@@ -490,22 +489,22 @@ def splited_qkv_to_tensor_parallel_qkv(weight_list, num_attention_heads):
490
489
return naive_merged_qkv_to_tensor_parallel_qkv (weight )
491
490
492
491
493
- def merged_as_tensor_parallel_qkv (state_dict , q_name , k_name , v_name , num_hidden_layers ):
494
- q = state_dict [q_name ]
495
- k = state_dict [k_name ]
496
- v = state_dict [v_name ]
492
+ def fuse_param_func ():
493
+ def fn (fuse_params : List [np .array ]):
494
+ return np .concatenate (fuse_params , axis = - 1 )
497
495
498
- naive_merged_qkv = np . concatenate (( q , k , v ), axis = - 1 )
496
+ return fn
499
497
500
- return naive_merged_qkv_to_tensor_parallel_qkv (naive_merged_qkv , num_hidden_layers )
501
498
499
+ def split_param_func (split_nums ):
500
+ def fn (fused_param ):
501
+ return np .split (fused_param , split_nums , axis = - 1 )
502
502
503
- def merge_as_naive_merged_qkv ():
504
- pass
503
+ return fn
505
504
506
505
507
- def merge_as_splited_qkv ( ):
508
- pass
506
+ def split_or_fuse_func ( is_fuse = True ):
507
+ return fuse_param_func if is_fuse else split_param_func
509
508
510
509
511
510
def get_tensor_parallel_merge_func (tensor_parallel_degree , tensor_parallel_rank , num_attention_heads = None ):
@@ -1101,19 +1100,10 @@ def _get_name_mappings(cls, config: PretrainedConfig) -> List[StateDictNameMappi
1101
1100
1102
1101
@classmethod
1103
1102
def get_tensor_parallel_convert_actions (
1104
- cls , config : PretrainedConfig , loaded_state_dict_keys , is_split = True , ignore_error = False , ignore_params = []
1103
+ cls , config : PretrainedConfig , loaded_state_dict_keys , is_split = True , ignore_error = False
1105
1104
):
1106
1105
name_action_mappings = cls ._get_tensor_parallel_mappings (config , is_split = is_split )
1107
-
1108
- # avoid act on fuse parameters (qkv/gate-up), they are not consistant between config and loaded_state_dict_keys
1109
- name_map_list = cls ._get_name_mappings (config )
1110
- for key in ignore_params :
1111
- for name_map in name_map_list :
1112
- if name_map .target_name == key :
1113
- name_action_mappings .pop (name_map .source_name .split ("model." )[- 1 ], None )
1114
-
1115
1106
state_keys_map = cls ._resolve_prefix_keys (name_action_mappings .keys (), loaded_state_dict_keys , ignore_error )
1116
-
1117
1107
for k , v in state_keys_map .items ():
1118
1108
name_action_mappings [v ] = name_action_mappings .pop (k )
1119
1109
return name_action_mappings
@@ -1129,66 +1119,27 @@ def convert_tensor_parallel(
1129
1119
config (PretrainedConfig): the PretrainedConfig instance of model
1130
1120
"""
1131
1121
1132
- def _apply_tp_action (name_action_mappings ):
1133
- state_keys_map = cls ._resolve_prefix_keys (name_action_mappings .keys (), state_dict .keys (), ignore_error )
1134
-
1135
- for k , v in state_keys_map .items ():
1136
- name_action_mappings [v ] = name_action_mappings .pop (k )
1137
-
1138
- for name , action in name_action_mappings .items ():
1139
- if name not in state_dict :
1140
- if not ignore_error :
1141
- logger .warning (f"Key <{ name } > not in the model state weight file." )
1142
- continue
1143
- tensor = state_dict .pop (name )
1144
- new_tensor = action (tensor )
1145
- with device_guard ("cpu" ):
1146
- state_dict [name ] = paddle .Tensor (new_tensor , zero_copy = True )
1147
-
1122
+ name_action_mappings = cls ._get_tensor_parallel_mappings (config )
1148
1123
if state_dict is None :
1149
1124
with device_guard ("cpu" ):
1150
1125
state_dict = paddle .load (weight_file , return_numpy = False )
1151
1126
logger .info ("Starting to convert orignal state_dict to tensor parallel state_dict." )
1152
1127
1153
- from paddlenlp . transformers . model_utils import select_fuse_parameter
1128
+ state_keys_map = cls . _resolve_prefix_keys ( name_action_mappings . keys (), state_dict . keys (), ignore_error )
1154
1129
1155
- do_fuse_parameter_list , do_separate_parameter_list = select_fuse_parameter (cls , state_dict .keys (), config )
1156
- if "attention_qkv_proj" in do_fuse_parameter_list :
1157
- state_dict , fuse_success = cls .fuse_attention_parameters (
1158
- state_dict , ["attention_qkv_proj" ], config
1159
- ) # design: q, k, v => qkv
1160
-
1161
- name_action_mappings = cls ._get_tensor_parallel_mappings (config )
1130
+ for k , v in state_keys_map .items ():
1131
+ name_action_mappings [v ] = name_action_mappings .pop (k )
1162
1132
1163
- # avoid act on fuse parameters (qkv/gate-up), they are not consistant between config and loaded_state_dict_keys
1164
- # pop qkv tp actions and apply the rest actions
1165
- if "attention_qkv_proj" in do_fuse_parameter_list :
1166
-
1167
- name_map_list = [
1168
- lambda layer_id : re .sub (r"\d+" , str (layer_id ), "layers.0.self_attn.q_proj.weight" ),
1169
- lambda layer_id : re .sub (r"\d+" , str (layer_id ), "layers.0.self_attn.k_proj.weight" ),
1170
- lambda layer_id : re .sub (r"\d+" , str (layer_id ), "layers.0.self_attn.v_proj.weight" ),
1171
- lambda layer_id : re .sub (r"\d+" , str (layer_id ), "layers.0.self_attn.qkv_proj.weight" ),
1172
- ]
1173
- tp_action_keys = list (name_action_mappings .keys ())
1174
- poped_param_names = []
1175
- for key in tp_action_keys :
1176
- for name_map in name_map_list :
1177
- if re .sub (r"\d+" , "0" , key ) == name_map (0 ):
1178
- name_action_mappings .pop (key , None )
1179
- poped_param_names .append (key )
1180
-
1181
- _apply_tp_action (name_action_mappings )
1182
-
1183
- # tail processing qkv parameters
1184
- if "attention_qkv_proj" in do_fuse_parameter_list :
1185
- name_action_mappings_fuse = cls ._get_tensor_parallel_mappings (config )
1186
- tp_action_fuse_keys = list (name_action_mappings_fuse .keys ())
1187
- for key in tp_action_fuse_keys :
1188
- if key not in poped_param_names :
1189
- name_action_mappings_fuse .pop (key , None )
1190
-
1191
- _apply_tp_action (name_action_mappings_fuse )
1133
+ for name , action in name_action_mappings .items ():
1134
+ if name not in state_dict :
1135
+ if not ignore_error :
1136
+ logger .warning (f"Key <{ name } > not in the model state weight file." )
1137
+ continue
1138
+ tensor = state_dict .pop (name )
1139
+ new_tensor = action (tensor )
1140
+ with device_guard ("cpu" ):
1141
+ state_dict [name ] = paddle .Tensor (new_tensor , zero_copy = True )
1142
+ state_dict = cls .convert_fuse_and_split (config , state_dict , name_action_mappings )
1192
1143
1193
1144
return state_dict
1194
1145
@@ -1270,6 +1221,90 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):
1270
1221
1271
1222
return state_keys_map
1272
1223
1224
+ def convert_fuse_and_split (cls , config : PretrainedConfig , state_dict , tp_actions = None ):
1225
+ loaded_keys = state_dict .keys ()
1226
+ # collect and convert fuse/split action
1227
+ fused_and_split_keys = []
1228
+ fuse_actions = cls .get_fuse_or_split_param_convert_actions (config , loaded_keys , is_fuse = True )
1229
+ for keys , action in fuse_actions .items ():
1230
+ origin_states = [state_dict [key ] for key in keys ]
1231
+ state_dict [keys [- 1 ]] = action (origin_states )
1232
+ fused_and_split_keys .append (keys [- 1 ])
1233
+
1234
+ split_actions = cls .get_fuse_or_split_param_convert_actions (config , loaded_keys , is_fuse = False )
1235
+ for keys , action in split_actions .items ():
1236
+ origin_state = state_dict [keys [- 1 ]]
1237
+ split_states = action (origin_state )
1238
+ for key , key_idx in enumerate (keys [:- 1 ]):
1239
+ state_dict [key ] = split_states [key_idx ]
1240
+ fused_and_split_keys .append (key )
1241
+
1242
+ if tp_actions is not None :
1243
+ for key in fused_and_split_keys :
1244
+ if key in tp_actions :
1245
+ state_dict [key ] = tp_actions [key ](state_dict .pop (key ))
1246
+ return state_dict
1247
+
1248
+ def get_fuse_or_split_param_convert_actions (
1249
+ cls ,
1250
+ config : PretrainedConfig ,
1251
+ loaded_state_dict_keys ,
1252
+ is_fuse = True ,
1253
+ ignore_error = False ,
1254
+ ):
1255
+ name_action_mappings = cls ._get_fuse_or_split_param_mappings (config , is_fuse )
1256
+ state_keys_map = cls ._resolve_prefix_keys_for_fuse_and_split (
1257
+ name_action_mappings .keys (), loaded_state_dict_keys , ignore_error , is_fuse = True
1258
+ )
1259
+ for k , v in state_keys_map .items ():
1260
+ name_action_mappings [v ] = name_action_mappings .pop (k )
1261
+
1262
+ filter_name_action = {}
1263
+ for k , v in name_action_mappings .items ():
1264
+ if is_fuse :
1265
+ cond = all (item in loaded_state_dict_keys for item in k [:- 1 ])
1266
+ else :
1267
+ cond = k [- 1 ] in loaded_state_dict_keys
1268
+
1269
+ if cond :
1270
+ filter_name_action [k ] = v
1271
+
1272
+ return filter_name_action
1273
+
1274
+ def _get_fuse_or_split_param_mappings (cls , config : PretrainedConfig , is_fuse = True ) -> List [StateDictNameMapping ]:
1275
+ """get fused parameter mapping of PretrainedModel
1276
+
1277
+ Args:
1278
+ config (PretrainedConfig): the configuration of name-mapping
1279
+
1280
+ Raises:
1281
+ NotImplementedError:
1282
+
1283
+ Returns:
1284
+ List[StateDictNameMapping]: the name-mappings for tensor_parallel
1285
+ """
1286
+ raise NotImplementedError (
1287
+ f"`_get_fused_param_mappings` is not implemented for { cls .__name__ } `. To implement it, you should "
1288
+ f"overwrite this method in the class { cls .__name__ } in `{ cls .__module__ } .py`"
1289
+ )
1290
+
1291
+ @staticmethod
1292
+ def _resolve_prefix_keys_for_fuse_and_split (state_keys_base , state_keys_real , ignore_error = False , is_fuse = True ):
1293
+ state_keys_map = {}
1294
+
1295
+ for keys in state_keys_base :
1296
+ base_key = keys [0 ] if is_fuse else keys [- 1 ]
1297
+ prefix = ""
1298
+ for x in state_keys_real :
1299
+ if x .endswith (base_key ):
1300
+ prefix = x .replace (x , base_key )
1301
+ break
1302
+ new_keys = (prefix + key for key in keys )
1303
+
1304
+ state_keys_map [keys ] = new_keys
1305
+
1306
+ return state_keys_map
1307
+
1273
1308
1274
1309
class Converter (ConversionMixin , LogitComparer ):
1275
1310
"""some converters are implemented in ppdiffusers, so if remove it directly, it will make ppdiffusers down.
0 commit comments