|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 | """ ConfigMixinuration base class and utilities."""
|
| 17 | +import dataclasses |
17 | 18 | import functools
|
18 | 19 | import inspect
|
19 | 20 | import json
|
@@ -271,6 +272,11 @@ def extract_init_dict(cls, config_dict, **kwargs):
|
271 | 272 | # remove general kwargs if present in dict
|
272 | 273 | if "kwargs" in expected_keys:
|
273 | 274 | expected_keys.remove("kwargs")
|
| 275 | + # remove flax interal keys |
| 276 | + if hasattr(cls, "_flax_internal_args"): |
| 277 | + for arg in cls._flax_internal_args: |
| 278 | + expected_keys.remove(arg) |
| 279 | + |
274 | 280 | # remove keys to be ignored
|
275 | 281 | if len(cls.ignore_for_config) > 0:
|
276 | 282 | expected_keys = expected_keys - set(cls.ignore_for_config)
|
@@ -401,3 +407,44 @@ def inner_init(self, *args, **kwargs):
|
401 | 407 | getattr(self, "register_to_config")(**new_kwargs)
|
402 | 408 |
|
403 | 409 | return inner_init
|
| 410 | + |
| 411 | + |
| 412 | +def flax_register_to_config(cls): |
| 413 | + original_init = cls.__init__ |
| 414 | + |
| 415 | + @functools.wraps(original_init) |
| 416 | + def init(self, *args, **kwargs): |
| 417 | + if not isinstance(self, ConfigMixin): |
| 418 | + raise RuntimeError( |
| 419 | + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " |
| 420 | + "not inherit from `ConfigMixin`." |
| 421 | + ) |
| 422 | + |
| 423 | + # Ignore private kwargs in the init. Retrieve all passed attributes |
| 424 | + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} |
| 425 | + |
| 426 | + # Retrieve default values |
| 427 | + fields = dataclasses.fields(self) |
| 428 | + default_kwargs = {} |
| 429 | + for field in fields: |
| 430 | + # ignore flax specific attributes |
| 431 | + if field.name in self._flax_internal_args: |
| 432 | + continue |
| 433 | + if type(field.default) == dataclasses._MISSING_TYPE: |
| 434 | + default_kwargs[field.name] = None |
| 435 | + else: |
| 436 | + default_kwargs[field.name] = getattr(self, field.name) |
| 437 | + |
| 438 | + # Make sure init_kwargs override default kwargs |
| 439 | + new_kwargs = {**default_kwargs, **init_kwargs} |
| 440 | + |
| 441 | + # Get positional arguments aligned with kwargs |
| 442 | + for i, arg in enumerate(args): |
| 443 | + name = fields[i].name |
| 444 | + new_kwargs[name] = arg |
| 445 | + |
| 446 | + getattr(self, "register_to_config")(**new_kwargs) |
| 447 | + original_init(self, *args, **kwargs) |
| 448 | + |
| 449 | + cls.__init__ = init |
| 450 | + return cls |
0 commit comments