-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Add from_pt
argument in .from_pretrained
#527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add from_pt
argument in .from_pretrained
#527
Conversation
- add `from_pt` argument in `from_pretrained` function - add `modeling_flax_pytorch_utils.py` file
The documentation is not available anymore as the PR was closed or merged. |
That's a very good first draft! Thanks a lot :-) We will also have to rename all the weight names of https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition_flax.py to make it compatible as stated by @patil-suraj here: #502 @younesbelkada In short can you use the following code to iterate fast here?
#!/usr/bin/env python3
from diffusers import UNet2DConditionModel, FlaxUNet2DConditionModel
import tempfile
import torch
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
flax_model = FlaxUNet2DConditionModel.from_pretrained(tmpdirname, from_pt=True)
sample = torch.rand(1, 4, model.config.sample_size, model.config.sample_size)
time = 1
text_emb = torch.rand(1, 77, model.config.sample_size)
## TODO convert all to Flax inputs
output = model(sample, time, text_emb).sample
## TODO make sure outputs match
flax_output = flax_model(flax_sample, time, flax_text_emb).sample |
Awesome!
We should definitely add tests for this! Think we will need to wait a bit, we need to update the flax unet model to have the same weight names as in pt cf #502 Also note that, we need to handle the In torch for module list the keys are created like this In We could instead use regex = r"\w+[.]\d+"
def rename_key(key):
pats = re.findall(regex, key)
for pat in pats:
key = key.replace(pat, "_".join(pat.split(".")))
return key This will rename the |
@younesbelkada I think you can already get started on the PR here to make it work by renaming keys in https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition_flax.py and changing the conversion function as explained by Suraj and then we look into tests |
Perfect thanks for all the pointers! |
Few keys left before the full match! |
- modify FlaxUnet modules - first conversion script - more keys to be matched
- now all keys match - change module names for correct matching - upsample module name changed
Aah, didn't realise this. changing the name here makes sense. Also now that we have JAX in diffusers, let's avoid using |
@@ -55,7 +55,7 @@ def setup(self): | |||
self.attentions = attentions | |||
|
|||
if self.add_downsample: | |||
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) | |||
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a list like in PyTorch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see that there is always one element in the list in the PyTorch modeling file, is it expected that this will change in the future to insert more modules?
The only solution I see is to wrap it in a LayerCollection
module but this goes against what we want to achieve here - also I agree that this solution adds a lot of boilerplate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! @patrickvonplaten why did we use list for downsamplers ?
- test pass with atol and rtol= `4e-02`
@@ -214,6 +214,13 @@ def __call__( | |||
When returning a tuple, the first element is the sample tensor. | |||
""" | |||
# 1. time | |||
timesteps = timestep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because I copy-pasted the logic behind:
timesteps = timestep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Proposed a fix in 7527ab1
Here is a working v1 of the conversion script! The test now pass with
|
# This is not really stable since any module that has the name 'scale' | ||
# Will be affected. Maybe just check pt_tuple_key[-2] ? | ||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) | ||
if any("norm" in str_ for str_ in pt_tuple_key) and pt_tuple_key[-1] == "weight": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite hacky, we might need to change it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is what I used in my repo, could we try adapting it like this ? Here's the full script https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
if (
"norm" in pt_key
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
):
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the snippet! Adapted the changes in 5973e43
Now it also supports converting embedding layers (which was not the case before your comment)
Think the test is flaky because the unet is not in |
Yes nice catch ! With eval mode the test pass with Edited the message above |
- add TODO for embedding layers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool, thanks for working on this!
For tests, we need to add the common test mixin for flax and then we can add tests for this. This could be done in a follow-up PR.
timesteps = jnp.array([timesteps], dtype=jnp.int32) | ||
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: | ||
timesteps = timesteps.astype(dtype=jnp.float32) | ||
timesteps = timesteps[None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use expand_dims
as it will be cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 1facd9f ! (leaving this open in case someone needs to jump in)
@@ -55,7 +55,7 @@ def setup(self): | |||
self.attentions = attentions | |||
|
|||
if self.add_downsample: | |||
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) | |||
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! @patrickvonplaten why did we use list for downsamplers ?
# This is not really stable since any module that has the name 'scale' | ||
# Will be affected. Maybe just check pt_tuple_key[-2] ? | ||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) | ||
if any("norm" in str_ for str_ in pt_tuple_key) and pt_tuple_key[-1] == "weight": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is what I used in my repo, could we try adapting it like this ? Here's the full script https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
if (
"norm" in pt_key
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
):
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
- add better test to check for keys conversion
Thanks a lot for the comments @patil-suraj ! Should have addressed them now 💪 Also from the experience in huggingface/transformers#17779 (comment) another factor of divergence could be the use of transposed convolution that not give the same result between Pytorch and JAX/Flax, see: jax-ml/jax#5772 EDIT: from my tests it seems that we are not using |
Very nice! Will take a closer look tomorrow. Could you try to apply the same changes to https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae_flax.py tomorrow? :-) |
@patrickvonplaten @patil-suraj #!/usr/bin/env python3
from diffusers import AutoencoderKL, FlaxAutoencoderKL
import tempfile
import torch
import numpy as np
import jax.numpy as jnp
model = AutoencoderKL(
in_channels = 3,
out_channels = 3,
sample_size=32,
latent_channels= 4,
norm_num_groups= 32,
)
model = model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
flax_model, flax_params = FlaxAutoencoderKL.from_pretrained(tmpdirname, from_pt=True)
sample = torch.rand(1, 3, model.config.sample_size, model.config.sample_size)
time = 1
# Step 1: Infer with the PT model
torch_output = model(sample).sample
# Step 2: Infer with JAX model
flax_sample = jnp.array(sample.numpy())
flax_output = flax_model.apply({"params":flax_params}, flax_sample).sample
# Step 3: Check that the values are close
converted_flax_output = torch.from_numpy(np.array(flax_output))
torch.testing.assert_allclose(converted_flax_output, torch_output, rtol=2e-05, atol=2e-05) One small detail though, I can see that the shapes are slightly inconsistent between the VAE and Unet, the first model expects a shape of the format |
- output `img_w x img_h x n_channels` from the VAE
Added a commit 4cad1ae for shapes consistency that we could revert if the current input/output shapes are expected, the test now pass with this script (similar as Unet):
|
This reverts commit 4cad1ae.
- channels first!
@patil-suraj as discussed offline I have just fixed the channels issue!
Channels first for Unet now 🔥 ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
Works perfectly for unet model. For FlaxAutoencoderKL
we need to rename some keys, will open a follow-up a PR for that. Merging!
* first commit: - add `from_pt` argument in `from_pretrained` function - add `modeling_flax_pytorch_utils.py` file * small nit - fix a small nit - to not enter in the second if condition * major changes - modify FlaxUnet modules - first conversion script - more keys to be matched * keys match - now all keys match - change module names for correct matching - upsample module name changed * working v1 - test pass with atol and rtol= `4e-02` * replace unsued arg * make quality * add small docstring * add more comments - add TODO for embedding layers * small change - use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array * add more conditions on conversion - add better test to check for keys conversion * make shapes consistent - output `img_w x img_h x n_channels` from the VAE * Revert "make shapes consistent" This reverts commit 4cad1ae. * fix unet shape - channels first!
What does this PR do?
This PR addresses #523 by attempting to add
from_pt
argument inside.from_pretrained
function inFlaxModelMixin
class.For now I followed the same logic as in
transformers
regarding this argument. However it seems that conversion tests are not implemented yet. Do we add these tests in this PR or it should be addressed in a future PR?After this PR, as in
transformers
any model could be loaded in flax by doingcc @patrickvonplaten @mishig25 @patil-suraj @pcuenca
I think we might need to add tests for this to make sure everything works as expected