-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Implement FlaxModelMixin
#493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
I love this. I'll test it with my changes in #485. |
@pcuenca please let me know if you have any questions. |
src/diffusers/modeling_flax_utils.py
Outdated
""" | ||
_missing_keys = set() | ||
config_name = CONFIG_NAME | ||
ignore_for_config = ["parent", "name"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this will be an issue in practice, but what happens if the model subclass needs to use ignore_for_config
for its own variables? In that case, it would need to also append "parent"
and "name"
to the list.
Would it be better if FlaxModelMixin
knows about these special names without adding them to the list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but what happens if the model subclass needs to use ignore_for_config for its own variables?
what could be a use-case to ignore the init args for config ?
But also this is a special case for flax so maybe we could hardcode this for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, just to clarify, hardcode how? Using a different list instead of ignore_for_config
as I was suggesting, or do you have something else in mind? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Playing around with this a bit now to get a better feeling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think "parent", "name" are so Flax specific that we should just hide them in the decorator method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very cool in-general, just left some minor nits.
src/diffusers/modeling_flax_utils.py
Outdated
""" | ||
_missing_keys = set() | ||
config_name = CONFIG_NAME | ||
ignore_for_config = ["parent", "name"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but what happens if the model subclass needs to use ignore_for_config for its own variables?
what could be a use-case to ignore the init args for config ?
But also this is a special case for flax so maybe we could hardcode this for now.
La réponse: i.e
model = FlaxUnet(.....)
params = model.init_weights()
model, params = FlaxUnet.from_pretrained(.....)
params = model.init_weights(params) Refer to the linked PR to see how it can be implemented. This way we stay close to JAX/Flax philosophy , nothing is implicitly initialized, the user should be responsible to handle the initialization. Let me know what you think :) |
My initial opinion before I read this comment was that There are other libraries that abstract away weight loading, see for example flaxmodels here. |
Co-authored-by: Suraj Patil <surajp815@gmail.com>
2 questions: Question 1: x = random.normal(key1, (...)) # Dummy input
params = FlaxModel.init(key2, x) # Initialization call since FlaxModel(nn.Module) is it because doing the "normal flax" way (i.e. Question 2: |
Super nice PR - will play around with it a bit now! |
I very much agree with @patil-suraj comments above. Note however that for diffusion models there is no "let's just add a head and fine-tune this" logic => therefore if one does |
@mishig25 to answer your second question. Indeed we might not even need an |
@patrickvonplaten let's do follow PR to tackle I will fix the docstrings now. |
Works like a charm! You can test it with: from diffusers.modeling_flax_utils import FlaxModelMixin
from diffusers.configuration_utils import flax_register_to_config, ConfigMixin
import flax.linen as nn
import jax
from jax.random import PRNGKey
import jax.numpy as jnp
@flax_register_to_config
class Adder(nn.Module, FlaxModelMixin, ConfigMixin):
vocab_size: int
hidden_size: int = 8
initializer_range: float = 0.4
def setup(self):
self.word_embeddings = nn.Embed(
self.config.vocab_size,
self.config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
)
def __call__(self, input_tensor):
return self.word_embeddings(input_tensor) + 6
adder = Adder(5, hidden_size=10)
ones = jnp.ones((1, 1), dtype="i4")
rngs = {"params": PRNGKey(0)}
params = adder.init(rngs, ones)
adder.save_pretrained("dummy", params)
adder_new, params = Adder.from_pretrained("./dummy") |
UpdateI've removed the Some tweaks to the constructor were necessary, but apart from this it looks great! Awesome job @mishig25 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the weight name is changed to "diffusion_flax_model.msgpack"
let's merge this PR to unblock a couple of other PRs :-)
TODOs
- Add functionality that automatically detects missing keys and in this case, I think we should just throw an error in
from_pretrained(...)
for now. This way we can follow @pcuenca suggestion and don't have to manually runinit(...)
after having runfrom_pretrained(...)
- Tests - we should add some tests here
Hmm double checking maybe I added a bug somewhere - this warning shouldn't happen. Thanks for the heads-up |
The removal of |
@pcuenca - actually reverting some logic a bit. Will add a |
@patil-suraj docstrings are updated. Please let me know if/when I should merge |
@mishig25 can you add the imports to init file and create dummy imports? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's good to go in my opinion. Great work!
>>> from diffusers import FlaxUNet2DConditionModel | ||
|
||
>>> # load model | ||
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will probably require a revision
argument if we store the weights in a different branch. But we can do that later when we push the weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, lets update the docstrings when "CompVis/stable-diffusion-v1-4" gets updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point @pcuenca! We can also work with allow/reject
regexes so that PyTorch wouldn't download Flax weights - this way everything could be in the same branch. Note that this is only relevant for the pipeline yet though
src/diffusers/modeling_flax_utils.py
Outdated
`CompVis/stable-diffusion-v1-4`. | ||
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], | ||
e.g., `./my_model_directory/`. | ||
- A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work? Just curious, I don't know if it does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it 82311a5 because we don't currently have conversion script between flax <> pt
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@mishig25 run: |
@kashif nothing happened
|
@mishig25 perhaps first update your branch and then i believe you need to put the imports inside the |
@kashif thanks a lot for the pointers! I've done so after rebasing from |
@@ -63,6 +63,7 @@ | |||
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 | |||
|
|||
if is_flax_available(): | |||
from .modeling_flax_utils import FlaxModelMixin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect!
Once tests are green - I think we can merge 🥳 |
* Implement `FlaxModelMixin` * Rm unused method `framework` * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * some more changes * make style * Add comment * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Rm unneeded comment * Update docstrings * correct ignore kwargs * make style * Update docstring examples * Make style * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Rm incorrect docstring * Add FlaxModelMixin to __init__.py * make fix-copies Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Implemented
FlaxModelMixin
FlaxModelMixin is flax equivalent of ModelMixin (i.e. a mixin that has
save_pretrained
,from_pretrained
, & various checkpoints util functions).After being discussed at google/flax#2454 (comment), design-wise, diffusers.FlaxModelMixin differs from transformers.FlaxPreTrainedModel in the fact that: model class should be "state-less" google/flax#2454.
see this comment for testing: #493 (comment)
Example usages:
Class definition (find example here)
from_pretrained (find example here)
call
apply
just like any otherflax.linen.Module
(example here)Une Question
In FlaxModelMixin.from_pretrained, I have NOT added section about checking missing_keys & mismatched_keys like in trasnformers.flax_from_pretrained.
In short, checking missing_keys & mismatched_keys functionality works by:
random_state = ModelCls.init_weights()
pretrained_state = from_pretrained()
random_state
VSpretrained_state
Therefore, to add it, it will require:
and probably call
init_weights
insidefrom_pretrained
since we do NOT haveFlaxModelMixin__init__
Please let me know if I should add this functionality. wdyt
TODOS
cc: @kashif somehow I can't add you as a reviewer