Skip to content

FLUX IPAdapter fails when transformers are quantized #10337

Closed
@vladmandic

Description

@vladmandic

Describe the bug

Using new IP-Adapter implementation from #10336
Works if FLUX.1 model is loaded without quantization, fails on load if transformers are quantized using BnB
Not sure why conversion would require gradients? Maybe as simple as including no_grad context during conversion?

Reproduction

import torch
import diffusers

dtype = torch.bfloat16
cache_dir = '/mnt/models/Diffusers'

quantization_config = diffusers.BitsAndBytesConfig()
print('load transformer')
transformer = diffusers.FluxTransformer2DModel.from_pretrained(
    'black-forest-labs/FLUX.1-schnell',
    subfolder="transformer",
    cache_dir=cache_dir,
    torch_dtype=dtype,
    quantization_config=quantization_config
)
print('load pipe')
pipe = diffusers.FluxPipeline.from_pretrained(
    'black-forest-labs/FLUX.1-schnell',
    transformer=transformer,
    cache_dir=cache_dir,
    torch_dtype=dtype,
)
print('load ipadapter')
pipe.load_ip_adapter(
    "XLabs-AI/flux-ip-adapter",
    weight_name="ip_adapter.safetensors",
    image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14"
)

Logs

Traceback (most recent call last):
  File "/home/vlado/dev/sdnext/tmp/flux.py", line 21, in <module>
    torch_dtype=dtype,
^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/loaders/ip_adapter.py", line 553, in load_ip_adapter
    self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
  File "/home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/loaders/transformer_flux.py", line 168, in _load_ip_adapter_weights
    attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/loaders/transformer_flux.py", line 135, in _convert_ip_adapter_attn_to_diffusers
    attn_procs[name] = attn_processor_class(
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 2683, in __init__
    nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
  File "/home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 105, in __init__
    self.weight = Parameter(
                  ^^^^^^^^^^
  File "/home/vlado/dev/sdnext/venv/lib/python3.12/site-packages/torch/nn/parameter.py", line 46, in __new__
    return torch.Tensor._make_subclass(cls, data, requires_grad)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Only Tensors of floating point and complex dtype can require gradients

System Info

diffusers==0.32.0.dev from main

Who can help?

@yiyixuxu @sayakpaul @DN6 @hlky

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions