Skip to content

Commit b1f172f

Browse files
committed
refine code
1 parent acc025f commit b1f172f

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
415415
mscale = self.config.mla_config.mscale
416416
self.softmax_scale = self.softmax_scale * mscale * mscale
417417

418+
self.position_ids: list[int] = []
419+
418420
self.weight_dtype = self._dtype
419421
self.create_params_type = self.get_weight_create_dype()
420422

@@ -949,7 +951,7 @@ def compute_layernorm_before_qkv(self, src, i):
949951

950952
return ln_out
951953

952-
def compute_qkv_linear(self, ln_out, i, position_ids=None):
954+
def compute_qkv_linear(self, ln_out, i):
953955
if self.config.mla_config.use_mla():
954956
if self.config.mla_config.q_lora_rank is not None:
955957
query = paddle.matmul(ln_out, self.q_a_proj_weights[i])
@@ -989,7 +991,7 @@ def compute_qkv_linear(self, ln_out, i, position_ids=None):
989991
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
990992
)
991993

992-
query_pe, key_pe = self.config.rotary_emb(position_ids, query_pe, key_pe)
994+
query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)
993995

994996
query[..., self.config.mla_config.qk_nope_head_dim :] = query_pe
995997
key = paddle.empty_like(query)
@@ -1017,9 +1019,9 @@ def compute_qkv_linear(self, ln_out, i, position_ids=None):
10171019

10181020
return qkv_out
10191021

1020-
def compute_qkv(self, src, residual_input, i, position_ids=None):
1022+
def compute_qkv(self, src, residual_input, i):
10211023
ln_out = self.compute_layernorm_before_qkv(src, i)
1022-
qkv_out = self.compute_qkv_linear(ln_out, i, position_ids)
1024+
qkv_out = self.compute_qkv_linear(ln_out, i)
10231025
return qkv_out, residual_input
10241026

10251027
def compute_max_len(self, seq_lens_encoder, seq_lens_decoder, cum_offsets):
@@ -1298,7 +1300,20 @@ def compute_shared_expert(self, tmp_out, i):
12981300
return ffn2_out
12991301

13001302
def pre_process(self, **kwargs):
1301-
pass
1303+
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
1304+
bsz = seq_lens_this_time.shape[0]
1305+
position_ids = []
1306+
for i in range(bsz):
1307+
cur_seq_len = kwargs.get("seq_lens_encoder", None)[i]
1308+
if cur_seq_len > 0:
1309+
for j in range(cur_seq_len):
1310+
position_ids.append(j)
1311+
else:
1312+
ids = kwargs.get("seq_lens_decoder", None)[i].item()
1313+
if ids > 0:
1314+
position_ids.append(ids)
1315+
1316+
self.position_ids = position_ids
13021317

13031318
def post_process(self, **kwargs):
13041319
time_step = kwargs.get("time_step", None)
@@ -1405,23 +1420,10 @@ def forward(
14051420
kwargs.get("block_size", 64),
14061421
self.config.speculate_config.speculate_max_draft_token_num,
14071422
)
1408-
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
1409-
bsz = seq_lens_this_time.shape[0]
1410-
position_ids = []
1411-
for i in range(bsz):
1412-
cur_seq_len = kwargs.get("seq_lens_encoder", None)[i]
1413-
if cur_seq_len > 0:
1414-
for j in range(cur_seq_len):
1415-
position_ids.append(j)
1416-
else:
1417-
ids = kwargs.get("seq_lens_decoder", None)[i].item()
14181423

1419-
if ids > 0:
1420-
position_ids.append(ids)
1421-
# print("position_ids;", position_ids)
14221424
residual_input = src
14231425
for i in range(self.num_layers):
1424-
qkv_out, residual_input = self.compute_qkv(src, residual_input, i, position_ids)
1426+
qkv_out, residual_input = self.compute_qkv(src, residual_input, i)
14251427
out_linear_out = self.compute_attn(
14261428
time_step,
14271429
qkv_out,
@@ -1856,8 +1858,7 @@ def compute_qkv_linear(self, ln_out, i):
18561858
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
18571859
)
18581860

1859-
position_ids = paddle.arange(ln_out.shape[0]).reshape((1, -1))
1860-
query_pe, key_pe = self.config.rotary_emb(position_ids, query_pe, key_pe)
1861+
query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)
18611862

18621863
query[..., self.config.mla_config.qk_nope_head_dim :] = query_pe
18631864
key = paddle.empty_like(query)

0 commit comments

Comments
 (0)