Skip to content

vipshop/cache-dit

Repository files navigation

🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration
Toolbox for Diffusion Transformers

DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT
offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥

🔥Supported Models🔥

🚀FLUX.1: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥
🚀Mochi: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥
🚀CogVideoX: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥
🚀CogVideoX1.5: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥
🚀Wan2.1: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥
🚀HunyuanVideo: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥

👋 Highlight

The CacheDiT codebase is adapted from FBCache. Special thanks to their excellent work! The FBCache support for Mochi, FLUX.1, CogVideoX, Wan2.1, and HunyuanVideo is directly adapted from the original FBCache.

🤗 Introduction

🔥 DBCache: Dual Block Caching for Diffusion Transformers

DBCache: Dual Block Caching for Diffusion Transformers. We have enhanced FBCache into a more general and customizable cache algorithm, namely DBCache, enabling it to achieve fully UNet-style cache acceleration for DiT models. Different configurations of compute blocks (F8B12, etc.) can be customized in DBCache. Moreover, it can be entirely training-free. DBCache can strike a perfect balance between performance and precision!

DBCache, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"

Baseline(L20x1) F1B0 (0.08) F1B0 (0.20) F8B8 (0.15) F12B12 (0.20) F16B16 (0.20)
24.85s 15.59s 8.58s 15.41s 15.11s 17.74s
Baseline(L20x1) F1B0 (0.08) F8B8 (0.12) F8B12 (0.20) F8B16 (0.20) F8B20 (0.20)
27.85s 6.04s 5.88s 5.77s 6.01s 6.20s

DBCache, L20x4 , Steps: 20, case to show the texture recovery ability of DBCache

These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache F12B12 or F8B16 configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!

🔥 DBPrune: Dynamic Block Prune with Residual Caching

DBPrune: We have further implemented a new Dynamic Block Prune algorithm based on Residual Caching for Diffusion Transformers, referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals.

DBPrune, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"

Baseline(L20x1) Pruned(24%) Pruned(35%) Pruned(38%) Pruned(45%) Pruned(60%)
24.85s 19.43s 16.82s 15.95s 14.24s 10.66s

🔥 Context Parallelism and Torch Compile

Moreover, CacheDiT are plug-and-play solutions that works hand-in-hand with ParaAttention. Users can easily tap into its Context Parallelism features for distributed inference. By the way, CacheDiT is designed to work compatibly with torch.compile. You can easily use CacheDiT with torch.compile to further achieve a better performance.

DBPrune + torch.compile + context parallelism
Steps: 28, "A cat holding a sign that says hello world with complex background"

Baseline Pruned(24%) Pruned(35%) Pruned(38%) Pruned(45%) Pruned(60%)
+compile:20.43s 16.25s 14.12s 13.41s 12.00s 8.86s
+L20x4:7.75s 6.62s 6.03s 5.81s 5.24s 3.93s

♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️

©️Citations

@misc{CacheDiT@2025,
  title={CacheDiT: A Training-free and Easy-to-use cache acceleration Toolbox for Diffusion Transformers},
  url={https://github.com/vipshop/cache-dit.git},
  note={Open-source software available at https://github.com/vipshop/cache-dit.git},
  author={vipshop.com},
  year={2025}
}

📖Contents

⚙️Installation

You can install the stable release of cache-dit from PyPI:

pip3 install cache-dit

Or you can install the latest develop version from GitHub:

pip3 install git+https://github.com/vipshop/cache-dit.git

⚡️DBCache: Dual Block Cache

DBCache provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:

  • Fn: Specifies that DBCache uses the first n Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
  • Bn: Further fuses approximate information in the last n Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
  • warmup_steps: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
  • max_cached_steps: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
  • residual_diff_threshold: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.

For a good balance between performance and precision, DBCache is configured by default with F8B8, 8 warmup steps, and unlimited cached steps.

from diffusers import FluxPipeline
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

# Default options, F8B8, good balance between performance and precision
cache_options = CacheType.default_options(CacheType.DBCache)

# Custom options, F8B16, higher precision
cache_options = {
    "cache_type": CacheType.DBCache,
    "warmup_steps": 8,
    "max_cached_steps": 8,    # -1 means no limit
    "Fn_compute_blocks": 8,   # Fn, F8, etc.
    "Bn_compute_blocks": 16,  # Bn, B16, etc.
    "residual_diff_threshold": 0.12,
}

apply_cache_on_pipe(pipe, **cache_options)

Moreover, users configuring higher Bn values (e.g., F8B16) while aiming to maintain good performance can specify Bn_compute_blocks_ids to work with Bn. DBCache will only compute the specified blocks, with the remaining estimated using the previous step's residual cache.

# Custom options, F8B16, higher precision with good performance.
cache_options = {
    # 0, 2, 4, ..., 14, 15, etc. [0,16)
    "Bn_compute_blocks_ids": CacheType.range(0, 16, 2),
    # If the L1 difference is below this threshold, skip Bn blocks 
    # not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise, 
    # compute these blocks.
    "non_compute_blocks_diff_threshold": 0.08,
}

DBCache, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"

Baseline(L20x1) F1B0 (0.08) F1B0 (0.20) F8B8 (0.15) F12B12 (0.20) F16B16 (0.20)
24.85s 15.59s 8.58s 15.41s 15.11s 17.74s

🎉FBCache: First Block Cache

DBCache is a more general cache algorithm than FBCache. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can either use the original FBCache implementation directly or configure DBCache with F1B0 settings to achieve the same functionality.

from diffusers import FluxPipeline
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

# Using FBCache directly
cache_options = CacheType.default_options(CacheType.FBCache)

# Or using DBCache with F1B0. 
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
cache_options = {
    "cache_type": CacheType.DBCache,
    "warmup_steps": 8,
    "max_cached_steps": 8,   # -1 means no limit
    "Fn_compute_blocks": 1,  # Fn, F1, etc.
    "Bn_compute_blocks": 0,  # Bn, B0, etc.
    "residual_diff_threshold": 0.12,
}

apply_cache_on_pipe(pipe, **cache_options)

⚡️DBPrune: Dynamic Block Prune

We have further implemented a new Dynamic Block Prune algorithm based on Residual Caching for Diffusion Transformers, which is referred to as DBPrune. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals. DBPrune is currently in the experimental phase, and we kindly invite you to stay tuned for upcoming updates.

from diffusers import FluxPipeline
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

# Using DBPrune with default options
cache_options = CacheType.default_options(CacheType.DBPrune)

apply_cache_on_pipe(pipe, **cache_options)

We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of Fn and Bn for higher precision, or set up the non-prune blocks list non_prune_blocks_ids to avoid aggressive pruning. For example:

# Custom options for DBPrune
cache_options = {
    "cache_type": CacheType.DBPrune,
    "residual_diff_threshold": 0.05,
    # Never prune the first `Fn` and last `Bn` blocks.
    "Fn_compute_blocks": 8,  # default 1
    "Bn_compute_blocks": 8,  # default 0
    "warmup_steps": 8,  # default -1
    # Disables the pruning strategy when the previous 
    # pruned steps greater than this value.
    "max_pruned_steps": 12,  # default, -1 means no limit
    # Enable dynamic prune threshold within step, higher 
    # `max_dynamic_prune_threshold` value may introduce a more 
    # ageressive pruning strategy.
    "enable_dynamic_prune_threshold": True,
    "max_dynamic_prune_threshold": 2 * 0.05,
    # (New thresh) = mean(previous_block_diffs_within_step) * 1.25
    # (New thresh) = ((New thresh) if (New thresh) <
    # max_dynamic_prune_threshold else residual_diff_threshold)
    "dynamic_prune_threshold_relax_ratio": 1.25,
    # The step interval to update residual cache. For example, 
    # 2: means the update steps will be [0, 2, 4, ...].
    "residual_cache_update_interval": 1,
    # You can set non-prune blocks to avoid ageressive pruning. 
    # For example, FLUX.1 has 19 + 38 blocks, so we can set it 
    # to 0, 2, 4, ..., 56, etc.
    "non_prune_blocks_ids": [],
}

apply_cache_on_pipe(pipe, **cache_options)

Important

Please note that for GPUs with lower VRAM, DBPrune may not be suitable for use on video DiTs, as it caches the hidden states and residuals of each block, leading to higher GPU memory requirements. In such cases, please use DBCache, which only caches the hidden states and residuals of 2 blocks.

DBPrune, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"

Baseline(L20x1) Pruned(24%) Pruned(35%) Pruned(38%) Pruned(45%) Pruned(60%)
24.85s 19.43s 16.82s 15.95s 14.24s 10.66s

🎉Context Parallelism

CacheDiT are plug-and-play solutions that works hand-in-hand with ParaAttention. Users can easily tap into its Context Parallelism features for distributed inference. Firstly, install para-attn from PyPI:

pip3 install para-attn  # or install `para-attn` from sources.

Then, you can run DBCache or DBPrune with Context Parallelism on 4 GPUs:

import torch.distributed as dist
from diffusers import FluxPipeline
from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType

 # Init distributed process group
dist.init_process_group()
torch.cuda.set_device(dist.get_rank())

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

# Context Parallel from ParaAttention
parallelize_pipe(
    pipe, mesh=init_context_parallel_mesh(
        pipe.device.type, max_ulysses_dim_size=4
    )
)

# DBPrune with default options from this library
apply_cache_on_pipe(
    pipe, **CacheType.default_options(CacheType.DBPrune)
)

dist.destroy_process_group()

Then, run the python test script with torchrun:

torchrun --nproc_per_node=4 parallel_cache.py

🔥Torch Compile

By the way, CacheDiT is designed to work compatibly with torch.compile. You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:

apply_cache_on_pipe(
    pipe, **CacheType.default_options(CacheType.DBPrune)
)
# Compile the Transformer module
pipe.transformer = torch.compile(pipe.transformer)

However, users intending to use CacheDiT for DiT with dynamic input shapes should consider increasing the recompile limit of torch._dynamo. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.

torch._dynamo.config.recompile_limit = 96  # default is 8
torch._dynamo.config.accumulated_recompile_limit = 2048  # default is 256

👋Contribute

How to contribute? Star ⭐️ this repo to support us or check CONTRIBUTE.md.

©️License

We have followed the original License from ParaAttention, please check LICENSE for more details.

About

🤗CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for DiTs (DBCache, DBPrune, FBCache)🔥

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages