Skip to content

Flux Controlnet Train Example, will run out of memory on validation step #9546

Closed
@Night1099

Description

@Night1099

Describe the bug

On default settings provided in flux train example readme, with 10 validation images training will error out with out of memory error during validation. on A100 80GB

09/28/2024 00:34:14 - INFO - __main__ - Running validation... 
model_index.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 536/536 [00:00<00:00, 1.54MB/s]
{'controlnet'} was not found in config. Values will be initialized to default values.                                                   | 0.00/536 [00:00<?, ?B/s]
                                                                                                                                                                 Loaded tokenizer_2 as T5TokenizerFast from `tokenizer_2` subfolder of black-forest-labs/FLUX.1-dev.                                          | 0/7 [00:00<?, ?it/s]
                                                                                                                                                                 Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of black-forest-labs/FLUX.1-dev.                      | 1/7 [00:00<00:01,  3.77it/s]
Loaded vae as AutoencoderKL from `vae` subfolder of black-forest-labs/FLUX.1-dev.
                                                                                                                                                                 Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of black-forest-labs/FLUX.1-dev.                                        | 3/7 [00:00<00:00,  8.01it/s]
Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of black-forest-labs/FLUX.1-dev.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.82it/s]
Loaded text_encoder_2 as T5EncoderModel from `text_encoder_2` subfolder of black-forest-labs/FLUX.1-dev.██████████████▍             | 6/7 [00:00<00:00,  7.63it/s]
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  4.41it/s]
Traceback (most recent call last):0%|███████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  3.64it/s]
  File "/workspace/diffusers/examples/controlnet/train_controlnet_flux.py", line 1434, in <module>
    main(args)
  File "/workspace/diffusers/examples/controlnet/train_controlnet_flux.py", line 1370, in main
    image_logs = log_validation(
                 ^^^^^^^^^^^^^^^
  File "/workspace/diffusers/examples/controlnet/train_controlnet_flux.py", line 146, in log_validation
    image = pipeline(
            ^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 860, in __call__
    controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
                                                                ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/models/controlnet_flux.py", line 336, in forward
    encoder_hidden_states, hidden_states = block(
                                           ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 172, in forward
    attn_output, context_attn_output = self.attn(
                                       ^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(
           ^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/models/attention_processor.py", line 1762, in __call__
    query = apply_rotary_emb(query, image_rotary_emb)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/diffusers/src/diffusers/models/embeddings.py", line 680, in apply_rotary_emb
    out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
           ~~~~~~~~~~^~~~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 54.00 MiB. GPU 0 has a total capacity of 79.14 GiB of which 52.75 MiB is free. Process 2301333 has 79.08 GiB memory in use. Of the allocated memory 78.35 GiB is allocated by PyTorch, and 217.84 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Reproduction

Run Train Flux controlnet example with default args in Flux Readme with 10 validation images

Logs

No response

System Info

  • 🤗 Diffusers version: 0.31.0.dev0
  • Platform: Linux-6.5.0-41-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.11.10
  • PyTorch version (GPU?): 2.4.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.25.1
  • Transformers version: 4.45.1
  • Accelerate version: 0.34.2
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA A100 80GB PCIe, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul @PromeAIpro

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions