@@ -96,12 +96,15 @@ def __init__(self, config: LlamaConfig):
96
96
self .vocab_size = config .vocab_size
97
97
self .hidden_size = config .hidden_size
98
98
self .num_attention_heads = config .num_attention_heads
99
+ self .num_key_value_heads = config .num_key_value_heads
99
100
self .intermediate_size = config .intermediate_size
100
101
self .num_layers = config .num_hidden_layers
101
102
self .epsilon = config .rms_norm_eps
102
103
self .max_position_embeddings = config .max_position_embeddings
103
104
self .quant_type = config .quant_type
104
105
106
+ self .rope_theta = config .rope_theta
107
+
105
108
self .use_weight_only = False
106
109
self .weight_only_quant_bits = config .weight_only_quant_bits
107
110
@@ -188,8 +191,6 @@ def __init__(self, config: LlamaConfig):
188
191
ffn2_bias_attrs = None
189
192
190
193
if self .quant_type == "a8w8" :
191
- self .quant_round_type = config .quantization_config .quant_round_type
192
-
193
194
qkv_out_scale_attrs = [
194
195
paddle .ParamAttr (name = "fusellama.{}.qkv_out_scale" .format (i )) for i in range (self .num_layers )
195
196
]
@@ -277,9 +278,10 @@ def __init__(self, config: LlamaConfig):
277
278
]
278
279
279
280
transformer_config = FusedMultiTransformerConfig (
280
- self .hidden_size ,
281
- self .num_attention_heads ,
282
- self .intermediate_size ,
281
+ embed_dim = self .hidden_size ,
282
+ num_heads = self .num_attention_heads ,
283
+ kv_num_heads = self .num_key_value_heads ,
284
+ dim_feedforward = self .intermediate_size ,
283
285
weight_only_quant_bits = self .weight_only_quant_bits ,
284
286
activation = "swiglu" ,
285
287
num_layers = config .num_hidden_layers ,
@@ -430,13 +432,12 @@ def forward(
430
432
seq_lens = seq_len_decoder if is_decoder else seq_len_encoder
431
433
432
434
position_offset = 0
433
- theta = 10000.0
434
435
if not is_decoder and pre_caches is not None :
435
436
position_offset = 128
436
437
from paddlenlp_ops import fused_get_rotary_embedding
437
438
438
439
new_rope = fused_get_rotary_embedding (
439
- input_ids , position_ids , self .head_dim_shape_tensor , position_offset , theta , True
440
+ input_ids , position_ids , self .head_dim_shape_tensor , position_offset , self . rope_theta , True
440
441
)
441
442
442
443
with dy2st_nocheck_guard_context ():
@@ -491,7 +492,7 @@ def set_state_dict(self, state_dict):
491
492
state_dict ["llama.layers.{}.self_attn.qkv_proj.weight" .format (idx )],
492
493
is_qkv = True ,
493
494
num_heads = self .num_attention_heads // self .config .tensor_parallel_degree ,
494
- num_key_value_heads = self .num_attention_heads // self .config .tensor_parallel_degree ,
495
+ num_key_value_heads = self .num_key_value_heads // self .config .tensor_parallel_degree ,
495
496
),
496
497
axis = - 1 ,
497
498
).transpose (1 , 0 )
@@ -517,10 +518,14 @@ def set_state_dict(self, state_dict):
517
518
)
518
519
.transpose (1 , 0 )
519
520
.reshape (
520
- 3 * (self .num_attention_heads // self .config .tensor_parallel_degree ) * (head_size ),
521
+ (
522
+ self .num_attention_heads // self .config .tensor_parallel_degree
523
+ + 2 * self .num_key_value_heads // self .config .tensor_parallel_degree
524
+ )
525
+ * (head_size ),
521
526
self .hidden_size ,
522
527
)
523
- ) # reshape(3, self.num_attention_heself.hidden_sizeads // self.config.tensor_parallel_degree, head_size, )
528
+ )
524
529
if "llama.layers.{}.mlp.gate_up_fused_proj.weight" .format (idx ) in state_dict .keys ():
525
530
concated_ffn1_weight = np .concatenate (
526
531
split_fn (state_dict ["llama.layers.{}.mlp.gate_up_fused_proj.weight" .format (idx )]), axis = - 1
@@ -744,7 +749,7 @@ def set_state_dict(self, state_dict):
744
749
cache_scale_json_path ,
745
750
cache_scale_map_dict ,
746
751
num_of_layers = self .config .num_hidden_layers ,
747
- num_heads = self .num_attention_heads // self .config .tensor_parallel_degree ,
752
+ num_heads = self .num_key_value_heads // self .config .tensor_parallel_degree ,
748
753
)
749
754
for k , v in cache_scales_loader .scale .items ():
750
755
for i_layer , weight_scale in enumerate (v ):
@@ -919,7 +924,7 @@ def get_cache_kvs_shape(
919
924
[
920
925
2 ,
921
926
max_batch_size ,
922
- config .num_attention_heads // max (config .tensor_parallel_degree , 1 ),
927
+ config .num_key_value_heads // max (config .tensor_parallel_degree , 1 ),
923
928
max_length ,
924
929
config .hidden_size // config .num_attention_heads ,
925
930
]
@@ -1205,7 +1210,7 @@ def get_cache_kvs_shape(
1205
1210
for _ in range (config .num_hidden_layers ):
1206
1211
cache_kv_shape = [
1207
1212
max_block_nums ,
1208
- config .num_attention_heads // max (config .tensor_parallel_degree , 1 ),
1213
+ config .num_key_value_heads // max (config .tensor_parallel_degree , 1 ),
1209
1214
config .block_size ,
1210
1215
config .hidden_size // config .num_attention_heads ,
1211
1216
]
0 commit comments