Skip to content

Implement FlaxModelMixin #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 14, 2022
Merged

Implement FlaxModelMixin #493

merged 20 commits into from
Sep 14, 2022

Conversation

mishig25
Copy link
Contributor

@mishig25 mishig25 commented Sep 13, 2022

Implemented FlaxModelMixin

FlaxModelMixin is flax equivalent of ModelMixin (i.e. a mixin that has save_pretrained, from_pretrained, & various checkpoints util functions).

After being discussed at google/flax#2454 (comment), design-wise, diffusers.FlaxModelMixin differs from transformers.FlaxPreTrainedModel in the fact that: model class should be "state-less" google/flax#2454.

see this comment for testing: #493 (comment)

Example usages:

Class definition (find example here)

@flax_register_to_config
class UNet2D(nn.Module, FlaxModelMixin, ConfigMixin):
    sample_size:int=32
    in_channels:int=4
    out_channels:int=4
    down_block_types:Tuple=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")
    ....
    dtype: jnp.dtype = jnp.float32

    def setup(self):
           .....

from_pretrained (find example here)

unet, unet_params = UNet2D.from_pretrained(f"{flax_path}/unet", dtype=dtype)

call apply just like any other flax.linen.Module (example here)

# predict the noise residual
noise_pred = self.unet.apply(
    {"params": inference_state.unet_params}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, rngs={}
)

Une Question

In FlaxModelMixin.from_pretrained, I have NOT added section about checking missing_keys & mismatched_keys like in trasnformers.flax_from_pretrained.

In short, checking missing_keys & mismatched_keys functionality works by:

  1. random_state = ModelCls.init_weights()
  2. pretrained_state = from_pretrained()
  3. compare keys between random_state VS pretrained_state

Therefore, to add it, it will require:

# add new `input_shape` property to model cls
@flax_register_to_config
class UNet2D(nn.Module, FlaxModelMixin, ConfigMixin):
    input_shape: Tuple = (1, 32, 32, 4),

# require implementation of init_weights
class ModelMixin():
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
        raise NotImplementedError(f"init method has to be implemented for {self}")

and probably call init_weights inside from_pretrained since we do NOT have FlaxModelMixin__init__

Please let me know if I should add this functionality. wdyt

TODOS

  • all docstrings need to be changed
  • soft flax dependency check?

cc: @kashif somehow I can't add you as a reviewer

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 13, 2022

The documentation is not available anymore as the PR was closed or merged.

@pcuenca
Copy link
Member

pcuenca commented Sep 13, 2022

I love this. I'll test it with my changes in #485.

@mishig25 mishig25 marked this pull request as ready for review September 13, 2022 09:43
@mishig25
Copy link
Contributor Author

mishig25 commented Sep 13, 2022

@pcuenca please let me know if you have any questions. Example usages: section in the description should contain all the necessary info. Moreover, you can just run example.py in this branch if you need: patil-suraj/stable-diffusion-jax#10

@mishig25 mishig25 mentioned this pull request Sep 13, 2022
7 tasks
"""
_missing_keys = set()
config_name = CONFIG_NAME
ignore_for_config = ["parent", "name"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this will be an issue in practice, but what happens if the model subclass needs to use ignore_for_config for its own variables? In that case, it would need to also append "parent" and "name" to the list.

Would it be better if FlaxModelMixin knows about these special names without adding them to the list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what happens if the model subclass needs to use ignore_for_config for its own variables?
what could be a use-case to ignore the init args for config ?

But also this is a special case for flax so maybe we could hardcode this for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just to clarify, hardcode how? Using a different list instead of ignore_for_config as I was suggesting, or do you have something else in mind? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Playing around with this a bit now to get a better feeling

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think "parent", "name" are so Flax specific that we should just hide them in the decorator method

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very cool in-general, just left some minor nits.

"""
_missing_keys = set()
config_name = CONFIG_NAME
ignore_for_config = ["parent", "name"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what happens if the model subclass needs to use ignore_for_config for its own variables?
what could be a use-case to ignore the init args for config ?

But also this is a special case for flax so maybe we could hardcode this for now.

@patil-suraj
Copy link
Contributor

patil-suraj commented Sep 14, 2022

Une Question
In FlaxModelMixin.from_pretrained, I have NOT added section about checking missing_keys & mismatched_keys like in trasnformers.flax_from_pretrained.
In short, checking missing_keys & mismatched_keys functionality works by:
random_state = ModelCls.init_weights()
pretrained_state = from_pretrained()
compare keys between random_state VS pretrained_state

La réponse:
Think here we should default to the _do_init=False API in transformers huggingface/transformers#16148

i.e

  • each model should define init_weights method
  • when a model is initialized, weights will never be initialized automatically, the user will always have to call init_weights
model = FlaxUnet(.....)
params = model.init_weights()
  • in from_pretrained we should just do jax.eval_shape to get the list of param names for the model. This way we can detect missing keys without actually initializing the weights. This helps avoid memory fragmentation. And then we always force the user to call init_weights.
model, params =  FlaxUnet.from_pretrained(.....)
params = model.init_weights(params)

Refer to the linked PR to see how it can be implemented.

This way we stay close to JAX/Flax philosophy , nothing is implicitly initialized, the user should be responsible to handle the initialization.

Let me know what you think :)

@pcuenca
Copy link
Member

pcuenca commented Sep 14, 2022

  • in from_pretrained we should just do jax.eval_shape to get the list of param names for the model. This way we can detect missing keys without actually initializing the weights. This helps avoid memory fragmentation. And then we always force the user to call init_weights.
model, params =  FlaxUnet.from_pretrained(.....)
params = model.init_weights(params)

This way we stay close to JAX/Flax philosophy , nothing is implicitly initialized, the user should be responsible to handle the initialization.

Let me know what you think :)

My initial opinion before I read this comment was that from_pretrained should take care of everything so the user does not need to invoke init_weights. I think this aligns with the philosophy behind from_pretrained in Hugging Face, where it does what it can to get you what you need. In addition, I suppose many users of Flax diffusers will be new to flax. They won't necessarily be familiar with Flax nuances, and probably just want to use a model. I think it's easy to convey that the model parameters are stored separately from the model, but it's harder to explain why they need to follow two steps in order to use them. If we remove one step, it's one step less that could get forgotten or misinterpreted.

There are other libraries that abstract away weight loading, see for example flaxmodels here.

mishig25 and others added 2 commits September 14, 2022 08:23
Co-authored-by: Suraj Patil <surajp815@gmail.com>
@mishig25
Copy link
Contributor Author

mishig25 commented Sep 14, 2022

  • each model should define init_weights method
  • when a model is initialized, weights will never be initialized automatically, the user will always have to call init_weights
model = FlaxUnet(.....)
params = model.init_weights()

2 questions:

Question 1:
@patil-suraj, since FlaxModel is just linen.Module, why would a user not do:

x = random.normal(key1, (...)) # Dummy input
params = FlaxModel.init(key2, x) # Initialization call since FlaxModel(nn.Module)

is it because doing the "normal flax" way (i.e. FlaxModel.init(key2, x)) would be bit complicated for the user with getting the correct input shape etc., so we are wrapping it up in a call FlaxModel.init_weights ?

Question 2:
In the case of implementing init_weights method, we do need to introduce property input_shape to FlaxModel right ? (although in this case, it is not used in FlaxUNet2DConditionModel, I guess it will be used in other FlaxModels)

@patrickvonplaten
Copy link
Contributor

Super nice PR - will play around with it a bit now!

@patrickvonplaten
Copy link
Contributor

@mishig25 for now I'd just add a init_weights

  • ld just do jax.eval_shape to get the list of param names for the model. This way we can detect missing keys without actually initializing the we

Big +1 here. @mishig25 also happy to add this functionality though in a follow up PR to not blow this one up

@patrickvonplaten
Copy link
Contributor

I very much agree with @patil-suraj comments above. Note however that for diffusion models there is no "let's just add a head and fine-tune this" logic => therefore if one does .from_pretrained(...) we check automatically that no keys are missing and then there is no need anymore to call init_weights. So one doesn't have to call init_weights after having run from_pretrained(...) usually

@patrickvonplaten
Copy link
Contributor

@mishig25 to answer your second question. Indeed we might not even need an init_weights function ourselves! We could just use the "native" init function. It's just a bit annoying to have to define the input shape/tensor oneself just for the init. But happy to tackle this in a follow up PR

@mishig25
Copy link
Contributor Author

mishig25 commented Sep 14, 2022

@patrickvonplaten let's do follow PR to tackle init_weights (if needed) so that we can get going with FlaxModelMixin from the main branch.

I will fix the docstrings now.
Also, do I need to add soft dependency checks like: if flax is installed ?

@patrickvonplaten
Copy link
Contributor

Works like a charm! You can test it with:

from diffusers.modeling_flax_utils import FlaxModelMixin
from diffusers.configuration_utils import flax_register_to_config, ConfigMixin
import flax.linen as nn
import jax
from jax.random import PRNGKey
import jax.numpy as jnp


@flax_register_to_config
class Adder(nn.Module, FlaxModelMixin, ConfigMixin):
    vocab_size: int
    hidden_size: int = 8
    initializer_range: float = 0.4

    def setup(self):
        self.word_embeddings = nn.Embed(
            self.config.vocab_size,
            self.config.hidden_size,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )

    def __call__(self, input_tensor):
        return self.word_embeddings(input_tensor) + 6


adder = Adder(5, hidden_size=10)
ones = jnp.ones((1, 1), dtype="i4")
rngs = {"params": PRNGKey(0)}
params = adder.init(rngs, ones)

adder.save_pretrained("dummy", params)
adder_new, params = Adder.from_pretrained("./dummy")

@patrickvonplaten
Copy link
Contributor

Update

I've removed the ignore_for_config for now as we've used this mostly for the schedulers __init__ method. I don't think Flax models will need this soon. Also as @patil-suraj, ("parent", "name") are now hard-coded just as it's done in Flax' original codebase.

Some tweaks to the constructor were necessary, but apart from this it looks great! Awesome job @mishig25

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the weight name is changed to "diffusion_flax_model.msgpack" let's merge this PR to unblock a couple of other PRs :-)

TODOs

  • Add functionality that automatically detects missing keys and in this case, I think we should just throw an error in from_pretrained(...) for now. This way we can follow @pcuenca suggestion and don't have to manually run init(...) after having run from_pretrained(...)
  • Tests - we should add some tests here

@patrickvonplaten
Copy link
Contributor

ow get this warning when loading the pretraine

Hmm double checking maybe I added a bug somewhere - this warning shouldn't happen. Thanks for the heads-up

@pcuenca
Copy link
Member

pcuenca commented Sep 14, 2022

Hmm double checking maybe I added a bug somewhere - this warning shouldn't happen. Thanks for the heads-up

The removal of ignore_for_config maybe? They are not excluded from the expected list any more.

@patrickvonplaten
Copy link
Contributor

@pcuenca - actually reverting some logic a bit. Will add a _flax_internal_args class attribute since now we need the name in two places

@mishig25
Copy link
Contributor Author

@patil-suraj docstrings are updated.

Please let me know if/when I should merge

@kashif
Copy link
Contributor

kashif commented Sep 14, 2022

@mishig25 can you add the imports to init file and create dummy imports?

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good to go in my opinion. Great work!

>>> from diffusers import FlaxUNet2DConditionModel

>>> # load model
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will probably require a revision argument if we store the weights in a different branch. But we can do that later when we push the weights.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, lets update the docstrings when "CompVis/stable-diffusion-v1-4" gets updated

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point @pcuenca! We can also work with allow/reject regexes so that PyTorch wouldn't download Flax weights - this way everything could be in the same branch. Note that this is only relevant for the pipeline yet though

`CompVis/stable-diffusion-v1-4`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
e.g., `./my_model_directory/`.
- A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work? Just curious, I don't know if it does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it 82311a5 because we don't currently have conversion script between flax <> pt

Mishig Davaadorj and others added 3 commits September 14, 2022 15:50
@mishig25
Copy link
Contributor Author

mishig25 commented Sep 14, 2022

@mishig25 can you add the imports to init file and create dummy imports?

@kashif Added FlaxModelMixin to init.py a94968f
However, I am not sure where we would add flax dummy object?

@kashif
Copy link
Contributor

kashif commented Sep 14, 2022

@mishig25 run: make fix-copies

@mishig25
Copy link
Contributor Author

mishig25 commented Sep 14, 2022

@kashif nothing happened

(base) mishig@vorace:~/diffusers$ make fix-copies
python utils/check_dummies.py --fix_and_overwrite
(base) mishig@vorace:~/diffusers$ git status
On branch flax_model_mixin
Your branch is up to date with 'origin/flax_model_mixin'.

nothing to commit, working tree clean
(base) mishig@vorace:~/diffusers$ 

@kashif
Copy link
Contributor

kashif commented Sep 14, 2022

@mishig25 perhaps first update your branch and then i believe you need to put the imports inside the is_flax_available... stuff... have a look at another init file

@mishig25
Copy link
Contributor Author

@kashif thanks a lot for the pointers! I've done so after rebasing from main

@@ -63,6 +63,7 @@
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403

if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

@patrickvonplaten
Copy link
Contributor

Once tests are green - I think we can merge 🥳

@mishig25 mishig25 merged commit 83a7bb2 into main Sep 14, 2022
@mishig25 mishig25 deleted the flax_model_mixin branch September 14, 2022 14:34
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Implement `FlaxModelMixin`

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Rm incorrect docstring

* Add FlaxModelMixin to __init__.py

* make fix-copies

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants