-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[bug] Precedence of operations in VAE should be slicing -> tiling #9342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
55a1abd
bugfix: precedence of operations should be slicing -> tiling
a-r-r-o-w af060b7
Merge branch 'main' into vae/bugfix-slicing-tiling
a-r-r-o-w 93e4d23
fix typo
a-r-r-o-w 5ac6473
fix another typo
a-r-r-o-w 47d34c3
Merge branch 'main' into vae/bugfix-slicing-tiling
a-r-r-o-w 4e2df5c
Merge branch 'main' into vae/bugfix-slicing-tiling
a-r-r-o-w be9797a
Merge branch 'main' into vae/bugfix-slicing-tiling
a-r-r-o-w ea25b69
deprecate current implementation of tiled_encode and use new impl
a-r-r-o-w 170d2d2
Merge branch 'main' into vae/bugfix-slicing-tiling
sayakpaul 0206822
Update src/diffusers/models/autoencoders/autoencoder_kl.py
a-r-r-o-w 896c8f7
Update src/diffusers/models/autoencoders/autoencoder_kl.py
a-r-r-o-w 1c02d76
Merge branch 'main' into vae/bugfix-slicing-tiling
a-r-r-o-w 2bc58fa
Merge branch 'main' into vae/bugfix-slicing-tiling
a-r-r-o-w File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
|
||
from ...configuration_utils import ConfigMixin, register_to_config | ||
from ...loaders.single_file_model import FromOriginalModelMixin | ||
from ...utils import deprecate | ||
from ...utils.accelerate_utils import apply_forward_hook | ||
from ..attention_processor import ( | ||
ADDED_KV_ATTENTION_PROCESSORS, | ||
|
@@ -245,6 +246,18 @@ def set_default_attn_processor(self): | |
|
||
self.set_attn_processor(processor) | ||
|
||
def _encode(self, x: torch.Tensor) -> torch.Tensor: | ||
batch_size, num_channels, height, width = x.shape | ||
|
||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): | ||
return self._tiled_encode(x) | ||
|
||
enc = self.encoder(x) | ||
if self.quant_conv is not None: | ||
enc = self.quant_conv(enc) | ||
|
||
return enc | ||
|
||
@apply_forward_hook | ||
def encode( | ||
self, x: torch.Tensor, return_dict: bool = True | ||
|
@@ -261,21 +274,13 @@ def encode( | |
The latent representations of the encoded images. If `return_dict` is True, a | ||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. | ||
""" | ||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): | ||
return self.tiled_encode(x, return_dict=return_dict) | ||
|
||
if self.use_slicing and x.shape[0] > 1: | ||
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] | ||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] | ||
h = torch.cat(encoded_slices) | ||
else: | ||
h = self.encoder(x) | ||
|
||
if self.quant_conv is not None: | ||
moments = self.quant_conv(h) | ||
else: | ||
moments = h | ||
h = self._encode(x) | ||
|
||
posterior = DiagonalGaussianDistribution(moments) | ||
posterior = DiagonalGaussianDistribution(h) | ||
|
||
if not return_dict: | ||
return (posterior,) | ||
|
@@ -337,6 +342,60 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. | |
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) | ||
return b | ||
|
||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: | ||
r"""Encode a batch of images using a tiled encoder. | ||
|
||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several | ||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is | ||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the | ||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the | ||
output, but they should be much less noticeable. | ||
|
||
Args: | ||
x (`torch.Tensor`): Input batch of images. | ||
|
||
Returns: | ||
`torch.Tensor`: | ||
The latent representation of the encoded videos. | ||
""" | ||
deprecation_message = ( | ||
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " | ||
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " | ||
"to pass `return_dict`. You will also have to also create a `DiagonalGaussianDistribution()` from the returned value." | ||
) | ||
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should test this deprecation too.
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) | ||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) | ||
row_limit = self.tile_latent_min_size - blend_extent | ||
|
||
# Split the image into 512x512 tiles and encode them separately. | ||
rows = [] | ||
for i in range(0, x.shape[2], overlap_size): | ||
row = [] | ||
for j in range(0, x.shape[3], overlap_size): | ||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] | ||
tile = self.encoder(tile) | ||
if self.config.use_quant_conv: | ||
tile = self.quant_conv(tile) | ||
row.append(tile) | ||
rows.append(row) | ||
result_rows = [] | ||
for i, row in enumerate(rows): | ||
result_row = [] | ||
for j, tile in enumerate(row): | ||
# blend the above tile and the left tile | ||
# to the current tile and add the current tile to the result row | ||
if i > 0: | ||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | ||
if j > 0: | ||
tile = self.blend_h(row[j - 1], tile, blend_extent) | ||
result_row.append(tile[:, :, :row_limit, :row_limit]) | ||
result_rows.append(torch.cat(result_row, dim=3)) | ||
|
||
enc = torch.cat(result_rows, dim=2) | ||
return enc | ||
|
||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: | ||
r"""Encode a batch of images using a tiled encoder. | ||
|
||
|
@@ -356,6 +415,13 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Autoencoder | |
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain | ||
`tuple` is returned. | ||
""" | ||
deprecation_message = ( | ||
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " | ||
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " | ||
"to pass `return_dict`. You will also have to also create a `DiagonalGaussianDistribution()` from the returned value." | ||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) | ||
|
||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) | ||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) | ||
row_limit = self.tile_latent_min_size - blend_extent | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think algorithem changed a bit for use_slicing
previously, apply quant_conv once after combining encoder outputs from all slice
currently, apply quant_conv on each slice
I'm pretty sure the result would be the same, I wonder if there is any implication on performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the performance should be the same since just one convolution layer on compressed outputs of encoder. I can get some numbers soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could perhaps add a test to ensure this? That should clear the confusions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w do you think it could make sense add a fast test here or not really?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's okay without a test here. The functionality is effectively similar and only affects the "batch_size" dim across this conv layer - which will never alter outputs as conv doesn't operate on that.
I know that understanding the changes here is quite easy, but I feel I should leave a comment making the explanation a bit more clear and elaborate for anyone stumbling upon this in the future.
Previously, slicing worked individually and tiling worked individually. When both were enabled, only tiling would be in effect meaning it would chop
[B, C, H, W]
to 4 tiles of shape[B, C, H // 2, W // 2]
(assuming we have 2x2 perfect tiles), process each tile individually and perform blending.This is incorrect as slicing is completely ignored. The correct processing size, ensuring slicing also took effect, would be 4 x B tiles with shape
[1, C, H // 2, W // 2]
- which this PR does.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining!