Skip to content

Commit 583e17f

Browse files
committed
fix wint8 precision and refine code
1 parent acc025f commit 583e17f

File tree

2 files changed

+32
-29
lines changed

2 files changed

+32
-29
lines changed

paddlenlp/experimental/transformers/deepseek_v2/modeling.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,28 +126,28 @@ def forward(
126126
query: paddle.Tensor,
127127
key: paddle.Tensor,
128128
) -> Tuple[paddle.Tensor, paddle.Tensor]:
129-
q = query[..., : self.rotary_dim]
130-
k = key[..., : self.rotary_dim]
129+
query_rot = query[..., : self.rotary_dim]
130+
key_rot = key[..., : self.rotary_dim]
131131
if self.rotary_dim < self.head_size:
132132
query_pass = query[..., self.rotary_dim :]
133133
key_pass = key[..., self.rotary_dim :]
134134
cos_sin = self.cos_sin_cache[position_ids].unsqueeze(1)
135135
cos, sin = cos_sin.chunk(2, axis=-1)
136136

137-
s, h, d = q.shape
138-
q = q.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
137+
s, h, d = query_rot.shape
138+
query_rot = query_rot.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
139139

140-
s, h, d = k.shape
141-
k = k.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
140+
s, h, d = key_rot.shape
141+
key_rot = key_rot.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
142142

143143
def rotate_half(x):
144144
"""Rotates half the hidden axiss of the input."""
145145
x1 = x[..., : x.shape[-1] // 2]
146146
x2 = x[..., x.shape[-1] // 2 :]
147147
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
148148

149-
query_rot = (q * cos) + (rotate_half(q) * sin)
150-
key_rot = (k * cos) + (rotate_half(k) * sin)
149+
query_rot = (query_rot * cos) + (rotate_half(query_rot) * sin)
150+
key_rot = (key_rot * cos) + (rotate_half(key_rot) * sin)
151151

152152
if self.rotary_dim < self.head_size:
153153
query = paddle.concat((query_rot, query_pass), axis=-1)
@@ -564,6 +564,7 @@ def set_state_dict(self, state_dict):
564564
q_b_proj_weight, algo=self.quant_algo
565565
)
566566
self.transformer_block.q_b_proj_weights[idx].set_value(q_b_proj_quanted_weight)
567+
self.transformer_block.q_a_layernorm_weights[idx].set_value(q_a_layernorm_weight)
567568
self.transformer_block.q_b_proj_weights_scale[idx].set_value(q_b_proj_weight_scale)
568569
else:
569570
self.transformer_block.q_a_proj_weights[idx].set_value(q_a_proj_weight)
@@ -602,6 +603,7 @@ def set_state_dict(self, state_dict):
602603
kv_b_proj_weight, algo=self.quant_algo
603604
)
604605
self.transformer_block.kv_b_proj_weights[idx].set_value(kv_b_proj_quanted_weight)
606+
self.transformer_block.kv_a_layernorm_weights[idx].set_value(kv_a_layernorm_weight)
605607
self.transformer_block.kv_b_proj_weights_scale[idx].set_value(kv_b_proj_weight_scale)
606608
else:
607609
self.transformer_block.kv_a_proj_with_mqa_weights[idx].set_value(kv_a_proj_with_mqa_weight)

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ def __init__(self, config: FusedMultiTransformerConfig):
415415
mscale = self.config.mla_config.mscale
416416
self.softmax_scale = self.softmax_scale * mscale * mscale
417417

418+
self.position_ids: list[int] = []
419+
418420
self.weight_dtype = self._dtype
419421
self.create_params_type = self.get_weight_create_dype()
420422

@@ -949,7 +951,7 @@ def compute_layernorm_before_qkv(self, src, i):
949951

950952
return ln_out
951953

952-
def compute_qkv_linear(self, ln_out, i, position_ids=None):
954+
def compute_qkv_linear(self, ln_out, i):
953955
if self.config.mla_config.use_mla():
954956
if self.config.mla_config.q_lora_rank is not None:
955957
query = paddle.matmul(ln_out, self.q_a_proj_weights[i])
@@ -989,7 +991,7 @@ def compute_qkv_linear(self, ln_out, i, position_ids=None):
989991
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
990992
)
991993

992-
query_pe, key_pe = self.config.rotary_emb(position_ids, query_pe, key_pe)
994+
query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)
993995

994996
query[..., self.config.mla_config.qk_nope_head_dim :] = query_pe
995997
key = paddle.empty_like(query)
@@ -1017,9 +1019,9 @@ def compute_qkv_linear(self, ln_out, i, position_ids=None):
10171019

10181020
return qkv_out
10191021

1020-
def compute_qkv(self, src, residual_input, i, position_ids=None):
1022+
def compute_qkv(self, src, residual_input, i):
10211023
ln_out = self.compute_layernorm_before_qkv(src, i)
1022-
qkv_out = self.compute_qkv_linear(ln_out, i, position_ids)
1024+
qkv_out = self.compute_qkv_linear(ln_out, i)
10231025
return qkv_out, residual_input
10241026

10251027
def compute_max_len(self, seq_lens_encoder, seq_lens_decoder, cum_offsets):
@@ -1298,7 +1300,20 @@ def compute_shared_expert(self, tmp_out, i):
12981300
return ffn2_out
12991301

13001302
def pre_process(self, **kwargs):
1301-
pass
1303+
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
1304+
bsz = seq_lens_this_time.shape[0]
1305+
position_ids = []
1306+
for i in range(bsz):
1307+
cur_seq_len = kwargs.get("seq_lens_encoder", None)[i]
1308+
if cur_seq_len > 0:
1309+
for j in range(cur_seq_len):
1310+
position_ids.append(j)
1311+
else:
1312+
ids = kwargs.get("seq_lens_decoder", None)[i].item()
1313+
if ids > 0:
1314+
position_ids.append(ids)
1315+
1316+
self.position_ids = position_ids
13021317

13031318
def post_process(self, **kwargs):
13041319
time_step = kwargs.get("time_step", None)
@@ -1405,23 +1420,10 @@ def forward(
14051420
kwargs.get("block_size", 64),
14061421
self.config.speculate_config.speculate_max_draft_token_num,
14071422
)
1408-
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
1409-
bsz = seq_lens_this_time.shape[0]
1410-
position_ids = []
1411-
for i in range(bsz):
1412-
cur_seq_len = kwargs.get("seq_lens_encoder", None)[i]
1413-
if cur_seq_len > 0:
1414-
for j in range(cur_seq_len):
1415-
position_ids.append(j)
1416-
else:
1417-
ids = kwargs.get("seq_lens_decoder", None)[i].item()
14181423

1419-
if ids > 0:
1420-
position_ids.append(ids)
1421-
# print("position_ids;", position_ids)
14221424
residual_input = src
14231425
for i in range(self.num_layers):
1424-
qkv_out, residual_input = self.compute_qkv(src, residual_input, i, position_ids)
1426+
qkv_out, residual_input = self.compute_qkv(src, residual_input, i)
14251427
out_linear_out = self.compute_attn(
14261428
time_step,
14271429
qkv_out,
@@ -1856,8 +1858,7 @@ def compute_qkv_linear(self, ln_out, i):
18561858
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
18571859
)
18581860

1859-
position_ids = paddle.arange(ln_out.shape[0]).reshape((1, -1))
1860-
query_pe, key_pe = self.config.rotary_emb(position_ids, query_pe, key_pe)
1861+
query_pe, key_pe = self.config.rotary_emb(self.position_ids, query_pe, key_pe)
18611862

18621863
query[..., self.config.mla_config.qk_nope_head_dim :] = query_pe
18631864
key = paddle.empty_like(query)

0 commit comments

Comments
 (0)