Skip to content

[LoRA] Add LoRA support to AuraFlow #9017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

Warlord-K
Copy link
Contributor

What does this PR do?

Adds LoRA support to AuraFlow

from diffusers import AuraFlowPipeline
import torch 

pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow-v0.2", torch_dtype = torch.float16).to("cuda")
pipe.load_lora_weights("Warlord-K/gorkem-auraflow-lora", weight_name="pytorch_lora_weights.safetensors") # Set weight_name = "lora_peft_format.safetensors" to test loading from peft format
image = pipe("gorkem in a black tuxedo", generator = torch.Generator().manual_seed(2347862)).images[0]
image.save("test.png")

b062c295-da4f-4887-92e9-36ca04daba0c

Following functions have also been tested taking SD3 LoRA Tests as reference:

pipe.load_lora_weights()
pipe.unload_lora_weights()
pipe.fuse_lora()
pipe.unfuse_lora()

Fusing lora decreases inference time by ~1.5s and unfusing it increases it again.

Before submitting

@sayakpaul Please review

P.S. make style && make quality fails on some other file hence I wasnt able to run it

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some comments.

I think the TODOs are:

LMK if you have questions.

Comment on lines 325 to 326
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
def fuse_qkv_projections(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a part of the PR. Let's tackle this separately. Also #8952.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I have removed it now.

@@ -329,6 +371,7 @@ def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, AuraFlow has two kinds of attention, right? MMDiT blocks have joint attention and Single DiT blocks have regular attention. So, wondering if it's right to call joint_attention_kwargs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, attention_kwargs should be more appropriate, I have replaced with that.


_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
text_encoder_name = TEXT_ENCODER_NAME
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the text_encoder_name then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed that, has been removed.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Warlord-K
Copy link
Contributor Author

I have added the tests using the lora I tested, and am building docs via https://github.com/huggingface/diffusers/tree/main/docs. Just to confirm I need to build the docs for the AuraFlowLoRALoaderMixin locally and then push? Please let me know if anything is incorrect with this

@sayakpaul
Copy link
Member

@Warlord-K you don't have build the docs locally. You just have to add the entry to corresponding loaders doc: https://huggingface.co/docs/diffusers/main/en/api/loaders/lora.

@Warlord-K
Copy link
Contributor Author

Ah, I think then it should be done. Please check

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more comments.

@@ -17,6 +17,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should suffice for the docs.

safe_serialization=safe_serialization,
)

# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.fuse_lora
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this wouldn't be a copy because SD3 has more components.

components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, I am removing the copy line from other functions too which have small changes

components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)

# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@@ -232,7 +233,7 @@ def forward(
return encoder_hidden_states, hidden_states


class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need FromOriginalModelMixin. No?

Copy link
Contributor Author

@Warlord-K Warlord-K Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope Sorry, Removed.



@require_peft_backend
class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):

Comment on lines 39 to 46
"sample_size": 64,
"patch_size": 2,
"in_channels": 4,
"num_mmdit_layers": 4,
"num_single_dit_layers": 32,
"attention_head_dim": 256,
"num_attention_heads": 12,
"joint_attention_dim": 2048,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are very big numbers of for fasts. Please consider using significantly smaller numbers as done in SD3 and others.

Comment on lines 51 to 69
vae_kwargs = {
"sample_size": 1024,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": [
128,
256,
512,
512
],
"layers_per_block": 2,
"latent_channels": 4,
"norm_num_groups": 32,
"use_quant_conv": True,
"use_post_quant_conv": True,
"shift_factor": None,
"scaling_factor": 0.13025,
}
has_three_text_encoders = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Comment on lines 51 to 69
vae_kwargs = {
"sample_size": 1024,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": [
128,
256,
512,
512
],
"layers_per_block": 2,
"latent_channels": 4,
"norm_num_groups": 32,
"use_quant_conv": True,
"use_post_quant_conv": True,
"shift_factor": None,
"scaling_factor": 0.13025,
}
has_three_text_encoders = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to explicitly specify has_three_text_encoders=False as

has_two_text_encoders = False

has_three_text_encoders = False

def test_af_lora(self):
"""
Test loading the loras that are saved with the diffusers and peft formats.
Related PR: https://github.com/huggingface/diffusers/pull/8584
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is that PR related?

has_three_text_encoders = False

@require_torch_gpu
def test_af_lora(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can safely remove this test.

@Warlord-K
Copy link
Contributor Author

I have made the required changes and added tests and doc mentions. I have tried to follow #9057 for the tests since utils.py was significantly changed to accomodate the newer models but I get all 26 tests skipped for both flux and auraflow when I run them on my laptop. @sayakpaul Please review and let me know If am making any mistake while running the tests.

@sayakpaul
Copy link
Member

but I get all 26 tests skipped for both flux and auraflow when I run them on my laptop.

Do you have peft installed in the env where you are running this from? We have this constraint:

@require_peft_backend

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks much for the changes. I just left my comments.


@classmethod
@validate_hf_hub_args
def lora_state_dict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a "Copied from statement ..." here like:

# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict

components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)

# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict with text_encoder removed from components
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Copied from ..." statements won't work with keywords like "removed from ...". Better to remove.

@@ -434,6 +435,7 @@ def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, sorry for my oversight here. We can call it joint_attention_kwargs as that is what we call them in Flux as well.

@vladmandic
Copy link
Contributor

any updates here? seems like a great progress made and then no updates for the past month

Copy link
Contributor

github-actions bot commented Oct 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Oct 7, 2024
@sayakpaul
Copy link
Member

@Warlord-K a gentle ping here.

@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Oct 7, 2024
Copy link
Contributor

github-actions bot commented Nov 1, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 1, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Nov 6, 2024
@hameerabbasi
Copy link
Contributor

I can take over here if that's okay, @Warlord-K.

@sayakpaul
Copy link
Member

Feel free to cherry pick commits! Thanks for offering to help :)

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 12, 2024
@sayakpaul
Copy link
Member

@hameerabbasi, are you still interested in picking this up?

@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Dec 13, 2024
@hameerabbasi
Copy link
Contributor

hameerabbasi commented Dec 13, 2024

Right — I didn’t find the bandwidth, but happy to let others take over.

Edit: On second thought; I can spend a few hours today.

Copy link
Contributor

github-actions bot commented Jan 6, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jan 6, 2025
@hameerabbasi
Copy link
Contributor

I guess this one can be closed, it's superceded by #10216

@github-actions github-actions bot removed the stale Issues that haven't received updates label Jan 7, 2025
Copy link
Contributor

github-actions bot commented Feb 1, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants