Skip to content

Commit acc025f

Browse files
authored
Merge pull request #4 from lizhenyun01/fix_rope
Fix rope&fix precision
2 parents feabdb8 + 9657580 commit acc025f

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

paddlenlp/experimental/transformers/deepseek_v2/modeling.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ def _compute_cos_sin_cache(self) -> paddle.Tensor:
114114
inv_freq = self._compute_inv_freq(self.scaling_factor)
115115
t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32)
116116
freqs = paddle.einsum("i,j -> ij", t, inv_freq)
117-
cos = freqs.cos() * self.mscale
118-
sin = freqs.sin() * self.mscale
117+
emb = paddle.concat((freqs, freqs), axis=-1)
118+
cos = emb.cos() * self.mscale
119+
sin = emb.sin() * self.mscale
119120
cache = paddle.concat((cos, sin), axis=-1)
120121
return cache
121122

@@ -125,28 +126,28 @@ def forward(
125126
query: paddle.Tensor,
126127
key: paddle.Tensor,
127128
) -> Tuple[paddle.Tensor, paddle.Tensor]:
128-
query_rot = query[..., : self.rotary_dim]
129-
key_rot = key[..., : self.rotary_dim]
129+
q = query[..., : self.rotary_dim]
130+
k = key[..., : self.rotary_dim]
130131
if self.rotary_dim < self.head_size:
131132
query_pass = query[..., self.rotary_dim :]
132133
key_pass = key[..., self.rotary_dim :]
133-
134-
cos_sin = self.cos_sin_cache[position_ids]
134+
cos_sin = self.cos_sin_cache[position_ids].unsqueeze(1)
135135
cos, sin = cos_sin.chunk(2, axis=-1)
136136

137-
cos = cos.repeat_interleave(2, axis=-1).unsqueeze(-2)
138-
sin = sin.repeat_interleave(2, axis=-1).unsqueeze(-2)
137+
s, h, d = q.shape
138+
q = q.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
139139

140-
def _rotate_gptj(x: paddle.Tensor) -> paddle.Tensor:
141-
x1 = x[..., ::2]
142-
x2 = x[..., 1::2]
143-
x = paddle.stack((-x2, x1), axis=-1)
144-
return x.flatten(-2)
140+
s, h, d = k.shape
141+
k = k.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
145142

146-
rotate_fn = _rotate_gptj
143+
def rotate_half(x):
144+
"""Rotates half the hidden axiss of the input."""
145+
x1 = x[..., : x.shape[-1] // 2]
146+
x2 = x[..., x.shape[-1] // 2 :]
147+
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
147148

148-
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
149-
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
149+
query_rot = (q * cos) + (rotate_half(q) * sin)
150+
key_rot = (k * cos) + (rotate_half(k) * sin)
150151

151152
if self.rotary_dim < self.head_size:
152153
query = paddle.concat((query_rot, query_pass), axis=-1)

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def compute_layernorm_before_qkv(self, src, i):
949949

950950
return ln_out
951951

952-
def compute_qkv_linear(self, ln_out, i):
952+
def compute_qkv_linear(self, ln_out, i, position_ids=None):
953953
if self.config.mla_config.use_mla():
954954
if self.config.mla_config.q_lora_rank is not None:
955955
query = paddle.matmul(ln_out, self.q_a_proj_weights[i])
@@ -989,7 +989,6 @@ def compute_qkv_linear(self, ln_out, i):
989989
key_value, [self.config.mla_config.qk_nope_head_dim, self.config.mla_config.v_head_dim], axis=-1
990990
)
991991

992-
position_ids = paddle.arange(ln_out.shape[0]).reshape((1, -1))
993992
query_pe, key_pe = self.config.rotary_emb(position_ids, query_pe, key_pe)
994993

995994
query[..., self.config.mla_config.qk_nope_head_dim :] = query_pe
@@ -1018,9 +1017,9 @@ def compute_qkv_linear(self, ln_out, i):
10181017

10191018
return qkv_out
10201019

1021-
def compute_qkv(self, src, residual_input, i):
1020+
def compute_qkv(self, src, residual_input, i, position_ids=None):
10221021
ln_out = self.compute_layernorm_before_qkv(src, i)
1023-
qkv_out = self.compute_qkv_linear(ln_out, i)
1022+
qkv_out = self.compute_qkv_linear(ln_out, i, position_ids)
10241023
return qkv_out, residual_input
10251024

10261025
def compute_max_len(self, seq_lens_encoder, seq_lens_decoder, cum_offsets):
@@ -1406,10 +1405,23 @@ def forward(
14061405
kwargs.get("block_size", 64),
14071406
self.config.speculate_config.speculate_max_draft_token_num,
14081407
)
1408+
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
1409+
bsz = seq_lens_this_time.shape[0]
1410+
position_ids = []
1411+
for i in range(bsz):
1412+
cur_seq_len = kwargs.get("seq_lens_encoder", None)[i]
1413+
if cur_seq_len > 0:
1414+
for j in range(cur_seq_len):
1415+
position_ids.append(j)
1416+
else:
1417+
ids = kwargs.get("seq_lens_decoder", None)[i].item()
14091418

1419+
if ids > 0:
1420+
position_ids.append(ids)
1421+
# print("position_ids;", position_ids)
14101422
residual_input = src
14111423
for i in range(self.num_layers):
1412-
qkv_out, residual_input = self.compute_qkv(src, residual_input, i)
1424+
qkv_out, residual_input = self.compute_qkv(src, residual_input, i, position_ids)
14131425
out_linear_out = self.compute_attn(
14141426
time_step,
14151427
qkv_out,

0 commit comments

Comments
 (0)