-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[core] FreeNoise #8948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] FreeNoise #8948
Changes from 23 commits
80e530f
441d321
5d0f4c3
2e97ba7
690dad6
610f433
10b65b3
a41f843
f6897ae
024e2da
1bb0984
1b7bc00
dc96a8d
691facf
5a60a62
58c2ddc
7000186
c5db39f
594d2d2
fb9ca34
77ee296
52884b3
1e2ef4d
5d5a7ea
3d9b183
44e40a2
a61ffff
d82228e
037ee07
ac3d8c6
d19ddb4
12cc84a
6f48356
1f0ccfd
6a4aab8
2f77c69
8564dc3
b32b1d7
045ae36
2d9aa42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -272,6 +272,17 @@ def __init__( | |
attention_out_bias: bool = True, | ||
): | ||
super().__init__() | ||
self.dim = dim | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes were made to initialize the FreeNoiseTransformerBlock correctly. I'm not sure how else we could determine these attributes in a "simple" way without accessing the interal pytorch dimensions which adds many many extra LOC after |
||
self.num_attention_heads = num_attention_heads | ||
self.attention_head_dim = attention_head_dim | ||
self.dropout = dropout | ||
self.cross_attention_dim = cross_attention_dim | ||
self.activation_fn = activation_fn | ||
self.attention_bias = attention_bias | ||
self.double_self_attention = double_self_attention | ||
self.norm_elementwise_affine = norm_elementwise_affine | ||
self.positional_embeddings = positional_embeddings | ||
self.num_positional_embeddings = num_positional_embeddings | ||
self.only_cross_attention = only_cross_attention | ||
|
||
# We keep these boolean flags for backward-compatibility. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,16 +11,18 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Dict, Optional, Tuple, Union | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
|
||
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config | ||
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin | ||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From #8995 to fix LoRA in UNetMotionModel. Once that's merged, these changes shouldn't be visible here |
||
from ...utils import logging | ||
from ...utils.torch_utils import maybe_allow_in_graph | ||
from ..attention import FeedForward, _chunked_feed_forward | ||
from ..attention_processor import ( | ||
ADDED_KV_ATTENTION_PROCESSORS, | ||
CROSS_ATTENTION_PROCESSORS, | ||
|
@@ -33,7 +35,7 @@ | |
IPAdapterAttnProcessor, | ||
IPAdapterAttnProcessor2_0, | ||
) | ||
from ..embeddings import TimestepEmbedding, Timesteps | ||
from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps | ||
from ..modeling_utils import ModelMixin | ||
from ..transformers.transformer_temporal import TransformerTemporalModel | ||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn | ||
|
@@ -53,6 +55,302 @@ | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@maybe_allow_in_graph | ||
class FreeNoiseTransformerBlock(nn.Module): | ||
r""" | ||
A FreeNoise Transformer block. | ||
|
||
Parameters: | ||
dim (`int`): The number of channels in the input and output. | ||
num_attention_heads (`int`): The number of heads to use for multi-head attention. | ||
attention_head_dim (`int`): The number of channels in each head. | ||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | ||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. | ||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. | ||
num_embeds_ada_norm (: | ||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. | ||
attention_bias (: | ||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. | ||
only_cross_attention (`bool`, *optional*): | ||
Whether to use only cross-attention layers. In this case two cross attention layers are used. | ||
double_self_attention (`bool`, *optional*): | ||
Whether to use two self-attention layers. In this case no cross attention layers are used. | ||
upcast_attention (`bool`, *optional*): | ||
Whether to upcast the attention computation to float32. This is useful for mixed precision training. | ||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): | ||
Whether to use learnable elementwise affine parameters for normalization. | ||
norm_type (`str`, *optional*, defaults to `"layer_norm"`): | ||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. | ||
final_dropout (`bool` *optional*, defaults to False): | ||
Whether to apply a final dropout after the last feed-forward layer. | ||
attention_type (`str`, *optional*, defaults to `"default"`): | ||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. | ||
positional_embeddings (`str`, *optional*, defaults to `None`): | ||
The type of positional embeddings to apply to. | ||
num_positional_embeddings (`int`, *optional*, defaults to `None`): | ||
The maximum number of positional embeddings to apply. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
num_attention_heads: int, | ||
attention_head_dim: int, | ||
dropout: float = 0.0, | ||
cross_attention_dim: Optional[int] = None, | ||
activation_fn: str = "geglu", | ||
num_embeds_ada_norm: Optional[int] = None, | ||
attention_bias: bool = False, | ||
only_cross_attention: bool = False, | ||
double_self_attention: bool = False, | ||
upcast_attention: bool = False, | ||
norm_elementwise_affine: bool = True, | ||
norm_type: str = "layer_norm", | ||
norm_eps: float = 1e-5, | ||
final_dropout: bool = False, | ||
positional_embeddings: Optional[str] = None, | ||
num_positional_embeddings: Optional[int] = None, | ||
ff_inner_dim: Optional[int] = None, | ||
ff_bias: bool = True, | ||
attention_out_bias: bool = True, | ||
context_length: int = 16, | ||
context_stride: int = 4, | ||
weighting_scheme: str = "pyramid", | ||
): | ||
super().__init__() | ||
self.dim = dim | ||
self.num_attention_heads = num_attention_heads | ||
self.attention_head_dim = attention_head_dim | ||
self.dropout = dropout | ||
self.cross_attention_dim = cross_attention_dim | ||
self.activation_fn = activation_fn | ||
self.attention_bias = attention_bias | ||
self.double_self_attention = double_self_attention | ||
self.norm_elementwise_affine = norm_elementwise_affine | ||
self.positional_embeddings = positional_embeddings | ||
self.num_positional_embeddings = num_positional_embeddings | ||
self.only_cross_attention = only_cross_attention | ||
|
||
self.set_free_noise_properties(context_length, context_stride, weighting_scheme) | ||
|
||
# We keep these boolean flags for backward-compatibility. | ||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" | ||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" | ||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single" | ||
self.use_layer_norm = norm_type == "layer_norm" | ||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" | ||
|
||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: | ||
raise ValueError( | ||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" | ||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." | ||
) | ||
|
||
self.norm_type = norm_type | ||
self.num_embeds_ada_norm = num_embeds_ada_norm | ||
|
||
if positional_embeddings and (num_positional_embeddings is None): | ||
raise ValueError( | ||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." | ||
) | ||
|
||
if positional_embeddings == "sinusoidal": | ||
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) | ||
else: | ||
self.pos_embed = None | ||
|
||
# Define 3 blocks. Each block has its own normalization layer. | ||
# 1. Self-Attn | ||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've made some assumptions here and removed all the branches handling different configurations of layernorms and attention. We support multiple |
||
|
||
self.attn1 = Attention( | ||
query_dim=dim, | ||
heads=num_attention_heads, | ||
dim_head=attention_head_dim, | ||
dropout=dropout, | ||
bias=attention_bias, | ||
cross_attention_dim=cross_attention_dim if only_cross_attention else None, | ||
upcast_attention=upcast_attention, | ||
out_bias=attention_out_bias, | ||
) | ||
|
||
# 2. Cross-Attn | ||
if cross_attention_dim is not None or double_self_attention: | ||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) | ||
|
||
self.attn2 = Attention( | ||
query_dim=dim, | ||
cross_attention_dim=cross_attention_dim if not double_self_attention else None, | ||
heads=num_attention_heads, | ||
dim_head=attention_head_dim, | ||
dropout=dropout, | ||
bias=attention_bias, | ||
upcast_attention=upcast_attention, | ||
out_bias=attention_out_bias, | ||
) # is self-attn if encoder_hidden_states is none | ||
|
||
# 3. Feed-forward | ||
self.ff = FeedForward( | ||
dim, | ||
dropout=dropout, | ||
activation_fn=activation_fn, | ||
final_dropout=final_dropout, | ||
inner_dim=ff_inner_dim, | ||
bias=ff_bias, | ||
) | ||
|
||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) | ||
|
||
# let chunk size default to None | ||
self._chunk_size = None | ||
self._chunk_dim = 0 | ||
|
||
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: | ||
frame_indices = [] | ||
for i in range(0, num_frames - self.context_length + 1, self.context_stride): | ||
window_start = i | ||
window_end = min(num_frames, i + self.context_length) | ||
frame_indices.append((window_start, window_end)) | ||
|
||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return frame_indices | ||
|
||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original FreeNoise implementation proposes using a pyramid weighted averaging (see Eq. 9 of the paper. However, the diffusion community found different weighting schemes that also seem to work well in practice. While I haven't tested it deeply, I would like to keep the implementation to extension in the future. For now, let's roll with the original unless we can test different methods qualitatively before next release |
||
if weighting_scheme == "pyramid": | ||
if num_frames % 2 == 0: | ||
# num_frames = 4 => [1, 2, 2, 1] | ||
weights = list(range(1, num_frames // 2 + 1)) | ||
weights = weights + weights[::-1] | ||
else: | ||
# num_frames = 5 => [1, 2, 3, 2, 1] | ||
weights = list(range(1, num_frames // 2 + 1)) | ||
weights = weights + [num_frames // 2 + 1] + weights[::-1] | ||
else: | ||
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") | ||
|
||
return weights | ||
|
||
def set_free_noise_properties( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what to name it, feel free to suggest. It's a helper function to change properties dynamically at inference from FreeNoiseMixin for already initialized FreeNoiseTransformerBlocks without doing the entire initialization part again |
||
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" | ||
) -> None: | ||
self.context_length = context_length | ||
self.context_stride = context_stride | ||
self.weighting_scheme = weighting_scheme | ||
|
||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: | ||
# Sets chunk feed-forward | ||
self._chunk_size = chunk_size | ||
self._chunk_dim = dim | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
encoder_attention_mask: Optional[torch.Tensor] = None, | ||
cross_attention_kwargs: Dict[str, Any] = None, | ||
*args, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
if cross_attention_kwargs is not None: | ||
if cross_attention_kwargs.get("scale", None) is not None: | ||
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | ||
|
||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} | ||
|
||
# hidden_states: [B x H x W, F, C] | ||
device = hidden_states.device | ||
dtype = hidden_states.dtype | ||
|
||
num_frames = hidden_states.size(1) | ||
frame_indices = self._get_frame_indices(num_frames) | ||
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) | ||
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) | ||
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames | ||
|
||
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length | ||
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: | ||
# [(0, 16), (4, 20), (8, 24), (10, 26)] | ||
if not is_last_frame_batch_complete: | ||
if num_frames < self.context_length: | ||
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") | ||
last_frame_batch_length = num_frames - frame_indices[-1][1] | ||
frame_indices.append((num_frames - self.context_length, num_frames)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original implementation here does not seem to support this case. Essentially, if we select a context_length and context_stride, or have to process number of frames, such that perfect frame-wise batching is not possible, we try and process the full context_length amount of frames BUT only accumulate on the unaccomodated frames. I tested it and it works well on cases like |
||
|
||
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) | ||
accumulated_values = torch.zeros_like(hidden_states) | ||
|
||
for i, (frame_start, frame_end) in enumerate(frame_indices): | ||
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle | ||
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or | ||
# essentially a non-multiple of `context_length`. | ||
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) | ||
weights *= frame_weights | ||
|
||
hidden_states_chunk = hidden_states[:, frame_start:frame_end] | ||
|
||
# Notice that normalization is always applied before the real computation in the following blocks. | ||
# 1. Self-Attention | ||
# assert self.norm_type == "layer_norm" | ||
norm_hidden_states = self.norm1(hidden_states_chunk) | ||
|
||
if self.pos_embed is not None: | ||
norm_hidden_states = self.pos_embed(norm_hidden_states) | ||
|
||
attn_output = self.attn1( | ||
norm_hidden_states, | ||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, | ||
attention_mask=attention_mask, | ||
**cross_attention_kwargs, | ||
) | ||
|
||
hidden_states_chunk = attn_output + hidden_states_chunk | ||
if hidden_states_chunk.ndim == 4: | ||
hidden_states_chunk = hidden_states_chunk.squeeze(1) | ||
|
||
# 2. Cross-Attention | ||
if self.attn2 is not None: | ||
norm_hidden_states = self.norm2(hidden_states_chunk) | ||
|
||
if self.pos_embed is not None and self.norm_type != "ada_norm_single": | ||
norm_hidden_states = self.pos_embed(norm_hidden_states) | ||
|
||
attn_output = self.attn2( | ||
norm_hidden_states, | ||
encoder_hidden_states=encoder_hidden_states, | ||
attention_mask=encoder_attention_mask, | ||
**cross_attention_kwargs, | ||
) | ||
hidden_states_chunk = attn_output + hidden_states_chunk | ||
|
||
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: | ||
accumulated_values[:, -last_frame_batch_length:] += ( | ||
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] | ||
) | ||
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] | ||
else: | ||
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights | ||
num_times_accumulated[:, frame_start:frame_end] += weights | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specialized logic to handle unfitting frame batch case as described above. LMK if this needs to be more readable and possible suggestions |
||
|
||
hidden_states = torch.where( | ||
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values | ||
).to(dtype) | ||
|
||
# 3. Feed-forward | ||
norm_hidden_states = self.norm3(hidden_states) | ||
|
||
if self._chunk_size is not None: | ||
# "feed_forward_chunk_size" can be used to save memory | ||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) | ||
else: | ||
ff_output = self.ff(norm_hidden_states) | ||
|
||
hidden_states = ff_output + hidden_states | ||
if hidden_states.ndim == 4: | ||
hidden_states = hidden_states.squeeze(1) | ||
|
||
return hidden_states | ||
|
||
|
||
class MotionModules(nn.Module): | ||
def __init__( | ||
self, | ||
|
@@ -231,7 +529,7 @@ def forward(self, sample): | |
pass | ||
|
||
|
||
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): | ||
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): | ||
r""" | ||
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a | ||
sample shaped output. | ||
|
Uh oh!
There was an error while loading. Please reload this page.