Skip to content

Commit ef1fb18

Browse files
committed
Redefine fuse and split functions
1 parent 28ed30f commit ef1fb18

File tree

3 files changed

+144
-259
lines changed

3 files changed

+144
-259
lines changed

paddlenlp/transformers/conversion_utils.py

Lines changed: 109 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import inspect
1818
import json
1919
import os
20-
import re
2120
from copy import deepcopy
2221
from dataclasses import dataclass
2322
from typing import (
@@ -490,22 +489,22 @@ def splited_qkv_to_tensor_parallel_qkv(weight_list, num_attention_heads):
490489
return naive_merged_qkv_to_tensor_parallel_qkv(weight)
491490

492491

493-
def merged_as_tensor_parallel_qkv(state_dict, q_name, k_name, v_name, num_hidden_layers):
494-
q = state_dict[q_name]
495-
k = state_dict[k_name]
496-
v = state_dict[v_name]
492+
def fuse_param_func():
493+
def fn(fuse_params: List[np.array]):
494+
return np.concatenate(fuse_params, axis=-1)
497495

498-
naive_merged_qkv = np.concatenate((q, k, v), axis=-1)
496+
return fn
499497

500-
return naive_merged_qkv_to_tensor_parallel_qkv(naive_merged_qkv, num_hidden_layers)
501498

499+
def split_param_func(split_nums):
500+
def fn(fused_param):
501+
return np.split(fused_param, split_nums, axis=-1)
502502

503-
def merge_as_naive_merged_qkv():
504-
pass
503+
return fn
505504

506505

507-
def merge_as_splited_qkv():
508-
pass
506+
def split_or_fuse_func(is_fuse=True):
507+
return fuse_param_func if is_fuse else split_param_func
509508

510509

511510
def get_tensor_parallel_merge_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
@@ -1101,19 +1100,10 @@ def _get_name_mappings(cls, config: PretrainedConfig) -> List[StateDictNameMappi
11011100

11021101
@classmethod
11031102
def get_tensor_parallel_convert_actions(
1104-
cls, config: PretrainedConfig, loaded_state_dict_keys, is_split=True, ignore_error=False, ignore_params=[]
1103+
cls, config: PretrainedConfig, loaded_state_dict_keys, is_split=True, ignore_error=False
11051104
):
11061105
name_action_mappings = cls._get_tensor_parallel_mappings(config, is_split=is_split)
1107-
1108-
# avoid act on fuse parameters (qkv/gate-up), they are not consistant between config and loaded_state_dict_keys
1109-
name_map_list = cls._get_name_mappings(config)
1110-
for key in ignore_params:
1111-
for name_map in name_map_list:
1112-
if name_map.target_name == key:
1113-
name_action_mappings.pop(name_map.source_name.split("model.")[-1], None)
1114-
11151106
state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), loaded_state_dict_keys, ignore_error)
1116-
11171107
for k, v in state_keys_map.items():
11181108
name_action_mappings[v] = name_action_mappings.pop(k)
11191109
return name_action_mappings
@@ -1129,66 +1119,27 @@ def convert_tensor_parallel(
11291119
config (PretrainedConfig): the PretrainedConfig instance of model
11301120
"""
11311121

1132-
def _apply_tp_action(name_action_mappings):
1133-
state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys(), ignore_error)
1134-
1135-
for k, v in state_keys_map.items():
1136-
name_action_mappings[v] = name_action_mappings.pop(k)
1137-
1138-
for name, action in name_action_mappings.items():
1139-
if name not in state_dict:
1140-
if not ignore_error:
1141-
logger.warning(f"Key <{name}> not in the model state weight file.")
1142-
continue
1143-
tensor = state_dict.pop(name)
1144-
new_tensor = action(tensor)
1145-
with device_guard("cpu"):
1146-
state_dict[name] = paddle.Tensor(new_tensor, zero_copy=True)
1147-
1122+
name_action_mappings = cls._get_tensor_parallel_mappings(config)
11481123
if state_dict is None:
11491124
with device_guard("cpu"):
11501125
state_dict = paddle.load(weight_file, return_numpy=False)
11511126
logger.info("Starting to convert orignal state_dict to tensor parallel state_dict.")
11521127

1153-
from paddlenlp.transformers.model_utils import select_fuse_parameter
1128+
state_keys_map = cls._resolve_prefix_keys(name_action_mappings.keys(), state_dict.keys(), ignore_error)
11541129

1155-
do_fuse_parameter_list, do_separate_parameter_list = select_fuse_parameter(cls, state_dict.keys(), config)
1156-
if "attention_qkv_proj" in do_fuse_parameter_list:
1157-
state_dict, fuse_success = cls.fuse_attention_parameters(
1158-
state_dict, ["attention_qkv_proj"], config
1159-
) # design: q, k, v => qkv
1160-
1161-
name_action_mappings = cls._get_tensor_parallel_mappings(config)
1130+
for k, v in state_keys_map.items():
1131+
name_action_mappings[v] = name_action_mappings.pop(k)
11621132

1163-
# avoid act on fuse parameters (qkv/gate-up), they are not consistant between config and loaded_state_dict_keys
1164-
# pop qkv tp actions and apply the rest actions
1165-
if "attention_qkv_proj" in do_fuse_parameter_list:
1166-
1167-
name_map_list = [
1168-
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.q_proj.weight"),
1169-
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.k_proj.weight"),
1170-
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.v_proj.weight"),
1171-
lambda layer_id: re.sub(r"\d+", str(layer_id), "layers.0.self_attn.qkv_proj.weight"),
1172-
]
1173-
tp_action_keys = list(name_action_mappings.keys())
1174-
poped_param_names = []
1175-
for key in tp_action_keys:
1176-
for name_map in name_map_list:
1177-
if re.sub(r"\d+", "0", key) == name_map(0):
1178-
name_action_mappings.pop(key, None)
1179-
poped_param_names.append(key)
1180-
1181-
_apply_tp_action(name_action_mappings)
1182-
1183-
# tail processing qkv parameters
1184-
if "attention_qkv_proj" in do_fuse_parameter_list:
1185-
name_action_mappings_fuse = cls._get_tensor_parallel_mappings(config)
1186-
tp_action_fuse_keys = list(name_action_mappings_fuse.keys())
1187-
for key in tp_action_fuse_keys:
1188-
if key not in poped_param_names:
1189-
name_action_mappings_fuse.pop(key, None)
1190-
1191-
_apply_tp_action(name_action_mappings_fuse)
1133+
for name, action in name_action_mappings.items():
1134+
if name not in state_dict:
1135+
if not ignore_error:
1136+
logger.warning(f"Key <{name}> not in the model state weight file.")
1137+
continue
1138+
tensor = state_dict.pop(name)
1139+
new_tensor = action(tensor)
1140+
with device_guard("cpu"):
1141+
state_dict[name] = paddle.Tensor(new_tensor, zero_copy=True)
1142+
state_dict = cls.convert_fuse_and_split(config, state_dict, name_action_mappings)
11921143

11931144
return state_dict
11941145

@@ -1270,6 +1221,90 @@ def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):
12701221

12711222
return state_keys_map
12721223

1224+
def convert_fuse_and_split(cls, config: PretrainedConfig, state_dict, tp_actions=None):
1225+
loaded_keys = state_dict.keys()
1226+
# collect and convert fuse/split action
1227+
fused_and_split_keys = []
1228+
fuse_actions = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=True)
1229+
for keys, action in fuse_actions.items():
1230+
origin_states = [state_dict[key] for key in keys]
1231+
state_dict[keys[-1]] = action(origin_states)
1232+
fused_and_split_keys.append(keys[-1])
1233+
1234+
split_actions = cls.get_fuse_or_split_param_convert_actions(config, loaded_keys, is_fuse=False)
1235+
for keys, action in split_actions.items():
1236+
origin_state = state_dict[keys[-1]]
1237+
split_states = action(origin_state)
1238+
for key, key_idx in enumerate(keys[:-1]):
1239+
state_dict[key] = split_states[key_idx]
1240+
fused_and_split_keys.append(key)
1241+
1242+
if tp_actions is not None:
1243+
for key in fused_and_split_keys:
1244+
if key in tp_actions:
1245+
state_dict[key] = tp_actions[key](state_dict.pop(key))
1246+
return state_dict
1247+
1248+
def get_fuse_or_split_param_convert_actions(
1249+
cls,
1250+
config: PretrainedConfig,
1251+
loaded_state_dict_keys,
1252+
is_fuse=True,
1253+
ignore_error=False,
1254+
):
1255+
name_action_mappings = cls._get_fuse_or_split_param_mappings(config, is_fuse)
1256+
state_keys_map = cls._resolve_prefix_keys_for_fuse_and_split(
1257+
name_action_mappings.keys(), loaded_state_dict_keys, ignore_error, is_fuse=True
1258+
)
1259+
for k, v in state_keys_map.items():
1260+
name_action_mappings[v] = name_action_mappings.pop(k)
1261+
1262+
filter_name_action = {}
1263+
for k, v in name_action_mappings.items():
1264+
if is_fuse:
1265+
cond = all(item in loaded_state_dict_keys for item in k[:-1])
1266+
else:
1267+
cond = k[-1] in loaded_state_dict_keys
1268+
1269+
if cond:
1270+
filter_name_action[k] = v
1271+
1272+
return filter_name_action
1273+
1274+
def _get_fuse_or_split_param_mappings(cls, config: PretrainedConfig, is_fuse=True) -> List[StateDictNameMapping]:
1275+
"""get fused parameter mapping of PretrainedModel
1276+
1277+
Args:
1278+
config (PretrainedConfig): the configuration of name-mapping
1279+
1280+
Raises:
1281+
NotImplementedError:
1282+
1283+
Returns:
1284+
List[StateDictNameMapping]: the name-mappings for tensor_parallel
1285+
"""
1286+
raise NotImplementedError(
1287+
f"`_get_fused_param_mappings` is not implemented for {cls.__name__}`. To implement it, you should "
1288+
f"overwrite this method in the class {cls.__name__} in `{cls.__module__}.py`"
1289+
)
1290+
1291+
@staticmethod
1292+
def _resolve_prefix_keys_for_fuse_and_split(state_keys_base, state_keys_real, ignore_error=False, is_fuse=True):
1293+
state_keys_map = {}
1294+
1295+
for keys in state_keys_base:
1296+
base_key = keys[0] if is_fuse else keys[-1]
1297+
prefix = ""
1298+
for x in state_keys_real:
1299+
if x.endswith(base_key):
1300+
prefix = x.replace(x, base_key)
1301+
break
1302+
new_keys = (prefix + key for key in keys)
1303+
1304+
state_keys_map[keys] = new_keys
1305+
1306+
return state_keys_map
1307+
12731308

12741309
class Converter(ConversionMixin, LogitComparer):
12751310
"""some converters are implemented in ppdiffusers, so if remove it directly, it will make ppdiffusers down.

paddlenlp/transformers/llama/modeling.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from __future__ import annotations
1717

1818
import math
19-
import re
2019
import warnings
2120
from functools import partial
2221
from typing import Optional, Tuple
@@ -1279,32 +1278,37 @@ def get_tensor_parallel_split_mappings(num_layers):
12791278
return mappings
12801279

12811280
@classmethod
1282-
def _get_fused_param_mappings(cls):
1281+
def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False):
12831282
# return parameter fuse utils
1284-
from paddlenlp.transformers.conversion_utils import (
1285-
merged_as_tensor_parallel_qkv,
1286-
)
1283+
from paddlenlp.transformers.conversion_utils import split_or_fuse_func
1284+
1285+
fn = split_or_fuse_func(is_fuse=is_fuse)
1286+
1287+
final_actions = {}
1288+
if config.fuse_attention_qkv:
1289+
# last key is fused key, other keys are to be fused.
1290+
base_keys = (
1291+
"layers.0.self_attn.q_proj.weight",
1292+
"layers.0.self_attn.k_proj.weight",
1293+
"layers.0.self_attn.v_proj.weight",
1294+
"layers.0.self_attn.qkv_proj.weight",
1295+
)
12871296

1288-
# attention: q,k,v -> qkv, ffn: gate, up -> gate_up
1289-
mappings = {
1290-
"fuse_action": [merged_as_tensor_parallel_qkv, None],
1291-
"split_action": [None, None],
1292-
"attn_param_names": {
1293-
"qkv_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.qkv_proj.weight"),
1294-
"q_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.q_proj.weight"),
1295-
"k_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.k_proj.weight"),
1296-
"v_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.self_attn.v_proj.weight"),
1297-
},
1298-
"ffn_param_names": {
1299-
"gate_up_proj": lambda layer_id: re.sub(
1300-
r"\d+", str(layer_id), "llama.layers.0.mlp.gate_up_proj.weight"
1301-
),
1302-
"gate_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.mlp.gate_proj.weight"),
1303-
"up_proj": lambda layer_id: re.sub(r"\d+", str(layer_id), "llama.layers.0.mlp.up_proj.weight"),
1304-
},
1305-
}
1297+
for i in range(config.num_hidden_layers):
1298+
keys = (key.replace("layers.0.", f"layers.{i}.") for key in base_keys)
1299+
final_actions[keys] = fn
13061300

1307-
return mappings
1301+
if config.fuse_attention_ffn:
1302+
base_keys = (
1303+
"llama.layers.0.mlp.gate_proj.weight",
1304+
"llama.layers.0.mlp.up_proj.weight",
1305+
"llama.layers.0.mlp.gate_up_proj.weight",
1306+
)
1307+
for i in range(config.num_hidden_layers):
1308+
keys = (key.replace("layers.0.", f"layers.{i}.") for key in base_keys)
1309+
final_actions[keys] = fn
1310+
1311+
return final_actions
13081312

13091313
def _init_weights(self, layer):
13101314
"""Initialization hook"""

0 commit comments

Comments
 (0)