Skip to content

Commit 83a7bb2

Browse files
Mishig Davaadorjpatil-surajpatrickvonplatenpcuenca
authored
Implement FlaxModelMixin (#493)
* Implement `FlaxModelMixin` * Rm unused method `framework` * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * some more changes * make style * Add comment * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Rm unneeded comment * Update docstrings * correct ignore kwargs * make style * Update docstring examples * Make style * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Rm incorrect docstring * Add FlaxModelMixin to __init__.py * make fix-copies Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent 8b45096 commit 83a7bb2

File tree

4 files changed

+516
-0
lines changed

4 files changed

+516
-0
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
6464

6565
if is_flax_available():
66+
from .modeling_flax_utils import FlaxModelMixin
6667
from .schedulers import FlaxPNDMScheduler
6768
else:
6869
from .utils.dummy_flax_objects import * # noqa F403

src/diffusers/configuration_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
""" ConfigMixinuration base class and utilities."""
17+
import dataclasses
1718
import functools
1819
import inspect
1920
import json
@@ -271,6 +272,11 @@ def extract_init_dict(cls, config_dict, **kwargs):
271272
# remove general kwargs if present in dict
272273
if "kwargs" in expected_keys:
273274
expected_keys.remove("kwargs")
275+
# remove flax interal keys
276+
if hasattr(cls, "_flax_internal_args"):
277+
for arg in cls._flax_internal_args:
278+
expected_keys.remove(arg)
279+
274280
# remove keys to be ignored
275281
if len(cls.ignore_for_config) > 0:
276282
expected_keys = expected_keys - set(cls.ignore_for_config)
@@ -401,3 +407,44 @@ def inner_init(self, *args, **kwargs):
401407
getattr(self, "register_to_config")(**new_kwargs)
402408

403409
return inner_init
410+
411+
412+
def flax_register_to_config(cls):
413+
original_init = cls.__init__
414+
415+
@functools.wraps(original_init)
416+
def init(self, *args, **kwargs):
417+
if not isinstance(self, ConfigMixin):
418+
raise RuntimeError(
419+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
420+
"not inherit from `ConfigMixin`."
421+
)
422+
423+
# Ignore private kwargs in the init. Retrieve all passed attributes
424+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
425+
426+
# Retrieve default values
427+
fields = dataclasses.fields(self)
428+
default_kwargs = {}
429+
for field in fields:
430+
# ignore flax specific attributes
431+
if field.name in self._flax_internal_args:
432+
continue
433+
if type(field.default) == dataclasses._MISSING_TYPE:
434+
default_kwargs[field.name] = None
435+
else:
436+
default_kwargs[field.name] = getattr(self, field.name)
437+
438+
# Make sure init_kwargs override default kwargs
439+
new_kwargs = {**default_kwargs, **init_kwargs}
440+
441+
# Get positional arguments aligned with kwargs
442+
for i, arg in enumerate(args):
443+
name = fields[i].name
444+
new_kwargs[name] = arg
445+
446+
getattr(self, "register_to_config")(**new_kwargs)
447+
original_init(self, *args, **kwargs)
448+
449+
cls.__init__ = init
450+
return cls

0 commit comments

Comments
 (0)