Skip to content

[core] FreeNoise #8948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
80e530f
initial work draft for freenoise; needs massive cleanup
a-r-r-o-w Jul 23, 2024
441d321
fix freeinit bug
a-r-r-o-w Jul 24, 2024
5d0f4c3
add animatediff controlnet implementation
a-r-r-o-w Jul 24, 2024
2e97ba7
Merge branch 'main' into freenoise
a-r-r-o-w Jul 24, 2024
690dad6
Merge branch 'main' into freenoise
a-r-r-o-w Jul 27, 2024
610f433
revert attention changes
a-r-r-o-w Jul 27, 2024
10b65b3
add freenoise
a-r-r-o-w Jul 27, 2024
a41f843
remove old helper functions
a-r-r-o-w Jul 27, 2024
f6897ae
add decode batch size param to all pipelines
a-r-r-o-w Jul 27, 2024
024e2da
make style
a-r-r-o-w Jul 27, 2024
1bb0984
fix copied from comments
a-r-r-o-w Jul 27, 2024
1b7bc00
make fix-copies
a-r-r-o-w Jul 27, 2024
dc96a8d
make style
a-r-r-o-w Jul 27, 2024
691facf
copy animatediff controlnet implementation from #8972
a-r-r-o-w Jul 27, 2024
5a60a62
add experimental support for num_frames not perfectly fitting context…
a-r-r-o-w Jul 27, 2024
58c2ddc
make unet motion model lora work again based on #8995
a-r-r-o-w Jul 27, 2024
7000186
copy load video utils from #8972
a-r-r-o-w Jul 27, 2024
c5db39f
copied from AnimateDiff::prepare_latents
a-r-r-o-w Jul 28, 2024
594d2d2
address the case where last batch of frames does not match length of …
a-r-r-o-w Jul 28, 2024
fb9ca34
decode_batch_size->vae_batch_size; batch vae encode support in animat…
a-r-r-o-w Jul 28, 2024
77ee296
revert sparsectrl and sdxl freenoise changes
a-r-r-o-w Jul 28, 2024
52884b3
revert pia
a-r-r-o-w Jul 28, 2024
1e2ef4d
add freenoise tests
a-r-r-o-w Jul 28, 2024
5d5a7ea
Merge branch 'main' into freenoise
a-r-r-o-w Jul 30, 2024
3d9b183
make fix-copies
a-r-r-o-w Jul 30, 2024
44e40a2
improve docstrings
a-r-r-o-w Jul 30, 2024
a61ffff
add freenoise tests to animatediff controlnet
a-r-r-o-w Jul 30, 2024
d82228e
update tests
a-r-r-o-w Jul 30, 2024
037ee07
Update src/diffusers/models/unets/unet_motion_model.py
a-r-r-o-w Jul 30, 2024
ac3d8c6
Merge branch 'main' into freenoise
a-r-r-o-w Aug 2, 2024
d19ddb4
add freenoise to animatediff pag
a-r-r-o-w Aug 2, 2024
12cc84a
address review comments
a-r-r-o-w Aug 2, 2024
6f48356
make style
a-r-r-o-w Aug 2, 2024
1f0ccfd
update tests
a-r-r-o-w Aug 2, 2024
6a4aab8
make fix-copies
a-r-r-o-w Aug 2, 2024
2f77c69
update
DN6 Aug 3, 2024
8564dc3
fix error message
a-r-r-o-w Aug 3, 2024
b32b1d7
remove copied from comment
a-r-r-o-w Aug 3, 2024
045ae36
fix imports in tests
a-r-r-o-w Aug 3, 2024
2d9aa42
update
DN6 Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"UNetMotionModel": lambda model_cls, weights: weights,
}


Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,17 @@ def __init__(
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes were made to initialize the FreeNoiseTransformerBlock correctly. I'm not sure how else we could determine these attributes in a "simple" way without accessing the interal pytorch dimensions which adds many many extra LOC after make style.

self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention

# We keep these boolean flags for backward-compatibility.
Expand Down
306 changes: 302 additions & 4 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint

from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From #8995 to fix LoRA in UNetMotionModel. Once that's merged, these changes shouldn't be visible here

from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward, _chunked_feed_forward
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Expand All @@ -33,7 +35,7 @@
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
Expand All @@ -53,6 +55,302 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@maybe_allow_in_graph
class FreeNoiseTransformerBlock(nn.Module):
r"""
A FreeNoise Transformer block.

Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""

def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
norm_eps: float = 1e-5,
final_dropout: bool = False,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
context_length: int = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention

self.set_free_noise_properties(context_length, context_stride, weighting_scheme)

# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"

if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)

self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm

if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)

if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None

# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made some assumptions here and removed all the branches handling different configurations of layernorms and attention. We support multiple norm_type in BasicTransformerBlock. I think that SD15 checkpoints, used in AnimateDiff, always use LayerNorm but am only 99% sure. LMK if any other norm types must be handled


self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)

# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none

# 3. Feed-forward
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)

self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0

def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
frame_indices = []
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
window_start = i
window_end = min(num_frames, i + self.context_length)
frame_indices.append((window_start, window_end))

return frame_indices

def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original FreeNoise implementation proposes using a pyramid weighted averaging (see Eq. 9 of the paper. However, the diffusion community found different weighting schemes that also seem to work well in practice. While I haven't tested it deeply, I would like to keep the implementation to extension in the future. For now, let's roll with the original unless we can test different methods qualitatively before next release

if weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + [num_frames // 2 + 1] + weights[::-1]
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")

return weights

def set_free_noise_properties(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what to name it, feel free to suggest. It's a helper function to change properties dynamically at inference from FreeNoiseMixin for already initialized FreeNoiseTransformerBlocks without doing the entire initialization part again

self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
) -> None:
self.context_length = context_length
self.context_stride = context_stride
self.weighting_scheme = weighting_scheme

def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
*args,
**kwargs,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}

# hidden_states: [B x H x W, F, C]
device = hidden_states.device
dtype = hidden_states.dtype

num_frames = hidden_states.size(1)
frame_indices = self._get_frame_indices(num_frames)
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames

# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
# [(0, 16), (4, 20), (8, 24), (10, 26)]
if not is_last_frame_batch_complete:
if num_frames < self.context_length:
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
last_frame_batch_length = num_frames - frame_indices[-1][1]
frame_indices.append((num_frames - self.context_length, num_frames))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original implementation here does not seem to support this case. Essentially, if we select a context_length and context_stride, or have to process number of frames, such that perfect frame-wise batching is not possible, we try and process the full context_length amount of frames BUT only accumulate on the unaccomodated frames. I tested it and it works well on cases like num_frames=26, context_length=16, context_stride=4


num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
accumulated_values = torch.zeros_like(hidden_states)

for i, (frame_start, frame_end) in enumerate(frame_indices):
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
# essentially a non-multiple of `context_length`.
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
weights *= frame_weights

hidden_states_chunk = hidden_states[:, frame_start:frame_end]

# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
# assert self.norm_type == "layer_norm"
norm_hidden_states = self.norm1(hidden_states_chunk)

if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)

attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)

hidden_states_chunk = attn_output + hidden_states_chunk
if hidden_states_chunk.ndim == 4:
hidden_states_chunk = hidden_states_chunk.squeeze(1)

# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states_chunk)

if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)

attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk

if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
accumulated_values[:, -last_frame_batch_length:] += (
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
)
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
else:
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specialized logic to handle unfitting frame batch case as described above. LMK if this needs to be more readable and possible suggestions


hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
).to(dtype)

# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)

if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)

hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)

return hidden_states


class MotionModules(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -231,7 +529,7 @@ def forward(self, sample):
pass


class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
r"""
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
sample shaped output.
Expand Down
Loading
Loading