@@ -272,16 +272,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
272
272
self .head_dim = self .hidden_size // config .num_attention_heads
273
273
274
274
self .num_key_value_heads = config .num_key_value_heads
275
+ assert config .num_attention_heads // config .num_key_value_heads
275
276
self .num_key_value_groups = config .num_attention_heads // config .num_key_value_heads
277
+ self .gqa_or_mqa = config .num_attention_heads != config .num_key_value_heads
276
278
277
279
self .max_position_embeddings = config .max_position_embeddings
278
280
self .seq_length = config .seq_length
279
281
280
282
self .fuse_attention_qkv = config .fuse_attention_qkv
281
- if self .fuse_attention_qkv and config .num_attention_heads != config .num_key_value_heads :
282
- raise ValueError (
283
- f"fuse_attention_qkv can't be True when num_attention_heads { config .num_attention_heads } != num_key_value_heads { config .num_key_value_heads } "
284
- )
285
283
286
284
self .kv_indices = None
287
285
# Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -303,7 +301,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
303
301
if self .fuse_attention_qkv :
304
302
self .qkv_proj = nn .Linear (
305
303
self .hidden_size ,
306
- 3 * self .hidden_size ,
304
+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
307
305
bias_attr = False ,
308
306
)
309
307
self .qkv_proj .weight = dist .shard_tensor (
@@ -415,10 +413,16 @@ def forward(
415
413
)
416
414
417
415
if self .fuse_attention_qkv :
418
- target_shape = [0 , 0 , self .num_heads , 3 * self .head_dim ]
416
+ target_shape = [0 , 0 , self .num_key_value_heads , ( self . num_key_value_groups + 2 ) * self .head_dim ]
419
417
mix_layer = self .qkv_proj (hidden_states )
420
418
mix_layer = paddle .reshape_ (mix_layer , target_shape )
421
- query_states , key_states , value_states = paddle .split (mix_layer , num_or_sections = 3 , axis = - 1 )
419
+ query_states , key_states , value_states = paddle .split (
420
+ mix_layer ,
421
+ num_or_sections = [self .num_key_value_groups * self .head_dim , self .head_dim , self .head_dim ],
422
+ axis = - 1 ,
423
+ )
424
+ if self .gqa_or_mqa :
425
+ query_states = paddle .reshape (query_states , [0 , 0 , self .num_heads , self .head_dim ])
422
426
else :
423
427
target_query_shape = [0 , 0 , self .num_heads , self .head_dim ]
424
428
target_key_value_shape = [0 , 0 , self .num_key_value_heads , self .head_dim ]
0 commit comments