Skip to content

Commit ab562b7

Browse files
update, using sep_group
1 parent 63b2be8 commit ab562b7

File tree

7 files changed

+242
-817
lines changed

7 files changed

+242
-817
lines changed

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,18 @@ def swiglu(x, y=None):
5252
flash_attention = None
5353

5454
from paddlenlp.transformers.ring_flash_attention import RingFlashAttention
55-
from paddlenlp.transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance
5655

57-
def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb, cp_parallel_degree=-1):
56+
57+
def fusion_rope(
58+
query_states,
59+
key_states,
60+
value_states,
61+
hidden_states,
62+
position_ids,
63+
past_key_value,
64+
rotary_emb,
65+
cp_parallel_degree=-1,
66+
):
5867
if get_env_device() != "gcu":
5968
assert past_key_value is None, "fuse rotary not support cache kv for now"
6069
batch_size, seq_length, num_heads, head_dim = query_states.shape
@@ -64,9 +73,6 @@ def fusion_rope(query_states, key_states, value_states, hidden_states, position_
6473
kv_seq_len *= cp_parallel_degree
6574
if get_env_device() != "gcu":
6675
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
67-
if cp_parallel_degree > 1:
68-
cos = split_inputs_sequence_dim_load_balance(cos)
69-
sin = split_inputs_sequence_dim_load_balance(sin)
7076
if get_env_device() == "npu":
7177
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
7278
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
@@ -165,7 +171,7 @@ def fusion_flash_attention(
165171
attention_mask = attention_mask.cast(alibi.dtype) + alibi
166172
if get_env_device() == "npu":
167173
if config.cp_parallel_degree > 1:
168-
raise ValueError(f"Context parallel is not implemented for npu")
174+
raise ValueError("Context parallel is not implemented for npu")
169175
attn_output = core.eager._run_custom_op(
170176
"flash_attention_npu",
171177
query_states,
@@ -181,7 +187,7 @@ def fusion_flash_attention(
181187
)[0]
182188
elif get_env_device() == "gcu":
183189
if config.cp_parallel_degree > 1:
184-
raise ValueError(f"Context parallel is not implemented for gcu")
190+
raise ValueError("Context parallel is not implemented for gcu")
185191
attn_output = core.eager._run_custom_op(
186192
"fused_sdp_flash_attention_gcu",
187193
query_states,

paddlenlp/transformers/llama/modeling.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def swiglu(x, y=None):
9999
]
100100

101101

102-
103102
def _get_interleave(n):
104103
def _get_interleave_power_of_2(n):
105104
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
@@ -956,7 +955,7 @@ def forward(
956955
position_ids,
957956
past_key_value,
958957
self.rotary_emb,
959-
self.cp_parallel_degree
958+
self.config.cp_parallel_degree,
960959
)
961960

962961
else:
@@ -972,7 +971,6 @@ def forward(
972971
)
973972
else:
974973
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
975-
976974
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
977975

978976
# [bs, seq_len, num_head, head_dim]

0 commit comments

Comments
 (0)