Skip to content

Commit dd1e52b

Browse files
committed
fix rotary_emb for llama
1 parent 87e4c4f commit dd1e52b

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,13 @@ def fusion_rope(query_states, key_states, value_states, hidden_states, position_
5757
assert past_key_value is None, "fuse rotary not support cache kv for now"
5858
batch_size, seq_length, num_heads, head_dim = query_states.shape
5959
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
60-
cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len)
60+
if get_env_device() != "gcu":
61+
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
6162
if get_env_device() == "npu":
6263
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
6364
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
6465
elif get_env_device() == "gcu":
66+
cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
6567
query_states, key_states = core.eager._run_custom_op(
6668
"fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True
6769
)

paddlenlp/transformers/llama/modeling.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,14 @@ def forward(self, x, seq_len=None):
419419
return (
420420
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
421421
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
422-
self.cos_sin_table.cast(x.dtype)
423-
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype
424-
else self.cos_sin_table,
425422
)
426423

424+
def get_fused_cos_sin(self, x, seq_len=None):
425+
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype:
426+
return self.cos_sin_table.cast(x.dtype)
427+
else:
428+
return self.cos_sin_table
429+
427430

428431
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
429432
def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
@@ -482,19 +485,26 @@ def _scale_cos_sin(self, seq_len):
482485
def forward(self, x, seq_len=None):
483486
# x: [bs, num_attention_heads, seq_len, head_size]
484487
if seq_len > self.max_position_embeddings:
485-
scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)
488+
scale_cos, scale_sin, _ = self._scale_cos_sin(seq_len=seq_len)
486489
else:
487-
scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table
490+
scale_cos, scale_sin = self.cos_cached, self.sin_cached
488491
cos = scale_cos[:, :seq_len, :, ...]
489492
sin = scale_sin[:, :seq_len, :, ...]
490493
return (
491494
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
492495
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
493-
scale_cos_sin.cast(x.dtype)
494-
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype
495-
else scale_cos_sin,
496496
)
497497

498+
def get_fused_cos_sin(self, x, seq_len=None):
499+
if seq_len > self.max_position_embeddings:
500+
_, _, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)
501+
else:
502+
scale_cos_sin = self.cos_sin_table
503+
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype:
504+
return scale_cos_sin.cast(x.dtype)
505+
else:
506+
return scale_cos_sin
507+
498508

499509
def rotate_half(x):
500510
"""Rotates half the hidden dims of the input."""
@@ -943,7 +953,7 @@ def forward(
943953
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
944954
)
945955
else:
946-
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
956+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
947957

948958
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
949959

0 commit comments

Comments
 (0)