We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 145522c commit cd0a4a8Copy full SHA for cd0a4a8
examples/dreambooth/train_dreambooth_lora_sana.py
@@ -995,7 +995,8 @@ def main(args):
995
if args.enable_npu_flash_attention:
996
if is_torch_npu_available():
997
logger.info("npu flash attention enabled.")
998
- transformer.enable_npu_flash_attention()
+ for block in transformer.transformer_blocks:
999
+ block.attn2.set_use_npu_flash_attention(True)
1000
else:
1001
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
1002
0 commit comments