Skip to content

Commit ae0bea9

Browse files
update (#8359)
* change llama/modeling.py to opt npu performence * update * update * Update modeling.py * add judge * update * update --------- Co-authored-by: Wang Huan <wanghuan29@baidu.com>
1 parent 09a0ce7 commit ae0bea9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def swiglu(x, y=None):
9696
"LlamaForCausalLM",
9797
"LlamaPretrainingCriterion",
9898
]
99-
99+
global npu_is_casual
100100
npu_is_casual = False
101101

102102
def _get_interleave(n):
@@ -213,7 +213,7 @@ def scaled_dot_product_attention(
213213
):
214214
bsz, q_len, num_heads, head_dim = query_states.shape
215215
_, kv_seq_len, _, _ = value_states.shape
216-
216+
global npu_is_casual
217217
if config.use_flash_attention and flash_attention:
218218
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
219219
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]
@@ -1613,6 +1613,7 @@ def forward(
16131613
attention_mask = self._prepare_decoder_attention_mask(
16141614
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
16151615
) # [bs, 1, seq_len, seq_len]
1616+
global npu_is_casual
16161617
if self.config.use_flash_attention:
16171618
is_casual = is_casual_mask(attention_mask)
16181619
if get_env_device() != "npu":

0 commit comments

Comments
 (0)