-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[core] FreeNoise #8948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] FreeNoise #8948
Changes from all commits
80e530f
441d321
5d0f4c3
2e97ba7
690dad6
610f433
10b65b3
a41f843
f6897ae
024e2da
1bb0984
1b7bc00
dc96a8d
691facf
5a60a62
58c2ddc
7000186
c5db39f
594d2d2
fb9ca34
77ee296
52884b3
1e2ef4d
5d5a7ea
3d9b183
44e40a2
a61ffff
d82228e
037ee07
ac3d8c6
d19ddb4
12cc84a
6f48356
1f0ccfd
6a4aab8
2f77c69
8564dc3
b32b1d7
045ae36
2d9aa42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -42,6 +42,7 @@ | |||
from ...utils.torch_utils import randn_tensor | ||||
from ...video_processor import VideoProcessor | ||||
from ..free_init_utils import FreeInitMixin | ||||
from ..free_noise_utils import AnimateDiffFreeNoiseMixin | ||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin | ||||
from .pipeline_output import AnimateDiffPipelineOutput | ||||
|
||||
|
@@ -72,6 +73,7 @@ class AnimateDiffPipeline( | |||
IPAdapterMixin, | ||||
StableDiffusionLoraLoaderMixin, | ||||
FreeInitMixin, | ||||
AnimateDiffFreeNoiseMixin, | ||||
): | ||||
r""" | ||||
Pipeline for text-to-video generation. | ||||
|
@@ -394,15 +396,20 @@ def prepare_ip_adapter_image_embeds( | |||
|
||||
return ip_adapter_image_embeds | ||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents | ||||
def decode_latents(self, latents): | ||||
def decode_latents(self, latents, vae_batch_size: int = 16): | ||||
latents = 1 / self.vae.config.scaling_factor * latents | ||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape | ||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) | ||||
|
||||
image = self.vae.decode(latents).sample | ||||
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) | ||||
video = [] | ||||
for i in range(0, latents.shape[0], vae_batch_size): | ||||
batch_latents = latents[i : i + vae_batch_size] | ||||
batch_latents = self.vae.decode(batch_latents).sample | ||||
video.append(batch_latents) | ||||
|
||||
video = torch.cat(video) | ||||
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) | ||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | ||||
video = video.float() | ||||
return video | ||||
|
@@ -495,22 +502,28 @@ def check_inputs( | |||
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" | ||||
) | ||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents | ||||
def prepare_latents( | ||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None | ||||
): | ||||
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) | ||||
if self.free_noise_enabled: | ||||
latents = self._prepare_latents_free_noise( | ||||
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents | ||||
) | ||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size: | ||||
raise ValueError( | ||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||||
) | ||||
|
||||
shape = ( | ||||
batch_size, | ||||
num_channels_latents, | ||||
num_frames, | ||||
height // self.vae_scale_factor, | ||||
width // self.vae_scale_factor, | ||||
) | ||||
if isinstance(generator, list) and len(generator) != batch_size: | ||||
raise ValueError( | ||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | ||||
) | ||||
|
||||
if latents is None: | ||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||||
|
@@ -569,6 +582,7 @@ def __call__( | |||
clip_skip: Optional[int] = None, | ||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | ||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | ||||
vae_batch_size: int = 16, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use naming/logic similar to SVD for batch decoding. diffusers/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py Line 607 in 73acebb
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also used in the vae encode for animatediff_video2video btw, but can rename it that |
||||
**kwargs, | ||||
): | ||||
r""" | ||||
|
@@ -637,6 +651,8 @@ def __call__( | |||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | ||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | ||||
`._callback_tensor_inputs` attribute of your pipeline class. | ||||
vae_batch_size (`int`, defaults to `16`): | ||||
The number of frames to decode at a time when calling `decode_latents` method. | ||||
|
||||
Examples: | ||||
|
||||
|
@@ -808,7 +824,7 @@ def __call__( | |||
if output_type == "latent": | ||||
video = latents | ||||
else: | ||||
video_tensor = self.decode_latents(latents) | ||||
video_tensor = self.decode_latents(latents, vae_batch_size) | ||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) | ||||
|
||||
# 10. Offload all models | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to support frame-wise chunking in all intermediate layers including ResNet blocks if we want to save memory otherwise this blows up. For now, we can process 64 frames on a 24 GB card by taking care of the VAE encode/decode. I suggest we get the FreeNoise functionality in first, and later take care of optimizing memory in different internal blocks later