Skip to content

[core] FreeNoise #8948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
80e530f
initial work draft for freenoise; needs massive cleanup
a-r-r-o-w Jul 23, 2024
441d321
fix freeinit bug
a-r-r-o-w Jul 24, 2024
5d0f4c3
add animatediff controlnet implementation
a-r-r-o-w Jul 24, 2024
2e97ba7
Merge branch 'main' into freenoise
a-r-r-o-w Jul 24, 2024
690dad6
Merge branch 'main' into freenoise
a-r-r-o-w Jul 27, 2024
610f433
revert attention changes
a-r-r-o-w Jul 27, 2024
10b65b3
add freenoise
a-r-r-o-w Jul 27, 2024
a41f843
remove old helper functions
a-r-r-o-w Jul 27, 2024
f6897ae
add decode batch size param to all pipelines
a-r-r-o-w Jul 27, 2024
024e2da
make style
a-r-r-o-w Jul 27, 2024
1bb0984
fix copied from comments
a-r-r-o-w Jul 27, 2024
1b7bc00
make fix-copies
a-r-r-o-w Jul 27, 2024
dc96a8d
make style
a-r-r-o-w Jul 27, 2024
691facf
copy animatediff controlnet implementation from #8972
a-r-r-o-w Jul 27, 2024
5a60a62
add experimental support for num_frames not perfectly fitting context…
a-r-r-o-w Jul 27, 2024
58c2ddc
make unet motion model lora work again based on #8995
a-r-r-o-w Jul 27, 2024
7000186
copy load video utils from #8972
a-r-r-o-w Jul 27, 2024
c5db39f
copied from AnimateDiff::prepare_latents
a-r-r-o-w Jul 28, 2024
594d2d2
address the case where last batch of frames does not match length of …
a-r-r-o-w Jul 28, 2024
fb9ca34
decode_batch_size->vae_batch_size; batch vae encode support in animat…
a-r-r-o-w Jul 28, 2024
77ee296
revert sparsectrl and sdxl freenoise changes
a-r-r-o-w Jul 28, 2024
52884b3
revert pia
a-r-r-o-w Jul 28, 2024
1e2ef4d
add freenoise tests
a-r-r-o-w Jul 28, 2024
5d5a7ea
Merge branch 'main' into freenoise
a-r-r-o-w Jul 30, 2024
3d9b183
make fix-copies
a-r-r-o-w Jul 30, 2024
44e40a2
improve docstrings
a-r-r-o-w Jul 30, 2024
a61ffff
add freenoise tests to animatediff controlnet
a-r-r-o-w Jul 30, 2024
d82228e
update tests
a-r-r-o-w Jul 30, 2024
037ee07
Update src/diffusers/models/unets/unet_motion_model.py
a-r-r-o-w Jul 30, 2024
ac3d8c6
Merge branch 'main' into freenoise
a-r-r-o-w Aug 2, 2024
d19ddb4
add freenoise to animatediff pag
a-r-r-o-w Aug 2, 2024
12cc84a
address review comments
a-r-r-o-w Aug 2, 2024
6f48356
make style
a-r-r-o-w Aug 2, 2024
1f0ccfd
update tests
a-r-r-o-w Aug 2, 2024
6a4aab8
make fix-copies
a-r-r-o-w Aug 2, 2024
2f77c69
update
DN6 Aug 3, 2024
8564dc3
fix error message
a-r-r-o-w Aug 3, 2024
b32b1d7
remove copied from comment
a-r-r-o-w Aug 3, 2024
045ae36
fix imports in tests
a-r-r-o-w Aug 3, 2024
2d9aa42
update
DN6 Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 325 additions & 1 deletion src/diffusers/models/attention.py

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def custom_forward(*inputs):

else:
hidden_states = resnet(hidden_states, temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

output_states = output_states + (hidden_states,)
Expand Down Expand Up @@ -536,6 +537,7 @@ def custom_forward(*inputs):
)[0]
else:
hidden_states = resnet(hidden_states, temb)

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -761,6 +763,7 @@ def custom_forward(*inputs):
)[0]
else:
hidden_states = resnet(hidden_states, temb)

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -921,9 +924,9 @@ def custom_forward(*inputs):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)

else:
hidden_states = resnet(hidden_states, temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

if self.upsamplers is not None:
Expand Down Expand Up @@ -1923,7 +1926,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
Expand Down Expand Up @@ -1953,7 +1955,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self) -> None:
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
Expand Down
38 changes: 27 additions & 11 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down Expand Up @@ -72,6 +73,7 @@ class AnimateDiffPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for text-to-video generation.
Expand Down Expand Up @@ -394,15 +396,20 @@ def prepare_ip_adapter_image_embeds(

return ip_adapter_image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
def decode_latents(self, latents, vae_batch_size: int = 16):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to support frame-wise chunking in all intermediate layers including ResNet blocks if we want to save memory otherwise this blows up. For now, we can process 64 frames on a 24 GB card by taking care of the VAE encode/decode. I suggest we get the FreeNoise functionality in first, and later take care of optimizing memory in different internal blocks later

latents = 1 / self.vae.config.scaling_factor * latents

batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)

image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
video = []
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)

video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
Expand Down Expand Up @@ -495,22 +502,28 @@ def check_inputs(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Expand Down Expand Up @@ -569,6 +582,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_batch_size: int = 16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use naming/logic similar to SVD for batch decoding.

frames = self.decode_latents(latents, num_frames, decode_chunk_size)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also used in the vae encode for animatediff_video2video btw, but can rename it that

**kwargs,
):
r"""
Expand Down Expand Up @@ -637,6 +651,8 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
vae_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.

Examples:

Expand Down Expand Up @@ -808,7 +824,7 @@ def __call__(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)

# 10. Offload all models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...video_processor import VideoProcessor
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down Expand Up @@ -109,6 +110,7 @@ class AnimateDiffControlNetPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for text-to-video generation with ControlNet guidance.
Expand Down Expand Up @@ -432,15 +434,16 @@ def prepare_ip_adapter_image_embeds(

return ip_adapter_image_embeds

def decode_latents(self, latents, decode_batch_size: int = 16):
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents

batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)

video = []
for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)

Expand Down Expand Up @@ -608,22 +611,29 @@ def check_inputs(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")

# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Expand Down Expand Up @@ -718,7 +728,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
vae_batch_size: int = 16,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -1054,7 +1064,7 @@ def __call__(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents, decode_batch_size)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)

# 10. Offload all models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down Expand Up @@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for video-to-video generation.
Expand Down Expand Up @@ -498,15 +500,29 @@ def prepare_ip_adapter_image_embeds(

return ip_adapter_image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
def encode_video(self, video, generator, vae_batch_size: int = 16) -> torch.Tensor:
latents = []
for i in range(0, len(video), vae_batch_size):
batch_video = video[i : i + vae_batch_size]
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
latents.append(batch_video)
return torch.cat(latents)

# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents

batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)

image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
video = []
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)

video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
Expand Down Expand Up @@ -622,6 +638,7 @@ def prepare_latents(
device,
generator,
latents=None,
vae_batch_size: int = 16,
):
if latents is None:
num_frames = video.shape[1]
Expand Down Expand Up @@ -656,13 +673,10 @@ def prepare_latents(
)

init_latents = [
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
for i in range(batch_size)
self.encode_video(video[i], generator[i], vae_batch_size).unsqueeze(0) for i in range(batch_size)
]
else:
init_latents = [
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
]
init_latents = [self.encode_video(vid, generator, vae_batch_size).unsqueeze(0) for vid in video]

init_latents = torch.cat(init_latents, dim=0)

Expand Down Expand Up @@ -747,6 +761,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_batch_size: int = 16,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -822,6 +837,8 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
vae_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.

Examples:

Expand Down Expand Up @@ -923,6 +940,7 @@ def __call__(
device=device,
generator=generator,
latents=latents,
vae_batch_size=vae_batch_size,
)

# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
Expand Down Expand Up @@ -990,7 +1008,7 @@ def __call__(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)

# 10. Offload all models
Expand Down
Loading
Loading