-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[LoRA] Add LoRA support to AuraFlow #9017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0 | ||
def fuse_qkv_projections(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a part of the PR. Let's tackle this separately. Also #8952.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I have removed it now.
@@ -329,6 +371,7 @@ def forward( | |||
hidden_states: torch.FloatTensor, | |||
encoder_hidden_states: torch.FloatTensor = None, | |||
timestep: torch.LongTensor = None, | |||
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, AuraFlow has two kinds of attention, right? MMDiT blocks have joint attention and Single DiT blocks have regular attention. So, wondering if it's right to call joint_attention_kwargs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, attention_kwargs should be more appropriate, I have replaced with that.
|
||
_lora_loadable_modules = ["transformer"] | ||
transformer_name = TRANSFORMER_NAME | ||
text_encoder_name = TEXT_ENCODER_NAME |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the text_encoder_name
then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed that, has been removed.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I have added the tests using the lora I tested, and am building docs via https://github.com/huggingface/diffusers/tree/main/docs. Just to confirm I need to build the docs for the |
@Warlord-K you don't have build the docs locally. You just have to add the entry to corresponding loaders doc: https://huggingface.co/docs/diffusers/main/en/api/loaders/lora. |
Ah, I think then it should be done. Please check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more comments.
@@ -17,6 +17,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi | |||
- [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model. | |||
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model. | |||
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3). | |||
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should suffice for the docs.
safe_serialization=safe_serialization, | ||
) | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.fuse_lora |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, this wouldn't be a copy because SD3 has more components
.
diffusers/src/diffusers/loaders/lora_pipeline.py
Line 1416 in e5b94b4
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh ok, I am removing the copy line from other functions too which have small changes
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names | ||
) | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
@@ -232,7 +233,7 @@ def forward( | |||
return encoder_hidden_states, hidden_states | |||
|
|||
|
|||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): | |||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need FromOriginalModelMixin
. No?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope Sorry, Removed.
tests/lora/test_lora_layers_af.py
Outdated
|
||
|
||
@require_peft_backend | ||
class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): | |
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): |
tests/lora/test_lora_layers_af.py
Outdated
"sample_size": 64, | ||
"patch_size": 2, | ||
"in_channels": 4, | ||
"num_mmdit_layers": 4, | ||
"num_single_dit_layers": 32, | ||
"attention_head_dim": 256, | ||
"num_attention_heads": 12, | ||
"joint_attention_dim": 2048, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are very big numbers of for fasts. Please consider using significantly smaller numbers as done in SD3 and others.
tests/lora/test_lora_layers_af.py
Outdated
vae_kwargs = { | ||
"sample_size": 1024, | ||
"in_channels": 3, | ||
"out_channels": 3, | ||
"block_out_channels": [ | ||
128, | ||
256, | ||
512, | ||
512 | ||
], | ||
"layers_per_block": 2, | ||
"latent_channels": 4, | ||
"norm_num_groups": 32, | ||
"use_quant_conv": True, | ||
"use_post_quant_conv": True, | ||
"shift_factor": None, | ||
"scaling_factor": 0.13025, | ||
} | ||
has_three_text_encoders = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same.
tests/lora/test_lora_layers_af.py
Outdated
vae_kwargs = { | ||
"sample_size": 1024, | ||
"in_channels": 3, | ||
"out_channels": 3, | ||
"block_out_channels": [ | ||
128, | ||
256, | ||
512, | ||
512 | ||
], | ||
"layers_per_block": 2, | ||
"latent_channels": 4, | ||
"norm_num_groups": 32, | ||
"use_quant_conv": True, | ||
"use_post_quant_conv": True, | ||
"shift_factor": None, | ||
"scaling_factor": 0.13025, | ||
} | ||
has_three_text_encoders = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/lora/test_lora_layers_af.py
Outdated
def test_af_lora(self): | ||
""" | ||
Test loading the loras that are saved with the diffusers and peft formats. | ||
Related PR: https://github.com/huggingface/diffusers/pull/8584 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is that PR related?
tests/lora/test_lora_layers_af.py
Outdated
has_three_text_encoders = False | ||
|
||
@require_torch_gpu | ||
def test_af_lora(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can safely remove this test.
I have made the required changes and added tests and doc mentions. I have tried to follow #9057 for the tests since utils.py was significantly changed to accomodate the newer models but I get all 26 tests skipped for both flux and auraflow when I run them on my laptop. @sayakpaul Please review and let me know If am making any mistake while running the tests. |
Do you have
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks much for the changes. I just left my comments.
|
||
@classmethod | ||
@validate_hf_hub_args | ||
def lora_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a "Copied from statement ..." here like:
diffusers/src/diffusers/loaders/lora_pipeline.py
Line 1492 in 98930ee
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict |
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names | ||
) | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict with text_encoder removed from components |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Copied from ..." statements won't work with keywords like "removed from ...". Better to remove.
@@ -434,6 +435,7 @@ def forward( | |||
hidden_states: torch.FloatTensor, | |||
encoder_hidden_states: torch.FloatTensor = None, | |||
timestep: torch.LongTensor = None, | |||
attention_kwargs: Optional[Dict[str, Any]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, sorry for my oversight here. We can call it joint_attention_kwargs
as that is what we call them in Flux as well.
any updates here? seems like a great progress made and then no updates for the past month |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@Warlord-K a gentle ping here. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I can take over here if that's okay, @Warlord-K. |
Feel free to cherry pick commits! Thanks for offering to help :) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@hameerabbasi, are you still interested in picking this up? |
Right — I didn’t find the bandwidth, but happy to let others take over. Edit: On second thought; I can spend a few hours today. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I guess this one can be closed, it's superceded by #10216 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
Adds LoRA support to AuraFlow
Following functions have also been tested taking SD3 LoRA Tests as reference:
Fusing lora decreases inference time by ~1.5s and unfusing it increases it again.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
@sayakpaul Please review
P.S.
make style && make quality
fails on some other file hence I wasnt able to run it