From e8b29fbbb6ebd304ed8f5a256ce3cb2b0d0ff64c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Sep 2024 09:07:13 +0200 Subject: [PATCH] fake context parallel cache, vae encode tiling (cherry picked from commit bf890bca0e8aed875d6a207f9b826ce894901522) --- .../autoencoders/autoencoder_kl_cogvideox.py | 105 +++++++++++++++++- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 17fa2bbf40f6..fe887b7db054 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -999,6 +999,7 @@ def __init__( # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different # number of temporal frames. self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 # We make the minimum height and width of sample for tiling half that of the generally supported self.tile_sample_min_height = sample_height // 2 @@ -1081,6 +1082,29 @@ def disable_slicing(self) -> None: """ self.use_slicing = False + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + frame_batch_size = self.num_sample_frames_batch_size + enc = [] + for i in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate = self.encoder(x_intermediate) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + + self._clear_fake_context_parallel_cache() + enc = torch.cat(enc, dim=2) + + return enc + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True @@ -1094,13 +1118,17 @@ def encode( Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: - The latent representations of the encoded images. If `return_dict` is True, a + The latent representations of the encoded videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - h = self.encoder(x) - if self.quant_conv is not None: - h = self.quant_conv(h) + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) @@ -1172,6 +1200,75 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + # For a rough memory estimate, take a look at the `tiled_decode` method. + batch_size, num_channels, num_frames, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + frame_batch_size = self.num_sample_frames_batch_size + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + time = [] + for k in range(num_frames // frame_batch_size): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = x[ + :, + :, + start_frame:end_frame, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3) + return enc + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder.