Skip to content

Commit c5af02d

Browse files
committed
fix rotary_emb for llama
1 parent b36b6a0 commit c5af02d

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def forward(
935935

936936
else:
937937
if self.config.use_long_sequence_strategies:
938-
cos, sin = self.rotary_emb(seq_len=kv_seq_len)
938+
cos, sin, _ = self.rotary_emb(seq_len=kv_seq_len)
939939
cos = cos[None, :, None, :]
940940
sin = sin[None, :, None, :]
941941
cos, sin = (

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def forward(
448448
assert past_key_value is None, "fuse rotary not support cache kv for now"
449449
batch_size, seq_length, num_heads, head_dim = query_states.shape
450450
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
451-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
451+
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
452452

453453
paddle_version = float(paddle.__version__[:3])
454454
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
@@ -481,7 +481,7 @@ def forward(
481481
use_neox_rotary_style=False,
482482
)
483483
else:
484-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
484+
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
485485
# hack here, because elementwise infer spmd not support broadcast now
486486
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
487487

paddlenlp/transformers/llama/modeling_auto_static.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def forward(
421421
if self.config.rope:
422422
if self.use_fused_rope:
423423
assert past_key_value is None, "fuse rotary not support cache kv for now"
424-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
424+
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
425425
query_states, key_states, _ = fused_rotary_position_embedding(
426426
query_states,
427427
key_states,
@@ -432,7 +432,7 @@ def forward(
432432
use_neox_rotary_style=False,
433433
)
434434
else:
435-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
435+
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
436436
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
437437

438438
# [bs, seq_len, num_head, head_dim]

0 commit comments

Comments
 (0)