@@ -351,9 +351,44 @@ def to(self=None, device=None, dtype=None, blocking=None):
351
351
352
352
nn .Layer .to = to
353
353
354
- from ..utils .import_utils import is_ppxformers_available
354
+ from ..utils .import_utils import is_ppxformers_available , is_npu_available
355
355
356
- if is_ppxformers_available ():
356
+ if is_npu_available ():
357
+ for lib in os .listdir (os .getenv ("CUSTOM_DEVICE_ROOT" )):
358
+ if lib .endswith (".so" ):
359
+ paddle .utils .cpp_extension .extension_utils .load_op_meta_info_and_register_op (
360
+ lib
361
+ )
362
+ from paddle .base import core
363
+ def scaled_dot_product_attention_npu (query ,
364
+ key ,
365
+ value ,
366
+ attn_mask = None ,
367
+ dropout_p = 0.0 ,
368
+ is_causal = False ,
369
+ training = True ,
370
+ name = None ,
371
+ fixed_seed_offset = None ,
372
+ return_softmax = False ,
373
+ is_triangle_upper_mask = True ,
374
+ ):
375
+ out = core .eager ._run_custom_op (
376
+ "flash_attention_npu" ,
377
+ query ,
378
+ key ,
379
+ value ,
380
+ fixed_seed_offset ,
381
+ attn_mask ,
382
+ dropout_p ,
383
+ is_causal ,
384
+ return_softmax ,
385
+ not training ,
386
+ is_triangle_upper_mask ,
387
+ )[0 ]
388
+ return out
389
+ paddle .nn .functional .scaled_dot_product_attention_npu = scaled_dot_product_attention_npu
390
+
391
+ if is_ppxformers_available () or is_npu_available ():
357
392
from paddle .incubate .nn .memory_efficient_attention import memory_efficient_attention
358
393
359
394
try :
@@ -392,6 +427,8 @@ def scaled_dot_product_attention_(
392
427
attention_op = "cutlass"
393
428
if is_support_flash_attention and query .dtype not in [paddle .float32 ]:
394
429
attention_op = "flash"
430
+ elif is_npu_available () and query .dtype not in [paddle .float32 ]:
431
+ attention_op = "flash_npu"
395
432
else :
396
433
if attention_op == "flash" and flash_attn_error is not None :
397
434
raise OSError (flash_attn_error )
@@ -473,6 +510,16 @@ def scaled_dot_product_attention_(
473
510
is_causal = bool (is_causal ),
474
511
training = training ,
475
512
)
513
+ elif attention_op == "flash_npu" :
514
+ output = paddle .nn .functional .scaled_dot_product_attention_npu (
515
+ query ,
516
+ key ,
517
+ value ,
518
+ attn_mask = None if is_causal else attn_mask ,
519
+ dropout_p = dropout_p if training else 0.0 ,
520
+ is_causal = bool (is_causal ),
521
+ training = training ,
522
+ )
476
523
else :
477
524
raise ValueError (
478
525
"ppxformers's attention_op shoulde be in ['auto', 'math', 'cutlass', `memory_efficient`, 'flash']."
0 commit comments