Skip to content

Commit 9dbffc8

Browse files
a-r-r-o-wsayakpaul
andcommitted
PAG variant for HunyuanDiT, PAG refactor (#8936)
* copy hunyuandit pipeline * pag variant of hunyuan dit * add tests * update docs * make style * make fix-copies * Update src/diffusers/pipelines/pag/pag_utils.py * remove incorrect copied from * remove pag hunyuan attn procs to resolve conflicts * add pag attn procs again * new implementation for pag_utils * revert pag changes * add pag refactor back; update pixart sigma * update pixart pag tests * apply suggestions from review Co-Authored-By: yixu310@gmail.com * make style * update docs, fix tests * fix tests * fix test_components_function since list not accepted as valid __init__ param * apply patch to fix broken tests Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com> * make style * fix hunyuan tests --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent fa55429 commit 9dbffc8

16 files changed

+1737
-354
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,29 @@ The abstract from the paper is:
2020

2121
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
2222

23+
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
24+
25+
- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
26+
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
27+
- Partial identifier as a RegEx: `down_blocks.2`, or `attn1`
28+
- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]`
29+
30+
<Tip warning={true}>
31+
32+
Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.
33+
34+
</Tip>
35+
2336
## AnimateDiffPAGPipeline
2437
[[autodoc]] AnimateDiffPAGPipeline
2538
- all
2639
- __call__
2740

41+
## HunyuanDiTPAGPipeline
42+
[[autodoc]] HunyuanDiTPAGPipeline
43+
- all
44+
- __call__
45+
2846
## StableDiffusionPAGPipeline
2947
[[autodoc]] StableDiffusionPAGPipeline
3048
- all
@@ -59,4 +77,4 @@ The abstract from the paper is:
5977
## PixArtSigmaPAGPipeline
6078
[[autodoc]] PixArtSigmaPAGPipeline
6179
- all
62-
- __call__
80+
- __call__

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@
252252
"CycleDiffusionPipeline",
253253
"FluxPipeline",
254254
"HunyuanDiTControlNetPipeline",
255+
"HunyuanDiTPAGPipeline",
255256
"HunyuanDiTPipeline",
256257
"I2VGenXLPipeline",
257258
"IFImg2ImgPipeline",
@@ -675,6 +676,7 @@
675676
CycleDiffusionPipeline,
676677
FluxPipeline,
677678
HunyuanDiTControlNetPipeline,
679+
HunyuanDiTPAGPipeline,
678680
HunyuanDiTPipeline,
679681
I2VGenXLPipeline,
680682
IFImg2ImgPipeline,

src/diffusers/models/attention_processor.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,253 @@ def __call__(
21472147
return hidden_states
21482148

21492149

2150+
class PAGHunyuanAttnProcessor2_0:
2151+
r"""
2152+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2153+
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
2154+
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
2155+
"""
2156+
2157+
def __init__(self):
2158+
if not hasattr(F, "scaled_dot_product_attention"):
2159+
raise ImportError(
2160+
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2161+
)
2162+
2163+
def __call__(
2164+
self,
2165+
attn: Attention,
2166+
hidden_states: torch.Tensor,
2167+
encoder_hidden_states: Optional[torch.Tensor] = None,
2168+
attention_mask: Optional[torch.Tensor] = None,
2169+
temb: Optional[torch.Tensor] = None,
2170+
image_rotary_emb: Optional[torch.Tensor] = None,
2171+
) -> torch.Tensor:
2172+
from .embeddings import apply_rotary_emb
2173+
2174+
residual = hidden_states
2175+
if attn.spatial_norm is not None:
2176+
hidden_states = attn.spatial_norm(hidden_states, temb)
2177+
2178+
input_ndim = hidden_states.ndim
2179+
2180+
if input_ndim == 4:
2181+
batch_size, channel, height, width = hidden_states.shape
2182+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2183+
2184+
# chunk
2185+
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
2186+
2187+
# 1. Original Path
2188+
batch_size, sequence_length, _ = (
2189+
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2190+
)
2191+
2192+
if attention_mask is not None:
2193+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2194+
# scaled_dot_product_attention expects attention_mask shape to be
2195+
# (batch, heads, source_length, target_length)
2196+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2197+
2198+
if attn.group_norm is not None:
2199+
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
2200+
2201+
query = attn.to_q(hidden_states_org)
2202+
2203+
if encoder_hidden_states is None:
2204+
encoder_hidden_states = hidden_states_org
2205+
elif attn.norm_cross:
2206+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2207+
2208+
key = attn.to_k(encoder_hidden_states)
2209+
value = attn.to_v(encoder_hidden_states)
2210+
2211+
inner_dim = key.shape[-1]
2212+
head_dim = inner_dim // attn.heads
2213+
2214+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2215+
2216+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2217+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2218+
2219+
if attn.norm_q is not None:
2220+
query = attn.norm_q(query)
2221+
if attn.norm_k is not None:
2222+
key = attn.norm_k(key)
2223+
2224+
# Apply RoPE if needed
2225+
if image_rotary_emb is not None:
2226+
query = apply_rotary_emb(query, image_rotary_emb)
2227+
if not attn.is_cross_attention:
2228+
key = apply_rotary_emb(key, image_rotary_emb)
2229+
2230+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2231+
# TODO: add support for attn.scale when we move to Torch 2.1
2232+
hidden_states_org = F.scaled_dot_product_attention(
2233+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2234+
)
2235+
2236+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2237+
hidden_states_org = hidden_states_org.to(query.dtype)
2238+
2239+
# linear proj
2240+
hidden_states_org = attn.to_out[0](hidden_states_org)
2241+
# dropout
2242+
hidden_states_org = attn.to_out[1](hidden_states_org)
2243+
2244+
if input_ndim == 4:
2245+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
2246+
2247+
# 2. Perturbed Path
2248+
if attn.group_norm is not None:
2249+
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
2250+
2251+
hidden_states_ptb = attn.to_v(hidden_states_ptb)
2252+
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
2253+
2254+
# linear proj
2255+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
2256+
# dropout
2257+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
2258+
2259+
if input_ndim == 4:
2260+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
2261+
2262+
# cat
2263+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
2264+
2265+
if attn.residual_connection:
2266+
hidden_states = hidden_states + residual
2267+
2268+
hidden_states = hidden_states / attn.rescale_output_factor
2269+
2270+
return hidden_states
2271+
2272+
2273+
class PAGCFGHunyuanAttnProcessor2_0:
2274+
r"""
2275+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2276+
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
2277+
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
2278+
"""
2279+
2280+
def __init__(self):
2281+
if not hasattr(F, "scaled_dot_product_attention"):
2282+
raise ImportError(
2283+
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2284+
)
2285+
2286+
def __call__(
2287+
self,
2288+
attn: Attention,
2289+
hidden_states: torch.Tensor,
2290+
encoder_hidden_states: Optional[torch.Tensor] = None,
2291+
attention_mask: Optional[torch.Tensor] = None,
2292+
temb: Optional[torch.Tensor] = None,
2293+
image_rotary_emb: Optional[torch.Tensor] = None,
2294+
) -> torch.Tensor:
2295+
from .embeddings import apply_rotary_emb
2296+
2297+
residual = hidden_states
2298+
if attn.spatial_norm is not None:
2299+
hidden_states = attn.spatial_norm(hidden_states, temb)
2300+
2301+
input_ndim = hidden_states.ndim
2302+
2303+
if input_ndim == 4:
2304+
batch_size, channel, height, width = hidden_states.shape
2305+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2306+
2307+
# chunk
2308+
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
2309+
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
2310+
2311+
# 1. Original Path
2312+
batch_size, sequence_length, _ = (
2313+
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2314+
)
2315+
2316+
if attention_mask is not None:
2317+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2318+
# scaled_dot_product_attention expects attention_mask shape to be
2319+
# (batch, heads, source_length, target_length)
2320+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2321+
2322+
if attn.group_norm is not None:
2323+
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
2324+
2325+
query = attn.to_q(hidden_states_org)
2326+
2327+
if encoder_hidden_states is None:
2328+
encoder_hidden_states = hidden_states_org
2329+
elif attn.norm_cross:
2330+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2331+
2332+
key = attn.to_k(encoder_hidden_states)
2333+
value = attn.to_v(encoder_hidden_states)
2334+
2335+
inner_dim = key.shape[-1]
2336+
head_dim = inner_dim // attn.heads
2337+
2338+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2339+
2340+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2341+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2342+
2343+
if attn.norm_q is not None:
2344+
query = attn.norm_q(query)
2345+
if attn.norm_k is not None:
2346+
key = attn.norm_k(key)
2347+
2348+
# Apply RoPE if needed
2349+
if image_rotary_emb is not None:
2350+
query = apply_rotary_emb(query, image_rotary_emb)
2351+
if not attn.is_cross_attention:
2352+
key = apply_rotary_emb(key, image_rotary_emb)
2353+
2354+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2355+
# TODO: add support for attn.scale when we move to Torch 2.1
2356+
hidden_states_org = F.scaled_dot_product_attention(
2357+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2358+
)
2359+
2360+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2361+
hidden_states_org = hidden_states_org.to(query.dtype)
2362+
2363+
# linear proj
2364+
hidden_states_org = attn.to_out[0](hidden_states_org)
2365+
# dropout
2366+
hidden_states_org = attn.to_out[1](hidden_states_org)
2367+
2368+
if input_ndim == 4:
2369+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
2370+
2371+
# 2. Perturbed Path
2372+
if attn.group_norm is not None:
2373+
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
2374+
2375+
hidden_states_ptb = attn.to_v(hidden_states_ptb)
2376+
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
2377+
2378+
# linear proj
2379+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
2380+
# dropout
2381+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
2382+
2383+
if input_ndim == 4:
2384+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
2385+
2386+
# cat
2387+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
2388+
2389+
if attn.residual_connection:
2390+
hidden_states = hidden_states + residual
2391+
2392+
hidden_states = hidden_states / attn.rescale_output_factor
2393+
2394+
return hidden_states
2395+
2396+
21502397
class LuminaAttnProcessor2_0:
21512398
r"""
21522399
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
@@ -3468,4 +3715,6 @@ def __init__(self):
34683715
CustomDiffusionAttnProcessor2_0,
34693716
PAGCFGIdentitySelfAttnProcessor2_0,
34703717
PAGIdentitySelfAttnProcessor2_0,
3718+
PAGCFGHunyuanAttnProcessor2_0,
3719+
PAGHunyuanAttnProcessor2_0,
34713720
]

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
_import_structure["pag"].extend(
146146
[
147147
"AnimateDiffPAGPipeline",
148+
"HunyuanDiTPAGPipeline",
148149
"StableDiffusionPAGPipeline",
149150
"StableDiffusionControlNetPAGPipeline",
150151
"StableDiffusionXLPAGPipeline",
@@ -532,6 +533,7 @@
532533
from .musicldm import MusicLDMPipeline
533534
from .pag import (
534535
AnimateDiffPAGPipeline,
536+
HunyuanDiTPAGPipeline,
535537
PixArtSigmaPAGPipeline,
536538
StableDiffusionControlNetPAGPipeline,
537539
StableDiffusionPAGPipeline,

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
5151
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
5252
from .pag import (
53+
HunyuanDiTPAGPipeline,
5354
PixArtSigmaPAGPipeline,
5455
StableDiffusionControlNetPAGPipeline,
5556
StableDiffusionPAGPipeline,
@@ -85,6 +86,7 @@
8586
("stable-diffusion-3", StableDiffusion3Pipeline),
8687
("if", IFPipeline),
8788
("hunyuan", HunyuanDiTPipeline),
89+
("hunyuan-pag", HunyuanDiTPAGPipeline),
8890
("kandinsky", KandinskyCombinedPipeline),
8991
("kandinsky22", KandinskyV22CombinedPipeline),
9092
("kandinsky3", Kandinsky3Pipeline),

src/diffusers/pipelines/pag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
else:
2525
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
2626
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
27+
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
2728
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
2829
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
2930
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
@@ -41,6 +42,7 @@
4142
else:
4243
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
4344
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
45+
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
4446
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
4547
from .pipeline_pag_sd import StableDiffusionPAGPipeline
4648
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline

0 commit comments

Comments
 (0)