Skip to content

Commit ffb986b

Browse files
committed
add FLAGS instead max_partition_size
1 parent f445a7a commit ffb986b

26 files changed

+45
-148
lines changed

csrc/gpu/append_attention.cu

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
5959
const float quant_max_bound,
6060
const float quant_min_bound,
6161
const float out_linear_in_scale,
62-
const int encoder_block_shape_q,
63-
const int decoder_block_shape_q,
64-
const int max_partition_size,
65-
const int encoder_max_partition_size,
6662
const int speculate_max_draft_token_num,
6763
const bool causal,
6864
const bool speculate_decoder) {
@@ -76,7 +72,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
7672
int max_enc_len_this_time_data = max_enc_len_this_time.data<int>()[0];
7773
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
7874
int max_len_kv_data = max_len_kv.data<int>()[0];
79-
75+
const int encoder_block_shape_q = get_encoder_block_shape_q();
76+
const int decoder_block_shape_q = get_decoder_block_shape_q();
8077
auto main_stream = qkv.stream();
8178
static cudaEvent_t main_event;
8279
static cudaEvent_t decoder_event;
@@ -209,8 +206,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
209206
quant_max_bound,
210207
quant_min_bound,
211208
out_linear_in_scale,
212-
max_partition_size,
213-
encoder_max_partition_size,
214209
speculate_max_draft_token_num,
215210
causal,
216211
false,
@@ -248,8 +243,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
248243
quant_max_bound,
249244
quant_min_bound,
250245
out_linear_in_scale,
251-
max_partition_size,
252-
encoder_max_partition_size,
253246
speculate_max_draft_token_num,
254247
causal,
255248
false,
@@ -292,8 +285,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
292285
quant_max_bound,
293286
quant_min_bound,
294287
out_linear_in_scale,
295-
max_partition_size,
296-
encoder_max_partition_size,
297288
speculate_max_draft_token_num,
298289
causal,
299290
false,
@@ -440,8 +431,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
440431
quant_max_bound,
441432
quant_min_bound,
442433
out_linear_in_scale,
443-
max_partition_size,
444-
encoder_max_partition_size,
445434
speculate_max_draft_token_num,
446435
causal,
447436
!speculate_decoder,
@@ -479,8 +468,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
479468
quant_max_bound,
480469
quant_min_bound,
481470
out_linear_in_scale,
482-
max_partition_size,
483-
encoder_max_partition_size,
484471
speculate_max_draft_token_num,
485472
causal,
486473
!speculate_decoder,
@@ -524,8 +511,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
524511
quant_max_bound,
525512
quant_min_bound,
526513
out_linear_in_scale,
527-
max_partition_size,
528-
encoder_max_partition_size,
529514
speculate_max_draft_token_num,
530515
causal,
531516
!speculate_decoder,
@@ -583,10 +568,6 @@ std::vector<paddle::Tensor> AppendAttention(
583568
const float quant_max_bound,
584569
const float quant_min_bound,
585570
const float out_linear_in_scale,
586-
const int encoder_block_shape_q,
587-
const int decoder_block_shape_q,
588-
const int max_partition_size,
589-
const int encoder_max_partition_size,
590571
const int speculate_max_draft_token_num,
591572
const bool causal,
592573
const bool speculate_decoder) {
@@ -648,10 +629,6 @@ std::vector<paddle::Tensor> AppendAttention(
648629
quant_max_bound,
649630
quant_min_bound,
650631
out_linear_in_scale,
651-
encoder_block_shape_q,
652-
decoder_block_shape_q,
653-
max_partition_size,
654-
encoder_max_partition_size,
655632
speculate_max_draft_token_num,
656633
causal,
657634
speculate_decoder);
@@ -698,10 +675,6 @@ std::vector<paddle::Tensor> AppendAttention(
698675
quant_max_bound,
699676
quant_min_bound,
700677
out_linear_in_scale,
701-
encoder_block_shape_q,
702-
decoder_block_shape_q,
703-
max_partition_size,
704-
encoder_max_partition_size,
705678
speculate_max_draft_token_num,
706679
causal,
707680
speculate_decoder);
@@ -749,10 +722,6 @@ std::vector<paddle::Tensor> AppendAttention(
749722
quant_max_bound,
750723
quant_min_bound,
751724
out_linear_in_scale,
752-
encoder_block_shape_q,
753-
decoder_block_shape_q,
754-
max_partition_size,
755-
encoder_max_partition_size,
756725
speculate_max_draft_token_num,
757726
causal,
758727
speculate_decoder);
@@ -798,10 +767,6 @@ std::vector<paddle::Tensor> AppendAttention(
798767
quant_max_bound,
799768
quant_min_bound,
800769
out_linear_in_scale,
801-
encoder_block_shape_q,
802-
decoder_block_shape_q,
803-
max_partition_size,
804-
encoder_max_partition_size,
805770
speculate_max_draft_token_num,
806771
causal,
807772
speculate_decoder);
@@ -903,10 +868,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
903868
const float quant_max_bound,
904869
const float quant_min_bound,
905870
const float out_linear_in_scale,
906-
const int encoder_block_shape_q,
907-
const int decoder_block_shape_q,
908-
const int max_partition_size,
909-
const int encoder_max_partition_size,
910871
const int speculate_max_draft_token_num,
911872
const bool causal,
912873
const bool speculate_decoder) {
@@ -983,10 +944,6 @@ PD_BUILD_OP(append_attention)
983944
"quant_max_bound: float",
984945
"quant_min_bound: float",
985946
"out_linear_in_scale: float",
986-
"encoder_block_shape_q: int",
987-
"decoder_block_shape_q: int",
988-
"max_partition_size: int",
989-
"encoder_max_partition_size: int",
990947
"speculate_max_draft_token_num: int",
991948
"causal: bool",
992949
"speculate_decoder: bool"})

csrc/gpu/append_attn/append_attention_c16_impl.cuh

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,6 @@ void MultiQueryAppendAttention(
786786
const float quant_max_bound,
787787
const float quant_min_bound,
788788
const float in_scale,
789-
const int max_partition_size,
790-
const int encoder_max_partition_size,
791789
const int speculate_max_draft_token_num,
792790
const bool is_decoder,
793791
cudaStream_t &stream,
@@ -839,9 +837,9 @@ void MultiQueryAppendAttention(
839837
int sm_count;
840838
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
841839

842-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
840+
uint32_t chunk_size = get_max_partition_size();
843841
if (!is_decoder) {
844-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
842+
chunk_size = get_encoder_max_partition_size();
845843
}
846844
const int num_chunks = div_up(max_dec_len, chunk_size);
847845
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1058,9 +1056,9 @@ void MultiQueryAppendAttention(
10581056
int sm_count;
10591057
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
10601058

1061-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
1059+
uint32_t chunk_size = get_max_partition_size();
10621060
if (!is_decoder) {
1063-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
1061+
chunk_size = get_encoder_max_partition_size();
10641062
}
10651063
const int num_chunks = div_up(max_dec_len, chunk_size);
10661064

@@ -1301,8 +1299,6 @@ void CascadeAppendAttentionC16Kernel(
13011299
const float quant_max_bound,
13021300
const float quant_min_bound,
13031301
const float in_scale,
1304-
const int max_partition_size,
1305-
const int encoder_max_partition_size,
13061302
const int speculate_max_draft_token_num,
13071303
const bool causal,
13081304
const bool is_decoder,
@@ -1363,8 +1359,6 @@ void CascadeAppendAttentionC16Kernel(
13631359
quant_max_bound,
13641360
quant_min_bound,
13651361
in_scale,
1366-
max_partition_size,
1367-
encoder_max_partition_size,
13681362
speculate_max_draft_token_num,
13691363
is_decoder,
13701364
stream,

csrc/gpu/append_attn/append_attention_c4_impl.cuh

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -973,8 +973,6 @@ void MultiQueryAppendC4Attention(
973973
const float quant_max_bound,
974974
const float quant_min_bound,
975975
const float in_scale,
976-
const int max_partition_size,
977-
const int encoder_max_partition_size,
978976
const int speculate_max_draft_token_num,
979977
const bool is_decoder,
980978
cudaStream_t &stream,
@@ -1036,9 +1034,9 @@ void MultiQueryAppendC4Attention(
10361034
const float ratio = static_cast<float>(num_blocks_need) /
10371035
static_cast<float>(num_blocks_per_wave);
10381036

1039-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
1037+
uint32_t chunk_size = get_max_partition_size();
10401038
if (!is_decoder) {
1041-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
1039+
chunk_size = get_encoder_max_partition_size();
10421040
}
10431041
const int num_chunks = div_up(max_dec_len, chunk_size);
10441042

@@ -1282,9 +1280,9 @@ void MultiQueryAppendC4Attention(
12821280
static_cast<float>(num_blocks_per_wave);
12831281

12841282

1285-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
1283+
static uint32_t chunk_size = get_max_partition_size();
12861284
if (!is_decoder) {
1287-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
1285+
chunk_size = get_encoder_max_partition_size();
12881286
}
12891287
const int num_chunks = div_up(max_dec_len, chunk_size);
12901288
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1538,8 +1536,6 @@ void CascadeAppendAttentionC4Kernel(
15381536
const float quant_max_bound,
15391537
const float quant_min_bound,
15401538
const float in_scale,
1541-
const int max_partition_size,
1542-
const int encoder_max_partition_size,
15431539
const int speculate_max_draft_token_num,
15441540
const bool causal,
15451541
const bool is_decoder,
@@ -1604,8 +1600,6 @@ void CascadeAppendAttentionC4Kernel(
16041600
quant_max_bound,
16051601
quant_min_bound,
16061602
in_scale,
1607-
max_partition_size,
1608-
encoder_max_partition_size,
16091603
speculate_max_draft_token_num,
16101604
is_decoder,
16111605
stream,

csrc/gpu/append_attn/append_attention_c8_impl.cuh

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -860,8 +860,6 @@ void MultiQueryAppendC8Attention(
860860
const float quant_max_bound,
861861
const float quant_min_bound,
862862
const float in_scale,
863-
const int max_partition_size,
864-
const int encoder_max_partition_size,
865863
const int speculate_max_draft_token_num,
866864
const bool is_decoder,
867865
cudaStream_t &stream,
@@ -914,9 +912,9 @@ void MultiQueryAppendC8Attention(
914912
const int dev_id = 0;
915913
int sm_count;
916914
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
917-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
915+
uint32_t chunk_size = get_max_partition_size();
918916
if (!is_decoder) {
919-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
917+
chunk_size = get_encoder_max_partition_size();
920918
}
921919
const int num_chunks = div_up(max_dec_len, chunk_size);
922920
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1136,9 +1134,9 @@ void MultiQueryAppendC8Attention(
11361134
const int dev_id = 0;
11371135
int sm_count;
11381136
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
1139-
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
1137+
uint32_t chunk_size = get_max_partition_size();
11401138
if (!is_decoder) {
1141-
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
1139+
chunk_size = get_encoder_max_partition_size();
11421140
}
11431141

11441142
const int num_chunks = div_up(max_dec_len, chunk_size);
@@ -1377,8 +1375,6 @@ void CascadeAppendAttentionC8Kernel(
13771375
const float quant_max_bound,
13781376
const float quant_min_bound,
13791377
const float in_scale,
1380-
const int max_partition_size,
1381-
const int encoder_max_partition_size,
13821378
const int speculate_max_draft_token_num,
13831379
const bool causal,
13841380
const bool is_decoder,
@@ -1441,8 +1437,6 @@ void CascadeAppendAttentionC8Kernel(
14411437
quant_max_bound,
14421438
quant_min_bound,
14431439
in_scale,
1444-
max_partition_size,
1445-
encoder_max_partition_size,
14461440
speculate_max_draft_token_num,
14471441
is_decoder,
14481442
stream,

csrc/gpu/append_attn/append_attention_kernel.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ void CascadeAppendAttentionC16Kernel(
5252
const float quant_max_bound,
5353
const float quant_min_bound,
5454
const float in_scale,
55-
const int max_partition_size,
56-
const int encoder_max_partition_size,
5755
const int speculate_max_draft_token_num,
5856
const bool causal,
5957
const bool is_decoder,
@@ -97,8 +95,6 @@ void CascadeAppendAttentionC8Kernel(
9795
const float quant_max_bound,
9896
const float quant_min_bound,
9997
const float in_scale,
100-
const int max_partition_size,
101-
const int encoder_max_partition_size,
10298
const int speculate_max_draft_token_num,
10399
const bool causal,
104100
const bool is_decoder,
@@ -142,8 +138,6 @@ void CascadeAppendAttentionC4Kernel(
142138
const float quant_max_bound,
143139
const float quant_min_bound,
144140
const float in_scale,
145-
const int max_partition_size,
146-
const int encoder_max_partition_size,
147141
const int speculate_max_draft_token_num,
148142
const bool causal,
149143
const bool is_decoder,
@@ -188,8 +182,6 @@ void CascadeAppendAttentionKernel(
188182
const float quant_max_bound,
189183
const float quant_min_bound,
190184
const float in_scale,
191-
const int max_partition_size,
192-
const int encoder_max_partition_size,
193185
const int speculate_max_draft_token_num,
194186
const bool causal,
195187
const bool is_decoder,
@@ -223,8 +215,6 @@ void CascadeAppendAttentionKernel(
223215
quant_max_bound,
224216
quant_min_bound,
225217
in_scale,
226-
max_partition_size,
227-
encoder_max_partition_size,
228218
speculate_max_draft_token_num,
229219
causal,
230220
is_decoder,
@@ -258,8 +248,6 @@ void CascadeAppendAttentionKernel(
258248
quant_max_bound,
259249
quant_min_bound,
260250
in_scale,
261-
max_partition_size,
262-
encoder_max_partition_size,
263251
speculate_max_draft_token_num,
264252
causal,
265253
is_decoder,
@@ -293,8 +281,6 @@ void CascadeAppendAttentionKernel(
293281
quant_max_bound,
294282
quant_min_bound,
295283
in_scale,
296-
max_partition_size,
297-
encoder_max_partition_size,
298284
speculate_max_draft_token_num,
299285
causal,
300286
is_decoder,
@@ -307,3 +293,17 @@ void CascadeAppendAttentionKernel(
307293
"cache_int4_zp]");
308294
}
309295
}
296+
297+
inline uint32_t get_max_partition_size() {
298+
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
299+
static const uint32_t max_partition_size =
300+
max_partition_size_env == nullptr ? 128 : std::stoul(std::string(max_partition_size_env));
301+
return max_partition_size;
302+
}
303+
304+
inline uint32_t get_encoder_max_partition_size() {
305+
static const char* encoder_max_partition_size_env = std::getenv("FLAGS_cascade_encoder_attention_max_partition_size");
306+
static const uint32_t encoder_max_partition_size =
307+
encoder_max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(encoder_max_partition_size_env));
308+
return encoder_max_partition_size;
309+
}

0 commit comments

Comments
 (0)