Skip to content

Commit c466106

Browse files
a-r-r-o-wsayakpaul
authored andcommitted
[core] CogVideoX memory optimizations in VAE encode (#9340)
fake context parallel cache, vae encode tiling (cherry picked from commit bf890bc)
1 parent 4d901f9 commit c466106

File tree

1 file changed

+101
-4
lines changed

1 file changed

+101
-4
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,7 @@ def __init__(
999999
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
10001000
# number of temporal frames.
10011001
self.num_latent_frames_batch_size = 2
1002+
self.num_sample_frames_batch_size = 8
10021003

10031004
# We make the minimum height and width of sample for tiling half that of the generally supported
10041005
self.tile_sample_min_height = sample_height // 2
@@ -1081,6 +1082,29 @@ def disable_slicing(self) -> None:
10811082
"""
10821083
self.use_slicing = False
10831084

1085+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
1086+
batch_size, num_channels, num_frames, height, width = x.shape
1087+
1088+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1089+
return self.tiled_encode(x)
1090+
1091+
frame_batch_size = self.num_sample_frames_batch_size
1092+
enc = []
1093+
for i in range(num_frames // frame_batch_size):
1094+
remaining_frames = num_frames % frame_batch_size
1095+
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1096+
end_frame = frame_batch_size * (i + 1) + remaining_frames
1097+
x_intermediate = x[:, :, start_frame:end_frame]
1098+
x_intermediate = self.encoder(x_intermediate)
1099+
if self.quant_conv is not None:
1100+
x_intermediate = self.quant_conv(x_intermediate)
1101+
enc.append(x_intermediate)
1102+
1103+
self._clear_fake_context_parallel_cache()
1104+
enc = torch.cat(enc, dim=2)
1105+
1106+
return enc
1107+
10841108
@apply_forward_hook
10851109
def encode(
10861110
self, x: torch.Tensor, return_dict: bool = True
@@ -1094,13 +1118,17 @@ def encode(
10941118
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
10951119
10961120
Returns:
1097-
The latent representations of the encoded images. If `return_dict` is True, a
1121+
The latent representations of the encoded videos. If `return_dict` is True, a
10981122
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
10991123
"""
1100-
h = self.encoder(x)
1101-
if self.quant_conv is not None:
1102-
h = self.quant_conv(h)
1124+
if self.use_slicing and x.shape[0] > 1:
1125+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1126+
h = torch.cat(encoded_slices)
1127+
else:
1128+
h = self._encode(x)
1129+
11031130
posterior = DiagonalGaussianDistribution(h)
1131+
11041132
if not return_dict:
11051133
return (posterior,)
11061134
return AutoencoderKLOutput(latent_dist=posterior)
@@ -1172,6 +1200,75 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
11721200
)
11731201
return b
11741202

1203+
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1204+
r"""Encode a batch of images using a tiled encoder.
1205+
1206+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1207+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1208+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1209+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1210+
output, but they should be much less noticeable.
1211+
1212+
Args:
1213+
x (`torch.Tensor`): Input batch of videos.
1214+
1215+
Returns:
1216+
`torch.Tensor`:
1217+
The latent representation of the encoded videos.
1218+
"""
1219+
# For a rough memory estimate, take a look at the `tiled_decode` method.
1220+
batch_size, num_channels, num_frames, height, width = x.shape
1221+
1222+
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1223+
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1224+
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1225+
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1226+
row_limit_height = self.tile_latent_min_height - blend_extent_height
1227+
row_limit_width = self.tile_latent_min_width - blend_extent_width
1228+
frame_batch_size = self.num_sample_frames_batch_size
1229+
1230+
# Split x into overlapping tiles and encode them separately.
1231+
# The tiles have an overlap to avoid seams between tiles.
1232+
rows = []
1233+
for i in range(0, height, overlap_height):
1234+
row = []
1235+
for j in range(0, width, overlap_width):
1236+
time = []
1237+
for k in range(num_frames // frame_batch_size):
1238+
remaining_frames = num_frames % frame_batch_size
1239+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1240+
end_frame = frame_batch_size * (k + 1) + remaining_frames
1241+
tile = x[
1242+
:,
1243+
:,
1244+
start_frame:end_frame,
1245+
i : i + self.tile_sample_min_height,
1246+
j : j + self.tile_sample_min_width,
1247+
]
1248+
tile = self.encoder(tile)
1249+
if self.quant_conv is not None:
1250+
tile = self.quant_conv(tile)
1251+
time.append(tile)
1252+
self._clear_fake_context_parallel_cache()
1253+
row.append(torch.cat(time, dim=2))
1254+
rows.append(row)
1255+
1256+
result_rows = []
1257+
for i, row in enumerate(rows):
1258+
result_row = []
1259+
for j, tile in enumerate(row):
1260+
# blend the above tile and the left tile
1261+
# to the current tile and add the current tile to the result row
1262+
if i > 0:
1263+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1264+
if j > 0:
1265+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1266+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1267+
result_rows.append(torch.cat(result_row, dim=4))
1268+
1269+
enc = torch.cat(result_rows, dim=3)
1270+
return enc
1271+
11751272
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
11761273
r"""
11771274
Decode a batch of images using a tiled decoder.

0 commit comments

Comments
 (0)