Skip to content

Add from_pt argument in .from_pretrained #527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Sep 16, 2022

What does this PR do?

This PR addresses #523 by attempting to add from_pt argument inside .from_pretrained function in FlaxModelMixin class.
For now I followed the same logic as in transformers regarding this argument. However it seems that conversion tests are not implemented yet. Do we add these tests in this PR or it should be addressed in a future PR?

After this PR, as in transformers any model could be loaded in flax by doing

model = ...from_pretrained(..., from_pt=True)

cc @patrickvonplaten @mishig25 @patil-suraj @pcuenca

I think we might need to add tests for this to make sure everything works as expected

- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 16, 2022

The documentation is not available anymore as the PR was closed or merged.

@mishig25 mishig25 marked this pull request as draft September 16, 2022 11:59
- fix a small nit - to not enter in the second if condition
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Sep 16, 2022

That's a very good first draft! Thanks a lot :-)

We will also have to rename all the weight names of https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition_flax.py to make it compatible as stated by @patil-suraj here: #502

@younesbelkada In short can you use the following code to iterate fast here?

  1. Create a dummy stable diffusion unet:
#!/usr/bin/env python3
from diffusers import UNet2DConditionModel, FlaxUNet2DConditionModel
import tempfile
import torch

model = UNet2DConditionModel(
    block_out_channels=(32, 64),
    layers_per_block=2,
    sample_size=32,
    in_channels=4,
    out_channels=4,
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
    up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
    cross_attention_dim=32,
)

with tempfile.TemporaryDirectory() as tmpdirname:
    model.save_pretrained(tmpdirname)
    flax_model = FlaxUNet2DConditionModel.from_pretrained(tmpdirname, from_pt=True)

sample = torch.rand(1, 4, model.config.sample_size, model.config.sample_size)
time = 1
text_emb = torch.rand(1, 77, model.config.sample_size)


## TODO convert all to Flax inputs

output = model(sample, time, text_emb).sample


## TODO make sure outputs match
flax_output = flax_model(flax_sample, time, flax_text_emb).sample

@patil-suraj
Copy link
Contributor

patil-suraj commented Sep 16, 2022

Awesome!

I think we might need to add tests for this to make sure everything works as expected

We should definitely add tests for this!

Think we will need to wait a bit, we need to update the flax unet model to have the same weight names as in pt cf #502
once that's ready we can use that for the tests.

Also note that, we need to handle the nn.ModuleList a bit differently here because the way flax treats module lists.

In torch for module list the keys are created like this model.layers.0..., model.layers.1....
In flax it's model.layers_0...., model.layers_1....

In tranformers this is handled by creating a wrapper module (the ...LayerCollection modules), to wrap the list inside the model to create the same naming structure for the keys. But that adds a lot of boilerplate so I'm not in favor of having this here.

We could instead use regex to rename the module list keys while converting to map them to correct names. For example:

regex = r"\w+[.]\d+"

def rename_key(key):
   pats = re.findall(regex, key)
    for pat in pats:
        key = key.replace(pat, "_".join(pat.split(".")))
    return key

This will rename the model.layers.0... keys to flax format model.layers_0....

@patrickvonplaten
Copy link
Contributor

@younesbelkada I think you can already get started on the PR here to make it work by renaming keys in https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition_flax.py and changing the conversion function as explained by Suraj and then we look into tests

@younesbelkada
Copy link
Contributor Author

Perfect thanks for all the pointers!
Yes I have just noticed that the Flax weights were named layer_name_{index} which is different from transformers - will rename them as suggested and test everting
Also as @mishig25 told me offline, there is no custom head logic so the function that I have copied will surely change!

@younesbelkada
Copy link
Contributor Author

Few keys left before the full match!
Just wondering how would you address the dropout issue? For now I just renamed the Sequential layer here to to_out_0 to match the PT weights

younesbelkada and others added 4 commits September 16, 2022 14:00
- modify FlaxUnet modules
- first conversion script
- more keys to be matched
- now all keys match
- change module names for correct matching
- upsample module name changed
@patil-suraj
Copy link
Contributor

patil-suraj commented Sep 16, 2022

Just wondering how would you address the dropout issue? For now I just renamed the Sequential layer here to to_out_0 to match the PT weights

Aah, didn't realise this. changing the name here makes sense.

Also now that we have JAX in diffusers, let's avoid using dropout or non module functions in nn.Sequential, better if we could just avoid it. cc @patrickvonplaten @anton-l @pcuenca

@@ -55,7 +55,7 @@ def setup(self):
self.attentions = attentions

if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a list like in PyTorch

Copy link
Contributor Author

@younesbelkada younesbelkada Sep 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that there is always one element in the list in the PyTorch modeling file, is it expected that this will change in the future to insert more modules?

The only solution I see is to wrap it in a LayerCollection module but this goes against what we want to achieve here - also I agree that this solution adds a lot of boilerplate

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! @patrickvonplaten why did we use list for downsamplers ?

- test pass with atol and rtol= `4e-02`
@@ -214,6 +214,13 @@ def __call__(
When returning a tuple, the first element is the sample tensor.
"""
# 1. time
timesteps = timestep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because I copy-pasted the logic behind:

, will change it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposed a fix in 7527ab1

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 16, 2022

Here is a working v1 of the conversion script! The test now pass with atol=rtol=4e-03. Here is the script I use to check the correctness of the logits:

#!/usr/bin/env python3
from diffusers import UNet2DConditionModel, FlaxUNet2DConditionModel
import tempfile
import torch
import numpy as np
import jax.numpy as jnp

model = UNet2DConditionModel(
    block_out_channels=(32, 64),
    layers_per_block=2,
    sample_size=32,
    in_channels=4,
    out_channels=4,
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
    up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
    cross_attention_dim=32,
)

model = model.eval()

with tempfile.TemporaryDirectory() as tmpdirname:
    model.save_pretrained(tmpdirname)
    flax_model, flax_params = FlaxUNet2DConditionModel.from_pretrained(tmpdirname, from_pt=True)

sample = torch.rand(1, 4, model.config.sample_size, model.config.sample_size)
time = 1
text_emb = torch.rand(1, 77, model.config.sample_size)


# Step 1: Infer with the PT model
torch_output = model(sample, time, text_emb).sample


# Step 2: Infer with JAX model
flax_sample = jnp.array(sample.numpy())
flax_text_emb = jnp.array(text_emb.numpy())

flax_output = flax_model.apply({"params":flax_params}, jnp.transpose(flax_sample, (0, 2, 3, 1)), time, flax_text_emb).sample

# Step 3: Check that the values are close
converted_flax_output = torch.from_numpy(np.array(flax_output))
converted_flax_output = converted_flax_output.permute(0, 3, 1, 2)

torch.testing.assert_allclose(converted_flax_output, torch_output, rtol=4e-03, atol=4e-03)

# This is not really stable since any module that has the name 'scale'
# Will be affected. Maybe just check pt_tuple_key[-2] ?
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if any("norm" in str_ for str_ in pt_tuple_key) and pt_tuple_key[-1] == "weight":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite hacky, we might need to change it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is what I used in my repo, could we try adapting it like this ? Here's the full script https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py

if (
            "norm" in pt_key
            and (pt_tuple_key[-1] == "bias")
            and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
            and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
        ):
            pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
        elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
            pt_tuple_key = pt_tuple_key[:-1] + ("scale",)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the snippet! Adapted the changes in 5973e43
Now it also supports converting embedding layers (which was not the case before your comment)

@patil-suraj
Copy link
Contributor

Think the test is flaky because the unet is not in eval mode.

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 16, 2022

Yes nice catch ! With eval mode the test pass with 4e-03 instead of 4e-02

Edited the message above

@younesbelkada younesbelkada marked this pull request as ready for review September 16, 2022 15:33
- add TODO for embedding layers
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool, thanks for working on this!

For tests, we need to add the common test mixin for flax and then we can add tests for this. This could be done in a follow-up PR.

timesteps = jnp.array([timesteps], dtype=jnp.int32)
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
timesteps = timesteps.astype(dtype=jnp.float32)
timesteps = timesteps[None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe use expand_dims as it will be cleaner.

Copy link
Contributor Author

@younesbelkada younesbelkada Sep 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 1facd9f ! (leaving this open in case someone needs to jump in)

@@ -55,7 +55,7 @@ def setup(self):
self.attentions = attentions

if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! @patrickvonplaten why did we use list for downsamplers ?

# This is not really stable since any module that has the name 'scale'
# Will be affected. Maybe just check pt_tuple_key[-2] ?
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if any("norm" in str_ for str_ in pt_tuple_key) and pt_tuple_key[-1] == "weight":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is what I used in my repo, could we try adapting it like this ? Here's the full script https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py

if (
            "norm" in pt_key
            and (pt_tuple_key[-1] == "bias")
            and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
            and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
        ):
            pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
        elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
            pt_tuple_key = pt_tuple_key[:-1] + ("scale",)

- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array
- add better test to check for keys conversion
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 16, 2022

Thanks a lot for the comments @patil-suraj ! Should have addressed them now 💪
I am still worried that the tests still pass only with atol=rtol=4e-03, as far as I can see from the Pytorch script there is no stochasticity involved in the forward pass

Also from the experience in huggingface/transformers#17779 (comment) another factor of divergence could be the use of transposed convolution that not give the same result between Pytorch and JAX/Flax, see: jax-ml/jax#5772

EDIT: from my tests it seems that we are not using nn.ConvTranspose2d (ie, use_conv_transpose is set to False.) but let's maybe keep this mind for future models

@patrickvonplaten
Copy link
Contributor

Very nice! Will take a closer look tomorrow.

Could you try to apply the same changes to https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae_flax.py tomorrow? :-)

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 19, 2022

@patrickvonplaten @patil-suraj
I just tried with VAE and can confirm the test pass with much lower tolerance than Unet (2e-05)
Here is the script I use to compare

#!/usr/bin/env python3
from diffusers import AutoencoderKL, FlaxAutoencoderKL
import tempfile
import torch
import numpy as np
import jax.numpy as jnp

model = AutoencoderKL(
    in_channels = 3,
    out_channels = 3,
    sample_size=32,
    latent_channels= 4,
    norm_num_groups= 32,
)

model = model.eval()

with tempfile.TemporaryDirectory() as tmpdirname:
    model.save_pretrained(tmpdirname)
    flax_model, flax_params = FlaxAutoencoderKL.from_pretrained(tmpdirname, from_pt=True)

sample = torch.rand(1, 3, model.config.sample_size, model.config.sample_size)
time = 1


# Step 1: Infer with the PT model
torch_output = model(sample).sample


# Step 2: Infer with JAX model
flax_sample = jnp.array(sample.numpy())
flax_output = flax_model.apply({"params":flax_params}, flax_sample).sample

# Step 3: Check that the values are close
converted_flax_output = torch.from_numpy(np.array(flax_output))

torch.testing.assert_allclose(converted_flax_output, torch_output, rtol=2e-05, atol=2e-05)

One small detail though, I can see that the shapes are slightly inconsistent between the VAE and Unet, the first model expects a shape of the format n_channels, img_w, img_h whereas the Unet expects an input of the format img_w, img_h, n_channels. It seems that the VAE does the permutation operations internally. Do you think we have to change either one of the model expected input shape to make it consistent across models? Happy to do that just let me know ;)

- output `img_w x img_h x n_channels` from the VAE
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 19, 2022

Added a commit 4cad1ae for shapes consistency that we could revert if the current input/output shapes are expected, the test now pass with this script (similar as Unet):

#!/usr/bin/env python3
from diffusers import AutoencoderKL, FlaxAutoencoderKL
import tempfile
import torch
import numpy as np
import jax.numpy as jnp

model = AutoencoderKL(
    in_channels = 3,
    out_channels = 3,
    sample_size=32,
    latent_channels= 4,
    norm_num_groups= 32,
)

model = model.eval()

with tempfile.TemporaryDirectory() as tmpdirname:
    model.save_pretrained(tmpdirname)
    flax_model, flax_params = FlaxAutoencoderKL.from_pretrained(tmpdirname, from_pt=True)

sample = torch.rand(1, 3, model.config.sample_size, model.config.sample_size)
time = 1


# Step 1: Infer with the PT model
torch_output = model(sample).sample


# Step 2: Infer with JAX model
flax_sample = jnp.array(sample.numpy())

flax_output = flax_model.apply({"params":flax_params}, jnp.transpose(flax_sample, (0, 2, 3, 1))).sample

# Step 3: Check that the values are close
converted_flax_output = torch.from_numpy(np.array(flax_output))
converted_flax_output = converted_flax_output.permute(0, 3, 1, 2)


torch.testing.assert_allclose(converted_flax_output, torch_output, rtol=2e-05, atol=2e-05)

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Sep 20, 2022

@patil-suraj as discussed offline I have just fixed the channels issue!
Now the test passes with:

#!/usr/bin/env python3
from diffusers import UNet2DConditionModel, FlaxUNet2DConditionModel
import tempfile
import torch
import numpy as np
import jax.numpy as jnp

model = UNet2DConditionModel(
    block_out_channels=(32, 64),
    layers_per_block=2,
    sample_size=32,
    in_channels=4,
    out_channels=4,
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
    up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
    cross_attention_dim=32,
)

model = model.eval()

with tempfile.TemporaryDirectory() as tmpdirname:
    model.save_pretrained(tmpdirname)
    flax_model, flax_params = FlaxUNet2DConditionModel.from_pretrained(tmpdirname, from_pt=True)

sample = torch.rand(1, 4, model.config.sample_size, model.config.sample_size)
time = 1
text_emb = torch.rand(1, 77, model.config.sample_size)


# Step 1: Infer with the PT model
torch_output = model(sample, time, text_emb).sample


# Step 2: Infer with JAX model
flax_sample = jnp.array(sample.numpy())
flax_text_emb = jnp.array(text_emb.numpy())

flax_output = flax_model.apply({"params":flax_params}, flax_sample, time, flax_text_emb).sample

# Step 3: Check that the values are close
converted_flax_output = torch.from_numpy(np.array(flax_output))

torch.testing.assert_allclose(converted_flax_output, torch_output, rtol=4e-03, atol=4e-03)

Channels first for Unet now 🔥 !

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

Works perfectly for unet model. For FlaxAutoencoderKL we need to rename some keys, will open a follow-up a PR for that. Merging!

@patil-suraj patil-suraj merged commit 0902449 into huggingface:main Sep 20, 2022
PhaneeshB added a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* first commit:

- add `from_pt` argument in `from_pretrained` function
- add `modeling_flax_pytorch_utils.py` file

* small nit

- fix a small nit - to not enter in the second if condition

* major changes

- modify FlaxUnet modules
- first conversion script
- more keys to be matched

* keys match

- now all keys match
- change module names for correct matching
- upsample module name changed

* working v1

- test pass with atol and rtol= `4e-02`

* replace unsued arg

* make quality

* add small docstring

* add more comments

- add TODO for embedding layers

* small change

- use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array

* add more conditions on conversion

- add better test to check for keys conversion

* make shapes consistent

- output `img_w x img_h x n_channels` from the VAE

* Revert "make shapes consistent"

This reverts commit 4cad1ae.

* fix unet shape

- channels first!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants