@@ -415,6 +415,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
415
415
mscale = self .config .mla_config .mscale
416
416
self .softmax_scale = self .softmax_scale * mscale * mscale
417
417
418
+ self .position_ids : list [int ] = []
419
+
418
420
self .weight_dtype = self ._dtype
419
421
self .create_params_type = self .get_weight_create_dype ()
420
422
@@ -949,7 +951,7 @@ def compute_layernorm_before_qkv(self, src, i):
949
951
950
952
return ln_out
951
953
952
- def compute_qkv_linear (self , ln_out , i , position_ids = None ):
954
+ def compute_qkv_linear (self , ln_out , i ):
953
955
if self .config .mla_config .use_mla ():
954
956
if self .config .mla_config .q_lora_rank is not None :
955
957
query = paddle .matmul (ln_out , self .q_a_proj_weights [i ])
@@ -989,7 +991,7 @@ def compute_qkv_linear(self, ln_out, i, position_ids=None):
989
991
key_value , [self .config .mla_config .qk_nope_head_dim , self .config .mla_config .v_head_dim ], axis = - 1
990
992
)
991
993
992
- query_pe , key_pe = self .config .rotary_emb (position_ids , query_pe , key_pe )
994
+ query_pe , key_pe = self .config .rotary_emb (self . position_ids , query_pe , key_pe )
993
995
994
996
query [..., self .config .mla_config .qk_nope_head_dim :] = query_pe
995
997
key = paddle .empty_like (query )
@@ -1017,9 +1019,9 @@ def compute_qkv_linear(self, ln_out, i, position_ids=None):
1017
1019
1018
1020
return qkv_out
1019
1021
1020
- def compute_qkv (self , src , residual_input , i , position_ids = None ):
1022
+ def compute_qkv (self , src , residual_input , i ):
1021
1023
ln_out = self .compute_layernorm_before_qkv (src , i )
1022
- qkv_out = self .compute_qkv_linear (ln_out , i , position_ids )
1024
+ qkv_out = self .compute_qkv_linear (ln_out , i )
1023
1025
return qkv_out , residual_input
1024
1026
1025
1027
def compute_max_len (self , seq_lens_encoder , seq_lens_decoder , cum_offsets ):
@@ -1298,7 +1300,20 @@ def compute_shared_expert(self, tmp_out, i):
1298
1300
return ffn2_out
1299
1301
1300
1302
def pre_process (self , ** kwargs ):
1301
- pass
1303
+ seq_lens_this_time = kwargs .get ("seq_lens_this_time" , None )
1304
+ bsz = seq_lens_this_time .shape [0 ]
1305
+ position_ids = []
1306
+ for i in range (bsz ):
1307
+ cur_seq_len = kwargs .get ("seq_lens_encoder" , None )[i ]
1308
+ if cur_seq_len > 0 :
1309
+ for j in range (cur_seq_len ):
1310
+ position_ids .append (j )
1311
+ else :
1312
+ ids = kwargs .get ("seq_lens_decoder" , None )[i ].item ()
1313
+ if ids > 0 :
1314
+ position_ids .append (ids )
1315
+
1316
+ self .position_ids = position_ids
1302
1317
1303
1318
def post_process (self , ** kwargs ):
1304
1319
time_step = kwargs .get ("time_step" , None )
@@ -1405,23 +1420,10 @@ def forward(
1405
1420
kwargs .get ("block_size" , 64 ),
1406
1421
self .config .speculate_config .speculate_max_draft_token_num ,
1407
1422
)
1408
- seq_lens_this_time = kwargs .get ("seq_lens_this_time" , None )
1409
- bsz = seq_lens_this_time .shape [0 ]
1410
- position_ids = []
1411
- for i in range (bsz ):
1412
- cur_seq_len = kwargs .get ("seq_lens_encoder" , None )[i ]
1413
- if cur_seq_len > 0 :
1414
- for j in range (cur_seq_len ):
1415
- position_ids .append (j )
1416
- else :
1417
- ids = kwargs .get ("seq_lens_decoder" , None )[i ].item ()
1418
1423
1419
- if ids > 0 :
1420
- position_ids .append (ids )
1421
- # print("position_ids;", position_ids)
1422
1424
residual_input = src
1423
1425
for i in range (self .num_layers ):
1424
- qkv_out , residual_input = self .compute_qkv (src , residual_input , i , position_ids )
1426
+ qkv_out , residual_input = self .compute_qkv (src , residual_input , i )
1425
1427
out_linear_out = self .compute_attn (
1426
1428
time_step ,
1427
1429
qkv_out ,
@@ -1856,8 +1858,7 @@ def compute_qkv_linear(self, ln_out, i):
1856
1858
key_value , [self .config .mla_config .qk_nope_head_dim , self .config .mla_config .v_head_dim ], axis = - 1
1857
1859
)
1858
1860
1859
- position_ids = paddle .arange (ln_out .shape [0 ]).reshape ((1 , - 1 ))
1860
- query_pe , key_pe = self .config .rotary_emb (position_ids , query_pe , key_pe )
1861
+ query_pe , key_pe = self .config .rotary_emb (self .position_ids , query_pe , key_pe )
1861
1862
1862
1863
query [..., self .config .mla_config .qk_nope_head_dim :] = query_pe
1863
1864
key = paddle .empty_like (query )
0 commit comments