Skip to content

Commit c4c822b

Browse files
committed
[Core] fix QKV fusion for attention (#8829)
* start debugging the problem, * start * fix * fix * fix imports. * handle hunyuan * remove residuals. * add a check for making sure there's appropriate procs. * add more rigor to the tests. * fix test * remove redundant check * fix-copies * move check_qkv_fusion_matches_attn_procs_length and check_qkv_fusion_processors_exist.
1 parent df4e3f4 commit c4c822b

File tree

13 files changed

+209
-9
lines changed

13 files changed

+209
-9
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,21 @@ def fuse_projections(self, fuse=True):
677677
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
678678
self.to_kv.bias.copy_(concatenated_bias)
679679

680+
# handle added projections for SD3 and others.
681+
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
682+
concatenated_weights = torch.cat(
683+
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
684+
)
685+
in_features = concatenated_weights.shape[1]
686+
out_features = concatenated_weights.shape[0]
687+
688+
self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
689+
self.to_added_qkv.weight.copy_(concatenated_weights)
690+
concatenated_bias = torch.cat(
691+
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
692+
)
693+
self.to_added_qkv.bias.copy_(concatenated_bias)
694+
680695
self.fused_projections = fuse
681696

682697

@@ -1708,6 +1723,109 @@ def __call__(
17081723
return hidden_states
17091724

17101725

1726+
class FusedHunyuanAttnProcessor2_0:
1727+
r"""
1728+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
1729+
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
1730+
query and key vector.
1731+
"""
1732+
1733+
def __init__(self):
1734+
if not hasattr(F, "scaled_dot_product_attention"):
1735+
raise ImportError(
1736+
"FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1737+
)
1738+
1739+
def __call__(
1740+
self,
1741+
attn: Attention,
1742+
hidden_states: torch.Tensor,
1743+
encoder_hidden_states: Optional[torch.Tensor] = None,
1744+
attention_mask: Optional[torch.Tensor] = None,
1745+
temb: Optional[torch.Tensor] = None,
1746+
image_rotary_emb: Optional[torch.Tensor] = None,
1747+
) -> torch.Tensor:
1748+
from .embeddings import apply_rotary_emb
1749+
1750+
residual = hidden_states
1751+
if attn.spatial_norm is not None:
1752+
hidden_states = attn.spatial_norm(hidden_states, temb)
1753+
1754+
input_ndim = hidden_states.ndim
1755+
1756+
if input_ndim == 4:
1757+
batch_size, channel, height, width = hidden_states.shape
1758+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1759+
1760+
batch_size, sequence_length, _ = (
1761+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1762+
)
1763+
1764+
if attention_mask is not None:
1765+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1766+
# scaled_dot_product_attention expects attention_mask shape to be
1767+
# (batch, heads, source_length, target_length)
1768+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1769+
1770+
if attn.group_norm is not None:
1771+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1772+
1773+
if encoder_hidden_states is None:
1774+
qkv = attn.to_qkv(hidden_states)
1775+
split_size = qkv.shape[-1] // 3
1776+
query, key, value = torch.split(qkv, split_size, dim=-1)
1777+
else:
1778+
if attn.norm_cross:
1779+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1780+
query = attn.to_q(hidden_states)
1781+
1782+
kv = attn.to_kv(encoder_hidden_states)
1783+
split_size = kv.shape[-1] // 2
1784+
key, value = torch.split(kv, split_size, dim=-1)
1785+
1786+
inner_dim = key.shape[-1]
1787+
head_dim = inner_dim // attn.heads
1788+
1789+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1790+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1791+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1792+
1793+
if attn.norm_q is not None:
1794+
query = attn.norm_q(query)
1795+
if attn.norm_k is not None:
1796+
key = attn.norm_k(key)
1797+
1798+
# Apply RoPE if needed
1799+
if image_rotary_emb is not None:
1800+
query = apply_rotary_emb(query, image_rotary_emb)
1801+
if not attn.is_cross_attention:
1802+
key = apply_rotary_emb(key, image_rotary_emb)
1803+
1804+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1805+
# TODO: add support for attn.scale when we move to Torch 2.1
1806+
hidden_states = F.scaled_dot_product_attention(
1807+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1808+
)
1809+
1810+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1811+
hidden_states = hidden_states.to(query.dtype)
1812+
1813+
# linear proj
1814+
hidden_states = attn.to_out[0](hidden_states)
1815+
# dropout
1816+
hidden_states = attn.to_out[1](hidden_states)
1817+
1818+
if input_ndim == 4:
1819+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1820+
1821+
if attn.residual_connection:
1822+
hidden_states = hidden_states + residual
1823+
1824+
hidden_states = hidden_states / attn.rescale_output_factor
1825+
1826+
return hidden_states
1827+
1828+
17111829
class LuminaAttnProcessor2_0:
17121830
r"""
17131831
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
AttentionProcessor,
2727
AttnAddedKVProcessor,
2828
AttnProcessor,
29+
FusedAttnProcessor2_0,
2930
)
3031
from ..modeling_outputs import AutoencoderKLOutput
3132
from ..modeling_utils import ModelMixin
@@ -492,6 +493,8 @@ def fuse_qkv_projections(self):
492493
if isinstance(module, Attention):
493494
module.fuse_projections(fuse=True)
494495

496+
self.set_attn_processor(FusedAttnProcessor2_0())
497+
495498
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
496499
def unfuse_qkv_projections(self):
497500
"""Disables the fused QKV projection if enabled.

src/diffusers/models/controlnet_sd3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..configuration_utils import ConfigMixin, register_to_config
2323
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
2424
from ..models.attention import JointTransformerBlock
25-
from ..models.attention_processor import Attention, AttentionProcessor
25+
from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
2626
from ..models.modeling_outputs import Transformer2DModelOutput
2727
from ..models.modeling_utils import ModelMixin
2828
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
@@ -196,7 +196,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
196196
for name, module in self.named_children():
197197
fn_recursive_attn_processor(name, module, processor)
198198

199-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
199+
# Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
200200
def fuse_qkv_projections(self):
201201
"""
202202
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -220,6 +220,8 @@ def fuse_qkv_projections(self):
220220
if isinstance(module, Attention):
221221
module.fuse_projections(fuse=True)
222222

223+
self.set_attn_processor(FusedJointAttnProcessor2_0())
224+
223225
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
224226
def unfuse_qkv_projections(self):
225227
"""Disables the fused QKV projection if enabled.

src/diffusers/models/controlnet_xs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
AttentionProcessor,
3030
AttnAddedKVProcessor,
3131
AttnProcessor,
32+
FusedAttnProcessor2_0,
3233
)
3334
from .controlnet import ControlNetConditioningEmbedding
3435
from .embeddings import TimestepEmbedding, Timesteps
@@ -1001,6 +1002,8 @@ def fuse_qkv_projections(self):
10011002
if isinstance(module, Attention):
10021003
module.fuse_projections(fuse=True)
10031004

1005+
self.set_attn_processor(FusedAttnProcessor2_0())
1006+
10041007
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
10051008
def unfuse_qkv_projections(self):
10061009
"""Disables the fused QKV projection if enabled.

src/diffusers/models/transformers/hunyuan_transformer_2d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ...utils import logging
2121
from ...utils.torch_utils import maybe_allow_in_graph
2222
from ..attention import FeedForward
23-
from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0
23+
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
2424
from ..embeddings import (
2525
HunyuanCombinedTimestepTextSizeStyleEmbedding,
2626
PatchEmbed,
@@ -317,7 +317,7 @@ def __init__(
317317
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
318318
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
319319

320-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
320+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0
321321
def fuse_qkv_projections(self):
322322
"""
323323
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -341,6 +341,8 @@ def fuse_qkv_projections(self):
341341
if isinstance(module, Attention):
342342
module.fuse_projections(fuse=True)
343343

344+
self.set_attn_processor(FusedHunyuanAttnProcessor2_0())
345+
344346
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
345347
def unfuse_qkv_projections(self):
346348
"""Disables the fused QKV projection if enabled.

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2525
from ...models.attention import JointTransformerBlock
26-
from ...models.attention_processor import Attention, AttentionProcessor
26+
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
2727
from ...models.modeling_utils import ModelMixin
2828
from ...models.normalization import AdaLayerNormContinuous
2929
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
@@ -211,7 +211,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
211211
for name, module in self.named_children():
212212
fn_recursive_attn_processor(name, module, processor)
213213

214-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
214+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
215215
def fuse_qkv_projections(self):
216216
"""
217217
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -235,6 +235,8 @@ def fuse_qkv_projections(self):
235235
if isinstance(module, Attention):
236236
module.fuse_projections(fuse=True)
237237

238+
self.set_attn_processor(FusedJointAttnProcessor2_0())
239+
238240
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
239241
def unfuse_qkv_projections(self):
240242
"""Disables the fused QKV projection if enabled.

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
AttentionProcessor,
3131
AttnAddedKVProcessor,
3232
AttnProcessor,
33+
FusedAttnProcessor2_0,
3334
)
3435
from ..embeddings import (
3536
GaussianFourierProjection,
@@ -890,6 +891,8 @@ def fuse_qkv_projections(self):
890891
if isinstance(module, Attention):
891892
module.fuse_projections(fuse=True)
892893

894+
self.set_attn_processor(FusedAttnProcessor2_0())
895+
893896
def unfuse_qkv_projections(self):
894897
"""Disables the fused QKV projection if enabled.
895898

src/diffusers/models/unets/unet_3d_condition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
AttentionProcessor,
3232
AttnAddedKVProcessor,
3333
AttnProcessor,
34+
FusedAttnProcessor2_0,
3435
)
3536
from ..embeddings import TimestepEmbedding, Timesteps
3637
from ..modeling_utils import ModelMixin
@@ -532,6 +533,8 @@ def fuse_qkv_projections(self):
532533
if isinstance(module, Attention):
533534
module.fuse_projections(fuse=True)
534535

536+
self.set_attn_processor(FusedAttnProcessor2_0())
537+
535538
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
536539
def unfuse_qkv_projections(self):
537540
"""Disables the fused QKV projection if enabled.

src/diffusers/models/unets/unet_i2vgen_xl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
AttentionProcessor,
3030
AttnAddedKVProcessor,
3131
AttnProcessor,
32+
FusedAttnProcessor2_0,
3233
)
3334
from ..embeddings import TimestepEmbedding, Timesteps
3435
from ..modeling_utils import ModelMixin
@@ -498,6 +499,8 @@ def fuse_qkv_projections(self):
498499
if isinstance(module, Attention):
499500
module.fuse_projections(fuse=True)
500501

502+
self.set_attn_processor(FusedAttnProcessor2_0())
503+
501504
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
502505
def unfuse_qkv_projections(self):
503506
"""Disables the fused QKV projection if enabled.

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
AttnAddedKVProcessor,
3030
AttnProcessor,
3131
AttnProcessor2_0,
32+
FusedAttnProcessor2_0,
3233
IPAdapterAttnProcessor,
3334
IPAdapterAttnProcessor2_0,
3435
)
@@ -929,6 +930,8 @@ def fuse_qkv_projections(self):
929930
if isinstance(module, Attention):
930931
module.fuse_projections(fuse=True)
931932

933+
self.set_attn_processor(FusedAttnProcessor2_0())
934+
932935
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
933936
def unfuse_qkv_projections(self):
934937
"""Disables the fused QKV projection if enabled.

tests/pipelines/hunyuan_dit/test_hunyuan_dit.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
)
3737

3838
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
39-
from ..test_pipelines_common import PipelineTesterMixin, to_np
39+
from ..test_pipelines_common import (
40+
PipelineTesterMixin,
41+
check_qkv_fusion_matches_attn_procs_length,
42+
check_qkv_fusion_processors_exist,
43+
to_np,
44+
)
4045

4146

4247
enable_full_determinism()
@@ -261,6 +266,16 @@ def test_fused_qkv_projections(self):
261266
original_image_slice = image[0, -3:, -3:, -1]
262267

263268
pipe.transformer.fuse_qkv_projections()
269+
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
270+
# to the pipeline level.
271+
pipe.transformer.fuse_qkv_projections()
272+
assert check_qkv_fusion_processors_exist(
273+
pipe.transformer
274+
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
275+
assert check_qkv_fusion_matches_attn_procs_length(
276+
pipe.transformer, pipe.transformer.original_attn_processors
277+
), "Something wrong with the attention processors concerning the fused QKV projections."
278+
264279
inputs = self.get_dummy_inputs(device)
265280
inputs["return_dict"] = False
266281
image_fused = pipe(**inputs)[0]

tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
torch_device,
1414
)
1515

16-
from ..test_pipelines_common import PipelineTesterMixin
16+
from ..test_pipelines_common import (
17+
PipelineTesterMixin,
18+
check_qkv_fusion_matches_attn_procs_length,
19+
check_qkv_fusion_processors_exist,
20+
)
1721

1822

1923
class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -191,7 +195,16 @@ def test_fused_qkv_projections(self):
191195
image = pipe(**inputs).images
192196
original_image_slice = image[0, -3:, -3:, -1]
193197

198+
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
199+
# to the pipeline level.
194200
pipe.transformer.fuse_qkv_projections()
201+
assert check_qkv_fusion_processors_exist(
202+
pipe.transformer
203+
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
204+
assert check_qkv_fusion_matches_attn_procs_length(
205+
pipe.transformer, pipe.transformer.original_attn_processors
206+
), "Something wrong with the attention processors concerning the fused QKV projections."
207+
195208
inputs = self.get_dummy_inputs(device)
196209
image = pipe(**inputs).images
197210
image_slice_fused = image[0, -3:, -3:, -1]

0 commit comments

Comments
 (0)