Skip to content

FluxPipeline does not support batch infernce #9051

Closed
@lmxyy

Description

@lmxyy

Describe the bug

The FLUX.1-dev model does not support batch inference. When I input two prompts, the inference raised some errors.

Reproduction

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", revision="refs/pr/3", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "a tiny astronaut hatching from an egg on the moon"
out = pipe(
    prompt=[prompt, prompt],
    guidance_scale=3.5,
    num_inference_steps=50,
).images[0]
out.save("dev.png")

Logs

Traceback (most recent call last):
  File "/anaconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/python3.11/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 706, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 428, in forward
    hidden_states = block(
                    ^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 126, in forward
    hidden_states = gate * self.proj_out(hidden_states)
                    ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (4608) at non-singleton dimension 1

System Info

Python 3.11.9
CUDA 12.2
torch 2.3.1
diffusers 0.30.0.dev0
Platform: Ubuntu 20.04.5 LTS

Who can help?

@yiyixuxu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions