Skip to content

Commit cd0a4a8

Browse files
leisuzzJ石页sayakpaul
authored
[bugfix] NPU Adaption for Sana (#10724)
* NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * NPU Adaption for Sanna * [bugfix]NPU Adaption for Sanna --------- Co-authored-by: J石页 <jiangshuo9@h-partners.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 145522c commit cd0a4a8

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,8 @@ def main(args):
995995
if args.enable_npu_flash_attention:
996996
if is_torch_npu_available():
997997
logger.info("npu flash attention enabled.")
998-
transformer.enable_npu_flash_attention()
998+
for block in transformer.transformer_blocks:
999+
block.attn2.set_use_npu_flash_attention(True)
9991000
else:
10001001
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
10011002

0 commit comments

Comments
 (0)