Skip to content

Commit 2ffbc24

Browse files
authored
Adapt fa for npu (#706)
1 parent d7d9694 commit 2ffbc24

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

ppdiffusers/ppdiffusers/patches/paddle_patch.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,44 @@ def to(self=None, device=None, dtype=None, blocking=None):
351351

352352
nn.Layer.to = to
353353

354-
from ..utils.import_utils import is_ppxformers_available
354+
from ..utils.import_utils import is_ppxformers_available, is_npu_available
355355

356-
if is_ppxformers_available():
356+
if is_npu_available():
357+
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
358+
if lib.endswith(".so"):
359+
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(
360+
lib
361+
)
362+
from paddle.base import core
363+
def scaled_dot_product_attention_npu(query,
364+
key,
365+
value,
366+
attn_mask=None,
367+
dropout_p=0.0,
368+
is_causal=False,
369+
training=True,
370+
name=None,
371+
fixed_seed_offset=None,
372+
return_softmax=False,
373+
is_triangle_upper_mask=True,
374+
):
375+
out = core.eager._run_custom_op(
376+
"flash_attention_npu",
377+
query,
378+
key,
379+
value,
380+
fixed_seed_offset,
381+
attn_mask,
382+
dropout_p,
383+
is_causal,
384+
return_softmax,
385+
not training,
386+
is_triangle_upper_mask,
387+
)[0]
388+
return out
389+
paddle.nn.functional.scaled_dot_product_attention_npu = scaled_dot_product_attention_npu
390+
391+
if is_ppxformers_available() or is_npu_available():
357392
from paddle.incubate.nn.memory_efficient_attention import memory_efficient_attention
358393

359394
try:
@@ -392,6 +427,8 @@ def scaled_dot_product_attention_(
392427
attention_op = "cutlass"
393428
if is_support_flash_attention and query.dtype not in [paddle.float32]:
394429
attention_op = "flash"
430+
elif is_npu_available() and query.dtype not in [paddle.float32]:
431+
attention_op = "flash_npu"
395432
else:
396433
if attention_op == "flash" and flash_attn_error is not None:
397434
raise OSError(flash_attn_error)
@@ -473,6 +510,16 @@ def scaled_dot_product_attention_(
473510
is_causal=bool(is_causal),
474511
training=training,
475512
)
513+
elif attention_op == "flash_npu":
514+
output = paddle.nn.functional.scaled_dot_product_attention_npu(
515+
query,
516+
key,
517+
value,
518+
attn_mask=None if is_causal else attn_mask,
519+
dropout_p=dropout_p if training else 0.0,
520+
is_causal=bool(is_causal),
521+
training=training,
522+
)
476523
else:
477524
raise ValueError(
478525
"ppxformers's attention_op shoulde be in ['auto', 'math', 'cutlass', `memory_efficient`, 'flash']."

ppdiffusers/ppdiffusers/utils/import_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ def is_scipy_available():
375375
def is_librosa_available():
376376
return _librosa_available
377377

378+
def is_npu_available():
379+
return paddle.device.get_device().startswith("npu")
378380

379381
def is_ppxformers_available():
380382
USE_PPXFORMERS = str2bool(os.getenv("USE_PPXFORMERS", True))

0 commit comments

Comments
 (0)