Skip to content

Allow resolutions that are not multiples of 64 #505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 30, 2022

Conversation

jachiam
Copy link
Contributor

@jachiam jachiam commented Sep 13, 2022

Hello! This is my first time making a PR for an open source project in a while so I apologize if I'm a bit rusty at it.

This should be a fix for #255

Generally it permits height/width for generated images that are not multiples of 64 (though still required to be multiples of 8). The core logic is just to change the F.interpolate function in Upsample2D to accept a size instead of giving a rote factor of 2. The remainder tracks what the right size should be and sends it where it needs to go.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 13, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Very cool that you open your first PR on diffusers @jachiam :-)

Could you give an example use case that now works with your PR that didn't work before that we could try out locally?
Also cc @anton-l @patil-suraj

@jachiam
Copy link
Contributor Author

jachiam commented Sep 13, 2022

Sure. This PR would make it easier to generate images using the StableDiffusionPipeline in an approximately 4:3 or 16:9 aspect ratio at a desired size (instead of always requiring both height and width to be multiples of 64), for instance using 512 as the height and hence either 680 or 904 as the width. Currently using height, width = (512, 680) or (512, 904) will throw an error when the upsampled output of an upblock doesn't have the same dimensions as the corresponding downblock hidden state.

@jachiam
Copy link
Contributor Author

jachiam commented Sep 14, 2022

I'm also moderately (though not 1000%) confident that this leaves all behavior mathematically unchanged from before the PR for all height/width combos that are multiples of 64. This should be a strict improvement with no regressions or degradations for any other cases.

@patrickvonplaten
Copy link
Contributor

Sorry, I wasn't very clear before I think 😅

Could you maybe add a code-snippet / test that shows the new behavior and how it wasn't possible before?

@jachiam
Copy link
Contributor Author

jachiam commented Sep 14, 2022

Ah, sure. Here's a simple snippet:

from torch import autocast
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True)
pipe.to('cuda')
prompt = "this can be anything!"
with autocast("cuda"):
	images = pipe(prompt, height=512, width=680, num_inference_steps=30)["sample"]

Existing code fails with error message

RuntimeError                              Traceback (most recent call last)
[<ipython-input-6-d959bf6051e2>](https://localhost:8080/#) in <module>
      1 prompt = "this can be anything!"
      2 with autocast("cuda"):
----> 3         images = pipe(prompt, height=512, width=680, num_inference_steps=30)["sample"]

5 frames
[/usr/local/lib/python3.7/dist-packages/diffusers/models/unet_blocks.py](https://localhost:8080/#) in forward(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states)
   1078             res_hidden_states = res_hidden_states_tuple[-1]
   1079             res_hidden_states_tuple = res_hidden_states_tuple[:-1]
-> 1080             hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
   1081 
   1082             hidden_states = resnet(hidden_states, temb)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 44 but got size 43 for tensor number 1 in the list.

My code fixes this.

Here's what I think this error comes from and how my code helps. We have to consider what happens in the guts of the image generation pipeline and the neural network. First, initial latents are generated. Latents are sized to be 1/8th the size of the desired height and width of the generated images. Then, each step of the U-net applies four downblocks followed by four upblocks to its input. Each produces an output that is approximately but not exactly 1/2 of the input size: if the input to a downblock has an odd shape, you get input size/2 rounded up as the output. With a width of 680, as above: the latents width is going to be 680/8 = 85. Then the hidden states from the downblocks have sizes:

1 torch.Size([2, 320, 64, 85])
2 torch.Size([2, 320, 64, 85])
3 torch.Size([2, 320, 32, 43])
4 torch.Size([2, 640, 32, 43])
5 torch.Size([2, 640, 32, 43])
6 torch.Size([2, 640, 16, 22])
7 torch.Size([2, 1280, 16, 22])
8 torch.Size([2, 1280, 16, 22])
9 torch.Size([2, 1280, 8, 11])
10 torch.Size([2, 1280, 8, 11])
11 torch.Size([2, 1280, 8, 11])

When the upblocks get applied, they go "up" by applying an upsampling step. The upsampling is hardcoded to be a factor of 2x. So here, upsampling would take the 11->22, 22->44, 44->88. But in order for the U-net to function correctly, the downblock hidden states have to match the upblock hidden states in the height and width dimensions (otherwise the error shown). So when the upblock produces an output with width 44 here, and the corresponding hidden state from the downblock has width 43, we anger the malevolent tensor shape gods and get an error.

Note that this error happens any time there is an odd value in the cascade of sizes resulting from successive application of downblocks. This means that the latent sizes all have to be divisible by 8 for the network to produce correct results. Which in turn means that the height/width values must be divisible by 64.

But this is easily fixed. Instead of upsampling with a hardcoded factor of 2x, we just upsample directly to the desired size. Then things work fine. For cases where the original width/height was divisible by 64, this is identical to using the hardcoded factor of 2x.

@keturn
Copy link
Contributor

keturn commented Sep 14, 2022

I had noticed that type of error and was puzzled by it myself.

Something I didn't realize until working through your explanation just now is that there are two sets of downblocks & upblocks. One set in the autoencoder to get it from 680 to 85, and then the other one you're describing in the UNet from 85 to 11 (10.625).

That being the case, it seems this warning was wrong, and it really should have been 64:

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

I don't know the math or the models well enough to say what happens if you run the blocks some other scaling factor; especially if it's one they weren't trained on.

A few other questions:

  • If you want to do this, would it be better to do it in the UNet or the autoencoder?
  • Doesn't a case like 512×680 leave you with different scaling factors in x and y? because the x wants to keep that same ×8×8 scaling while the y needs ×2×2×(85/44)×8.

I can believe that the result looks close enough if you eyeball it, because the rounding error is always going to be, what, no more than 1/64th of the overall size? But for now I am suspicious.

@jachiam
Copy link
Contributor Author

jachiam commented Sep 14, 2022

That being the case, it seems this warning was wrong, and it really should have been 64:

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Indeed, that error is wrong and it should have said 64.

I don't know the math or the models well enough to say what happens if you run the blocks some other scaling factor; especially if it's one they weren't trained on.

A few other questions:

  • If you want to do this, would it be better to do it in the UNet or the autoencoder?

I can believe that the result looks close enough if you eyeball it, because the rounding error is always going to be, what, no more than 1/64th of the overall size? But for now I am suspicious.

I'm not sure whether it would be better to do in the UNet or the autoencoder. Plausible to me that both should be fine, allowing a full range of resolutions with reasonable results. Though resolutions closest to the training distribution, with exact multiples of 8/64 will probably be best.

As for how much quality degredation, if any, to expect for off-distribution upsample factors: I'd guess fairly little? This ever-so-slightly changes the boundary on the upsampled tensor (by leaving off one last row or column or both), but otherwise everything that's going through successive NN modules is still on-distribution. So far I have not seen any weirdness or other strange image artifacts. Though again, worth noting: for height/width that are multiples of 64, there is no change mathematically to the behavior of the code from this modification. This only changes behavior for stuff that's not a multiple of 64, where currently the model just breaks, and the new code makes it not break.

  • Doesn't a case like 512×680 leave you with different scaling factors in x and y? because the x wants to keep that same ×8×8 scaling while the y needs ×2×2×(85/44)×8.

I am not sure how this poses an issue for upsampling to the desired size. They are different upsampling factors along different dimensions, sure, but you're passing a specific output tensor size to the F.interpolate call instead of a single scale factor.

@jachiam
Copy link
Contributor Author

jachiam commented Sep 16, 2022

@patrickvonplaten related to your concerns above: I recognize that this introduces code complexity that is hard to read and understand. But the bug itself is hard to understand because it depends on an obscure interaction of tensor shapes from downsampler, upsampler, and resblock pieces, and this is the simplest fix for that issue. I don't believe there is a remedy for the complexity short of an extensive comment explaining the bug and why this fixes it. As far as I can tell, passing the size of the downsample resblock outputs to the upsampler is the only way to ensure the upsampler produces correct results.

Let me know if it would be helpful/faster to hop on a video call to disentangle this, I'd be happy to chat.

@jachiam
Copy link
Contributor Author

jachiam commented Sep 19, 2022

Hello! Just curious if there's any interest in accepting the PR / if there's any action items on my end to make it acceptable?

@vvsotnikov
Copy link
Contributor

Hi @jachiam! Thanks for your PR, I'm eager to see it merged 😃
Seems like there's an issue with the maximum line length in a few places:

black --check --preview --diff -l 119 -t py36 unet_blocks.py unet_2d_condition.py 
--- unet_2d_condition.py	2022-09-20 13:55:12.525445 +0000
+++ unet_2d_condition.py	2022-09-20 14:10:22.826067 +0000
@@ -256,18 +256,15 @@
                 sample = upsample_block(
                     hidden_states=sample,
                     temb=emb,
                     res_hidden_states_tuple=res_samples,
                     encoder_hidden_states=encoder_hidden_states,
-                    upsample_size=upsample_size
+                    upsample_size=upsample_size,
                 )
             else:
                 sample = upsample_block(
-                    hidden_states=sample, 
-                    temb=emb, 
-                    res_hidden_states_tuple=res_samples,
-                    upsample_size=upsample_size
+                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                 )
         # 6. post-process
         # make sure hidden states is in float32
         # when running in half-precision
         sample = self.conv_norm_out(sample.float()).type(sample.dtype)
would reformat unet_2d_condition.py
--- unet_blocks.py	2022-09-20 13:56:35.881426 +0000
+++ unet_blocks.py	2022-09-20 14:10:23.556576 +0000
@@ -1070,11 +1070,13 @@
             )
 
         for attn in self.attentions:
             attn._set_attention_slice(slice_size)
 
-    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
+    def forward(
+        self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None
+    ):
         for resnet, attn in zip(self.resnets, self.attentions):
             # pop res hidden states
             res_hidden_states = res_hidden_states_tuple[-1]
             res_hidden_states_tuple = res_hidden_states_tuple[:-1]
             hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
would reformat unet_blocks.py

Oh no! 💥 💔 💥
2 files would be reformatted.

@jachiam
Copy link
Contributor Author

jachiam commented Sep 20, 2022

@vvsotnikov OK, great! I ran black - looks like all checks are passed. All good?

@vvsotnikov
Copy link
Contributor

@patrickvonplaten is there anything else needed to get this one merged?

@patrickvonplaten
Copy link
Contributor

Hey very sorry to be so slow here! This is an important PR and I also would very much like to get in merged soon! Hope to be able to look into it until the end of the week

@patrickvonplaten patrickvonplaten self-assigned this Sep 27, 2022
Sanster added a commit to Sanster/IOPaint that referenced this pull request Sep 27, 2022
@patrickvonplaten
Copy link
Contributor

Hey @vvsotnikov and @jachiam,

Very sorry to have been so late here. I went into the PR and fiddled a bit with the code to try to make it a bit more readable.
The core functionality is kept as is - I tried it out with image sizes that are not multiples of 64 and it works very well! Thanks a lot for adding this :-)

Added also one test.
Hope this was ok! @jachiam could you take a look if the current version seems sensible to you?

More than happy to merge then!

@jachiam
Copy link
Contributor Author

jachiam commented Sep 29, 2022

Rad! I've skimmed it but haven't run it yet - looks fine to me since as far as I can tell the core logic didn't change anywhere. As long as it's still working for you and you made a passing test for it, it's almost certainly fine! I'd be excited to see it merged. :)

@patrickvonplaten patrickvonplaten merged commit a784be2 into huggingface:main Sep 30, 2022
@patrickvonplaten
Copy link
Contributor

Put out a tweet here: https://twitter.com/PatrickPlaten/status/1577430173611610112 - hope that's fine 😅

prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* Allow resolutions that are not multiples of 64

* ran black

* fix bug

* add test

* more explanation

* more comments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
cookieranger pushed a commit to cookieranger/lama-cleaner that referenced this pull request Jan 16, 2023
keeneyetact added a commit to keeneyetact/lama that referenced this pull request Jul 4, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Allow resolutions that are not multiples of 64

* ran black

* fix bug

* add test

* more explanation

* more comments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants