@@ -96,7 +96,7 @@ def swiglu(x, y=None):
96
96
"LlamaForCausalLM" ,
97
97
"LlamaPretrainingCriterion" ,
98
98
]
99
-
99
+ global npu_is_casual
100
100
npu_is_casual = False
101
101
102
102
def _get_interleave (n ):
@@ -213,7 +213,7 @@ def scaled_dot_product_attention(
213
213
):
214
214
bsz , q_len , num_heads , head_dim = query_states .shape
215
215
_ , kv_seq_len , _ , _ = value_states .shape
216
-
216
+ global npu_is_casual
217
217
if config .use_flash_attention and flash_attention :
218
218
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219
219
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
@@ -1613,6 +1613,7 @@ def forward(
1613
1613
attention_mask = self ._prepare_decoder_attention_mask (
1614
1614
attention_mask , (batch_size , seq_length ), cache_length , inputs_embeds .dtype
1615
1615
) # [bs, 1, seq_len, seq_len]
1616
+ global npu_is_casual
1616
1617
if self .config .use_flash_attention :
1617
1618
is_casual = is_casual_mask (attention_mask )
1618
1619
if get_env_device () != "npu" :
0 commit comments