Skip to content

Commit bf890bc

Browse files
committed
fake context parallel cache, vae encode tiling
1 parent 1b781ba commit bf890bc

File tree

2 files changed

+112
-19
lines changed

2 files changed

+112
-19
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,7 @@ def __init__(
999999
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
10001000
# number of temporal frames.
10011001
self.num_latent_frames_batch_size = 2
1002+
self.num_sample_frames_batch_size = 8
10021003

10031004
# We make the minimum height and width of sample for tiling half that of the generally supported
10041005
self.tile_sample_min_height = sample_height // 2
@@ -1081,6 +1082,29 @@ def disable_slicing(self) -> None:
10811082
"""
10821083
self.use_slicing = False
10831084

1085+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
1086+
batch_size, num_channels, num_frames, height, width = x.shape
1087+
1088+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1089+
return self.tiled_encode(x)
1090+
1091+
frame_batch_size = self.num_sample_frames_batch_size
1092+
enc = []
1093+
for i in range(num_frames // frame_batch_size):
1094+
remaining_frames = num_frames % frame_batch_size
1095+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1096+
end_frame = frame_batch_size * (i + 1) + remaining_frames
1097+
x_intermediate = x[:, :, start_frame:end_frame]
1098+
x_intermediate = self.encoder(x_intermediate)
1099+
if self.quant_conv is not None:
1100+
x_intermediate = self.quant_conv(x_intermediate)
1101+
enc.append(x_intermediate)
1102+
1103+
self._clear_fake_context_parallel_cache()
1104+
enc = torch.cat(enc, dim=2)
1105+
1106+
return enc
1107+
10841108
@apply_forward_hook
10851109
def encode(
10861110
self, x: torch.Tensor, return_dict: bool = True
@@ -1094,13 +1118,17 @@ def encode(
10941118
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
10951119
10961120
Returns:
1097-
The latent representations of the encoded images. If `return_dict` is True, a
1121+
The latent representations of the encoded videos. If `return_dict` is True, a
10981122
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
10991123
"""
1100-
h = self.encoder(x)
1101-
if self.quant_conv is not None:
1102-
h = self.quant_conv(h)
1124+
if self.use_slicing and x.shape[0] > 1:
1125+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1126+
h = torch.cat(encoded_slices)
1127+
else:
1128+
h = self._encode(x)
1129+
11031130
posterior = DiagonalGaussianDistribution(h)
1131+
11041132
if not return_dict:
11051133
return (posterior,)
11061134
return AutoencoderKLOutput(latent_dist=posterior)
@@ -1172,6 +1200,75 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
11721200
)
11731201
return b
11741202

1203+
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1204+
r"""Encode a batch of images using a tiled encoder.
1205+
1206+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1207+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1208+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1209+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1210+
output, but they should be much less noticeable.
1211+
1212+
Args:
1213+
x (`torch.Tensor`): Input batch of videos.
1214+
1215+
Returns:
1216+
`torch.Tensor`:
1217+
The latent representation of the encoded videos.
1218+
"""
1219+
# For a rough memory estimate, take a look at the `tiled_decode` method.
1220+
batch_size, num_channels, num_frames, height, width = x.shape
1221+
1222+
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1223+
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1224+
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1225+
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1226+
row_limit_height = self.tile_latent_min_height - blend_extent_height
1227+
row_limit_width = self.tile_latent_min_width - blend_extent_width
1228+
frame_batch_size = self.num_sample_frames_batch_size
1229+
1230+
# Split x into overlapping tiles and encode them separately.
1231+
# The tiles have an overlap to avoid seams between tiles.
1232+
rows = []
1233+
for i in range(0, height, overlap_height):
1234+
row = []
1235+
for j in range(0, width, overlap_width):
1236+
time = []
1237+
for k in range(num_frames // frame_batch_size):
1238+
remaining_frames = num_frames % frame_batch_size
1239+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1240+
end_frame = frame_batch_size * (k + 1) + remaining_frames
1241+
tile = x[
1242+
:,
1243+
:,
1244+
start_frame:end_frame,
1245+
i : i + self.tile_sample_min_height,
1246+
j : j + self.tile_sample_min_width,
1247+
]
1248+
tile = self.encoder(tile)
1249+
if self.quant_conv is not None:
1250+
tile = self.quant_conv(tile)
1251+
time.append(tile)
1252+
self._clear_fake_context_parallel_cache()
1253+
row.append(torch.cat(time, dim=2))
1254+
rows.append(row)
1255+
1256+
result_rows = []
1257+
for i, row in enumerate(rows):
1258+
result_row = []
1259+
for j, tile in enumerate(row):
1260+
# blend the above tile and the left tile
1261+
# to the current tile and add the current tile to the result row
1262+
if i > 0:
1263+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1264+
if j > 0:
1265+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1266+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1267+
result_rows.append(torch.cat(result_row, dim=4))
1268+
1269+
enc = torch.cat(result_rows, dim=3)
1270+
return enc
1271+
11751272
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
11761273
r"""
11771274
Decode a batch of images using a tiled decoder.

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,6 @@ def prepare_latents(
341341
video: Optional[torch.Tensor] = None,
342342
batch_size: int = 1,
343343
num_channels_latents: int = 16,
344-
num_frames: int = 13,
345344
height: int = 60,
346345
width: int = 90,
347346
dtype: Optional[torch.dtype] = None,
@@ -350,13 +349,16 @@ def prepare_latents(
350349
latents: Optional[torch.Tensor] = None,
351350
timestep: Optional[torch.Tensor] = None,
352351
):
352+
num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
353+
353354
shape = (
354355
batch_size,
355-
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
356+
num_frames,
356357
num_channels_latents,
357358
height // self.vae_scale_factor_spatial,
358359
width // self.vae_scale_factor_spatial,
359360
)
361+
360362
if isinstance(generator, list) and len(generator) != batch_size:
361363
raise ValueError(
362364
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -432,6 +434,8 @@ def check_inputs(
432434
strength,
433435
negative_prompt,
434436
callback_on_step_end_tensor_inputs,
437+
video=None,
438+
latents=None,
435439
prompt_embeds=None,
436440
negative_prompt_embeds=None,
437441
):
@@ -479,6 +483,9 @@ def check_inputs(
479483
f" {negative_prompt_embeds.shape}."
480484
)
481485

486+
if video is not None and latents is not None:
487+
raise ValueError("Only one of `video` or `latents` should be provided")
488+
482489
def fuse_qkv_projections(self) -> None:
483490
r"""Enables fused QKV projections."""
484491
self.fusing_transformer = True
@@ -539,7 +546,6 @@ def __call__(
539546
negative_prompt: Optional[Union[str, List[str]]] = None,
540547
height: int = 480,
541548
width: int = 720,
542-
num_frames: int = 49,
543549
num_inference_steps: int = 50,
544550
timesteps: Optional[List[int]] = None,
545551
strength: float = 0.8,
@@ -576,11 +582,6 @@ def __call__(
576582
The height in pixels of the generated image. This is set to 1024 by default for the best results.
577583
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
578584
The width in pixels of the generated image. This is set to 1024 by default for the best results.
579-
num_frames (`int`, defaults to `48`):
580-
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
581-
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
582-
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
583-
needs to be satisfied is that of divisibility mentioned above.
584585
num_inference_steps (`int`, *optional*, defaults to 50):
585586
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
586587
expense of slower inference.
@@ -639,11 +640,6 @@ def __call__(
639640
`tuple`. When returning a tuple, the first element is a list with the generated images.
640641
"""
641642

642-
if num_frames > 49:
643-
raise ValueError(
644-
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
645-
)
646-
647643
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
648644
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
649645

@@ -700,16 +696,16 @@ def __call__(
700696
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
701697
self._num_timesteps = len(timesteps)
702698

703-
# 5. Prepare latents.
699+
# 5. Prepare latents
704700
if latents is None:
705701
video = self.video_processor.preprocess_video(video, height=height, width=width)
706702
video = video.to(device=device, dtype=prompt_embeds.dtype)
703+
707704
latent_channels = self.transformer.config.in_channels
708705
latents = self.prepare_latents(
709706
video,
710707
batch_size * num_videos_per_prompt,
711708
latent_channels,
712-
num_frames,
713709
height,
714710
width,
715711
prompt_embeds.dtype,

0 commit comments

Comments
 (0)