@@ -52,9 +52,18 @@ def swiglu(x, y=None):
52
52
flash_attention = None
53
53
54
54
from paddlenlp .transformers .ring_flash_attention import RingFlashAttention
55
- from paddlenlp .transformers .context_parallel_utils import split_inputs_sequence_dim_load_balance
56
55
57
- def fusion_rope (query_states , key_states , value_states , hidden_states , position_ids , past_key_value , rotary_emb , cp_parallel_degree = - 1 ):
56
+
57
+ def fusion_rope (
58
+ query_states ,
59
+ key_states ,
60
+ value_states ,
61
+ hidden_states ,
62
+ position_ids ,
63
+ past_key_value ,
64
+ rotary_emb ,
65
+ cp_parallel_degree = - 1 ,
66
+ ):
58
67
if get_env_device () != "gcu" :
59
68
assert past_key_value is None , "fuse rotary not support cache kv for now"
60
69
batch_size , seq_length , num_heads , head_dim = query_states .shape
@@ -64,9 +73,6 @@ def fusion_rope(query_states, key_states, value_states, hidden_states, position_
64
73
kv_seq_len *= cp_parallel_degree
65
74
if get_env_device () != "gcu" :
66
75
cos , sin = rotary_emb (value_states , seq_len = kv_seq_len )
67
- if cp_parallel_degree > 1 :
68
- cos = split_inputs_sequence_dim_load_balance (cos )
69
- sin = split_inputs_sequence_dim_load_balance (sin )
70
76
if get_env_device () == "npu" :
71
77
query_states = core .eager ._run_custom_op ("fused_rope" , query_states , cos , sin )[0 ]
72
78
key_states = core .eager ._run_custom_op ("fused_rope" , key_states , cos , sin )[0 ]
@@ -165,7 +171,7 @@ def fusion_flash_attention(
165
171
attention_mask = attention_mask .cast (alibi .dtype ) + alibi
166
172
if get_env_device () == "npu" :
167
173
if config .cp_parallel_degree > 1 :
168
- raise ValueError (f "Context parallel is not implemented for npu" )
174
+ raise ValueError ("Context parallel is not implemented for npu" )
169
175
attn_output = core .eager ._run_custom_op (
170
176
"flash_attention_npu" ,
171
177
query_states ,
@@ -181,7 +187,7 @@ def fusion_flash_attention(
181
187
)[0 ]
182
188
elif get_env_device () == "gcu" :
183
189
if config .cp_parallel_degree > 1 :
184
- raise ValueError (f "Context parallel is not implemented for gcu" )
190
+ raise ValueError ("Context parallel is not implemented for gcu" )
185
191
attn_output = core .eager ._run_custom_op (
186
192
"fused_sdp_flash_attention_gcu" ,
187
193
query_states ,
0 commit comments