Skip to content

Commit 79cb8b6

Browse files
fix llama fa bug (#8237)
1 parent c1aad02 commit 79cb8b6

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,7 @@ def forward(
15501550
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
15511551
) # [bs, 1, seq_len, seq_len]
15521552
if self.config.use_flash_attention:
1553-
if get_env_device != "npu":
1553+
if get_env_device() != "npu":
15541554
is_casual = is_casual_mask(attention_mask)
15551555
if is_casual and alibi is None:
15561556
attention_mask = None

0 commit comments

Comments
 (0)