Skip to content

Commit 7330593

Browse files
authored
[XPU] qwen2 supports flash_attn on XPU (#9549)
1 parent a5ec6bf commit 7330593

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

paddlenlp/transformers/qwen2/modeling.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
create_skip_config_for_refined_recompute,
4141
recompute,
4242
)
43+
from paddlenlp.utils.tools import get_env_device
4344

4445
from .. import linear_utils
4546
from ..activations import ACT2FN
@@ -1020,7 +1021,14 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
10201021
past_key_values_length=past_key_values_length,
10211022
)
10221023
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
1023-
expanded_attn_mask = paddle.where(expanded_attn_mask.to("bool"), 0.0, paddle.finfo(dtype).min).astype(dtype)
1024+
if get_env_device() == "xpu":
1025+
x = paddle.to_tensor(0.0, dtype="float32")
1026+
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
1027+
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y)
1028+
else:
1029+
expanded_attn_mask = paddle.where(expanded_attn_mask.to("bool"), 0.0, paddle.finfo(dtype).min).astype(
1030+
dtype
1031+
)
10241032
return expanded_attn_mask
10251033

10261034
@paddle.jit.not_to_static

0 commit comments

Comments
 (0)