@@ -56,6 +56,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
56
56
const std::string& cache_quant_type_str,
57
57
const bool use_neox_rotary_style,
58
58
const int max_input_length,
59
+ const float softmax_scale,
59
60
const float quant_max_bound,
60
61
const float quant_min_bound,
61
62
const float out_linear_in_scale,
@@ -97,21 +98,21 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
97
98
if (out_linear_in_scale > 0.0 ) {
98
99
if (fabs (quant_max_bound - 127 .0f ) < 0.000001 ) {
99
100
fmha_out = GetEmptyTensor (
100
- {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims },
101
+ {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims_v },
101
102
paddle::DataType::INT8,
102
103
qkv.place ());
103
104
}
104
105
else if (fabs (quant_max_bound - 448 .0f ) < 0.000001 ) {
105
106
fmha_out = GetEmptyTensor (
106
- {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims },
107
+ {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims_v },
107
108
paddle::DataType::FLOAT8_E4M3FN,
108
109
qkv.place ());
109
110
}else {
110
111
PD_THROW (" Only supported attr of quant_max_bound in ['127.0', '448.0']." );
111
112
}
112
113
} else {
113
114
fmha_out = GetEmptyTensor (
114
- {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims },
115
+ {meta_data.token_nums , meta_data.q_num_heads * meta_data.head_dims_v },
115
116
D,
116
117
qkv.place ());
117
118
}
@@ -203,6 +204,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
203
204
encoder_block_shape_q,
204
205
max_input_length,
205
206
max_enc_len_this_time_data,
207
+ softmax_scale,
206
208
quant_max_bound,
207
209
quant_min_bound,
208
210
out_linear_in_scale,
@@ -240,6 +242,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
240
242
encoder_block_shape_q,
241
243
max_input_length,
242
244
max_enc_len_this_time_data,
245
+ softmax_scale,
243
246
quant_max_bound,
244
247
quant_min_bound,
245
248
out_linear_in_scale,
@@ -282,6 +285,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
282
285
encoder_block_shape_q,
283
286
max_input_length,
284
287
max_enc_len_this_time_data,
288
+ softmax_scale,
285
289
quant_max_bound,
286
290
quant_min_bound,
287
291
out_linear_in_scale,
@@ -428,6 +432,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
428
432
decoder_block_shape_q,
429
433
max_input_length,
430
434
max_len_kv_data,
435
+ softmax_scale,
431
436
quant_max_bound,
432
437
quant_min_bound,
433
438
out_linear_in_scale,
@@ -465,6 +470,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
465
470
decoder_block_shape_q,
466
471
max_input_length,
467
472
max_len_kv_data,
473
+ softmax_scale,
468
474
quant_max_bound,
469
475
quant_min_bound,
470
476
out_linear_in_scale,
@@ -508,6 +514,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
508
514
decoder_block_shape_q,
509
515
max_input_length,
510
516
max_len_kv_data,
517
+ softmax_scale,
511
518
quant_max_bound,
512
519
quant_min_bound,
513
520
out_linear_in_scale,
@@ -565,6 +572,7 @@ std::vector<paddle::Tensor> AppendAttention(
565
572
const std::string& cache_quant_type_str,
566
573
const bool use_neox_rotary_style,
567
574
const int max_input_length,
575
+ const float softmax_scale,
568
576
const float quant_max_bound,
569
577
const float quant_min_bound,
570
578
const float out_linear_in_scale,
@@ -578,9 +586,10 @@ std::vector<paddle::Tensor> AppendAttention(
578
586
meta_data.token_nums = qkv_dims[0 ];
579
587
meta_data.kv_num_heads = key_cache_dims[1 ];
580
588
meta_data.head_dims = key_cache_dims[3 ];
581
- const int total_num_head =
582
- qkv_dims[qkv_dims.size () - 1 ] / meta_data.head_dims ;
583
- meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads ;
589
+ meta_data.head_dims_v = value_cache.dims ()[3 ];
590
+ const int q_hidden_size =
591
+ qkv_dims[qkv_dims.size () - 1 ] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v );
592
+ meta_data.q_num_heads = q_hidden_size / meta_data.head_dims ;
584
593
585
594
meta_data.max_blocks_per_seq = block_tables.dims ()[1 ];
586
595
meta_data.block_size = key_cache.dims ()[2 ];
@@ -626,6 +635,7 @@ std::vector<paddle::Tensor> AppendAttention(
626
635
cache_quant_type_str,
627
636
use_neox_rotary_style,
628
637
max_input_length,
638
+ softmax_scale,
629
639
quant_max_bound,
630
640
quant_min_bound,
631
641
out_linear_in_scale,
@@ -672,6 +682,7 @@ std::vector<paddle::Tensor> AppendAttention(
672
682
cache_quant_type_str,
673
683
use_neox_rotary_style,
674
684
max_input_length,
685
+ softmax_scale,
675
686
quant_max_bound,
676
687
quant_min_bound,
677
688
out_linear_in_scale,
@@ -719,6 +730,7 @@ std::vector<paddle::Tensor> AppendAttention(
719
730
cache_quant_type_str,
720
731
use_neox_rotary_style,
721
732
max_input_length,
733
+ softmax_scale,
722
734
quant_max_bound,
723
735
quant_min_bound,
724
736
out_linear_in_scale,
@@ -764,6 +776,7 @@ std::vector<paddle::Tensor> AppendAttention(
764
776
cache_quant_type_str,
765
777
use_neox_rotary_style,
766
778
max_input_length,
779
+ softmax_scale,
767
780
quant_max_bound,
768
781
quant_min_bound,
769
782
out_linear_in_scale,
@@ -821,10 +834,12 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
821
834
const paddle::optional<std::vector<int64_t >>& out_linear_smooths_shape) {
822
835
const int token_num = qkv_shape[0 ];
823
836
const int kv_num_heads = key_cache_shape[1 ];
824
- const int head_dim = key_cache_shape[3 ];
825
- const int total_num_head = qkv_shape[qkv_shape.size () - 1 ] / head_dim;
826
- const int num_heads = total_num_head - 2 * kv_num_heads;
827
- return {{token_num, num_heads * head_dim}, qkv_shape};
837
+ const int head_dim_qk = key_cache_shape[3 ];
838
+ const int head_dim_v = value_cache_shape[3 ];
839
+ const int q_hidden_size =
840
+ qkv_shape[qkv_shape.size () - 1 ] - kv_num_heads * (head_dim_qk + head_dim_v);
841
+ const int num_heads = q_hidden_size / head_dim_qk;
842
+ return {{token_num, num_heads * head_dim_v}, qkv_shape};
828
843
}
829
844
830
845
std::vector<paddle::DataType> AppendAttentionInferDtype (
@@ -865,6 +880,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
865
880
const std::string& cache_quant_type_str,
866
881
const bool use_neox_rotary_style,
867
882
const int max_input_length,
883
+ const float softmax_scale,
868
884
const float quant_max_bound,
869
885
const float quant_min_bound,
870
886
const float out_linear_in_scale,
@@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
941
957
" cache_quant_type: std::string" ,
942
958
" use_neox_rotary_style: bool" ,
943
959
" max_input_length: int" ,
960
+ " softmax_scale: float" ,
944
961
" quant_max_bound: float" ,
945
962
" quant_min_bound: float" ,
946
963
" out_linear_in_scale: float" ,
0 commit comments