@@ -999,6 +999,7 @@ def __init__(
999
999
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1000
1000
# number of temporal frames.
1001
1001
self .num_latent_frames_batch_size = 2
1002
+ self .num_sample_frames_batch_size = 8
1002
1003
1003
1004
# We make the minimum height and width of sample for tiling half that of the generally supported
1004
1005
self .tile_sample_min_height = sample_height // 2
@@ -1081,6 +1082,29 @@ def disable_slicing(self) -> None:
1081
1082
"""
1082
1083
self .use_slicing = False
1083
1084
1085
+ def _encode (self , x : torch .Tensor ) -> torch .Tensor :
1086
+ batch_size , num_channels , num_frames , height , width = x .shape
1087
+
1088
+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
1089
+ return self .tiled_encode (x )
1090
+
1091
+ frame_batch_size = self .num_sample_frames_batch_size
1092
+ enc = []
1093
+ for i in range (num_frames // frame_batch_size ):
1094
+ remaining_frames = num_frames % frame_batch_size
1095
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames )
1096
+ end_frame = frame_batch_size * (i + 1 ) + remaining_frames
1097
+ x_intermediate = x [:, :, start_frame :end_frame ]
1098
+ x_intermediate = self .encoder (x_intermediate )
1099
+ if self .quant_conv is not None :
1100
+ x_intermediate = self .quant_conv (x_intermediate )
1101
+ enc .append (x_intermediate )
1102
+
1103
+ self ._clear_fake_context_parallel_cache ()
1104
+ enc = torch .cat (enc , dim = 2 )
1105
+
1106
+ return enc
1107
+
1084
1108
@apply_forward_hook
1085
1109
def encode (
1086
1110
self , x : torch .Tensor , return_dict : bool = True
@@ -1094,13 +1118,17 @@ def encode(
1094
1118
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1095
1119
1096
1120
Returns:
1097
- The latent representations of the encoded images . If `return_dict` is True, a
1121
+ The latent representations of the encoded videos . If `return_dict` is True, a
1098
1122
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1099
1123
"""
1100
- h = self .encoder (x )
1101
- if self .quant_conv is not None :
1102
- h = self .quant_conv (h )
1124
+ if self .use_slicing and x .shape [0 ] > 1 :
1125
+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
1126
+ h = torch .cat (encoded_slices )
1127
+ else :
1128
+ h = self ._encode (x )
1129
+
1103
1130
posterior = DiagonalGaussianDistribution (h )
1131
+
1104
1132
if not return_dict :
1105
1133
return (posterior ,)
1106
1134
return AutoencoderKLOutput (latent_dist = posterior )
@@ -1172,6 +1200,75 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
1172
1200
)
1173
1201
return b
1174
1202
1203
+ def tiled_encode (self , x : torch .Tensor ) -> torch .Tensor :
1204
+ r"""Encode a batch of images using a tiled encoder.
1205
+
1206
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1207
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1208
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1209
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1210
+ output, but they should be much less noticeable.
1211
+
1212
+ Args:
1213
+ x (`torch.Tensor`): Input batch of videos.
1214
+
1215
+ Returns:
1216
+ `torch.Tensor`:
1217
+ The latent representation of the encoded videos.
1218
+ """
1219
+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1220
+ batch_size , num_channels , num_frames , height , width = x .shape
1221
+
1222
+ overlap_height = int (self .tile_sample_min_height * (1 - self .tile_overlap_factor_height ))
1223
+ overlap_width = int (self .tile_sample_min_width * (1 - self .tile_overlap_factor_width ))
1224
+ blend_extent_height = int (self .tile_latent_min_height * self .tile_overlap_factor_height )
1225
+ blend_extent_width = int (self .tile_latent_min_width * self .tile_overlap_factor_width )
1226
+ row_limit_height = self .tile_latent_min_height - blend_extent_height
1227
+ row_limit_width = self .tile_latent_min_width - blend_extent_width
1228
+ frame_batch_size = self .num_sample_frames_batch_size
1229
+
1230
+ # Split x into overlapping tiles and encode them separately.
1231
+ # The tiles have an overlap to avoid seams between tiles.
1232
+ rows = []
1233
+ for i in range (0 , height , overlap_height ):
1234
+ row = []
1235
+ for j in range (0 , width , overlap_width ):
1236
+ time = []
1237
+ for k in range (num_frames // frame_batch_size ):
1238
+ remaining_frames = num_frames % frame_batch_size
1239
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames )
1240
+ end_frame = frame_batch_size * (k + 1 ) + remaining_frames
1241
+ tile = x [
1242
+ :,
1243
+ :,
1244
+ start_frame :end_frame ,
1245
+ i : i + self .tile_sample_min_height ,
1246
+ j : j + self .tile_sample_min_width ,
1247
+ ]
1248
+ tile = self .encoder (tile )
1249
+ if self .quant_conv is not None :
1250
+ tile = self .quant_conv (tile )
1251
+ time .append (tile )
1252
+ self ._clear_fake_context_parallel_cache ()
1253
+ row .append (torch .cat (time , dim = 2 ))
1254
+ rows .append (row )
1255
+
1256
+ result_rows = []
1257
+ for i , row in enumerate (rows ):
1258
+ result_row = []
1259
+ for j , tile in enumerate (row ):
1260
+ # blend the above tile and the left tile
1261
+ # to the current tile and add the current tile to the result row
1262
+ if i > 0 :
1263
+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_extent_height )
1264
+ if j > 0 :
1265
+ tile = self .blend_h (row [j - 1 ], tile , blend_extent_width )
1266
+ result_row .append (tile [:, :, :, :row_limit_height , :row_limit_width ])
1267
+ result_rows .append (torch .cat (result_row , dim = 4 ))
1268
+
1269
+ enc = torch .cat (result_rows , dim = 3 )
1270
+ return enc
1271
+
1175
1272
def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
1176
1273
r"""
1177
1274
Decode a batch of images using a tiled decoder.
0 commit comments