Closed
Description
Describe the bug
The FLUX.1-dev
model does not support batch inference. When I input two prompts, the inference raised some errors.
Reproduction
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", revision="refs/pr/3", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "a tiny astronaut hatching from an egg on the moon"
out = pipe(
prompt=[prompt, prompt],
guidance_scale=3.5,
num_inference_steps=50,
).images[0]
out.save("dev.png")
Logs
Traceback (most recent call last):
File "/anaconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/python3.11/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 706, in __call__
noise_pred = self.transformer(
^^^^^^^^^^^^^^^^^
File "/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 428, in forward
hidden_states = block(
^^^^^^
File "/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux.py", line 126, in forward
hidden_states = gate * self.proj_out(hidden_states)
~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (4608) at non-singleton dimension 1
System Info
Python 3.11.9
CUDA 12.2
torch 2.3.1
diffusers 0.30.0.dev0
Platform: Ubuntu 20.04.5 LTS