From 95073e13b2172bcadddd4ed528360c4f216d753b Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Tue, 13 Sep 2022 09:20:34 +0000 Subject: [PATCH 01/17] Implement `FlaxModelMixin` --- src/diffusers/configuration_utils.py | 40 +++ src/diffusers/modeling_flax_utils.py | 500 +++++++++++++++++++++++++++ 2 files changed, 540 insertions(+) create mode 100644 src/diffusers/modeling_flax_utils.py diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fbe75f3f1441..bd08f25bffdf 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -401,3 +401,43 @@ def inner_init(self, *args, **kwargs): getattr(self, "register_to_config")(**new_kwargs) return inner_init + + +def flax_register_to_config(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + # original_init(self, *args, **init_kwargs) + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + getattr(self, "register_to_config")(**new_kwargs) + + original_init(self, *args, **init_kwargs) + + cls.__init__ = init + return cls diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py new file mode 100644 index 000000000000..dbfe258c9c13 --- /dev/null +++ b/src/diffusers/modeling_flax_utils.py @@ -0,0 +1,500 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pickle import UnpicklingError +from typing import Any, Dict, Union + +import jax +import jax.numpy as jnp +import msgpack.exceptions +from flax.core.frozen_dict import FrozenDict +from flax.serialization import from_bytes, to_bytes +from flax.traverse_util import flatten_dict, unflatten_dict +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from .modeling_utils import WEIGHTS_NAME +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging + + +FLAX_WEIGHTS_NAME = "flax_model.msgpack" # TODO should be "diffusion_flax_model.msgpack" + +logger = logging.get_logger(__name__) + + +class FlaxModelMixin: + r""" + Base class for all flax models. + + [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models. + """ + _missing_keys = set() + config_name = CONFIG_NAME + ignore_for_config = ["parent", "name"] + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a Flax model. + """ + return "flax" + + def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_map(conditional_cast, params) + + flat_params = flatten_dict(params) + flat_mask, _ = jax.tree_flatten(mask) + + for masked, key in zip(flat_mask, flat_params.keys()): + if masked: + param = flat_params[key] + flat_params[key] = conditional_cast(param) + + return unflatten_dict(flat_params) + + def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast + the `params` in place. + + This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full + half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip. + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # load model + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision + >>> model.params = model.to_bf16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_bf16(model.params, mask) + ```""" + return self._cast_floating_to(params, jnp.bfloat16, mask) + + def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the + model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # Download model and configuration from huggingface.co + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> # By default, the model params will be in fp32, to illustrate the use of this method, + >>> # we'll first cast to fp16 and back to fp32 + >>> model.params = model.to_f16(model.params) + >>> # now cast back to fp32 + >>> model.params = model.to_fp32(model.params) + ```""" + return self._cast_floating_to(params, jnp.float32, mask) + + def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + `params` in place. + + This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full + half-precision training or to save weights in float16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # load model + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> # By default, the model params will be in fp32, to cast these to float16 + >>> model.params = model.to_fp16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_fp16(model.params, mask) + ```""" + return self._cast_floating_to(params, jnp.float16, mask) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + dtype: jnp.dtype = jnp.float32, + *model_args, + **kwargs, + ): + r""" + Instantiate a pretrained flax model from a pre-trained model configuration. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], + e.g., `./my_model_directory/`. + - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, + `from_pt` should be set to `True`. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and + [`~ModelMixin.to_bf16`]. + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import BertConfig, FlaxBertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") + >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./pt_model/config.json") + >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) + ```""" + config = kwargs.pop("config", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_auto_class = kwargs.pop("_from_auto", False) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} + + # Load config if we don't provide a configuration + config_path = config if config is not None else pretrained_model_name_or_path + model, model_kwargs = cls.from_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + # model args + dtype=dtype, + **kwargs, + ) + + # Load model + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) + # At this stage we don't have a weight file so we will raise an error. + elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + "but there is a file for PyTorch weights." + ) + else: + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_model_name_or_path}." + ) + else: + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" + f"{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" + " internet connection or see how to run the library in offline mode at" + " 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + + try: + with open(model_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.arrays + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 + state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) + + # flatten dicts + state = flatten_dict(state) + + # dictionary of key: dtypes for the model params + param_dtypes = jax.tree_map(lambda x: x.dtype, state) + # extract keys of parameters not in jnp.float32 + fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] + bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] + + # raise a warning if any of the parameters are not in jnp.float32 + if len(fp16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~ModelMixin.to_fp32`] for further information on how to do this." + ) + + if len(bf16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~ModelMixin.to_fp32`] for further information on how to do this." + ) + + return model, unflatten_dict(state) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params: Union[Dict, FrozenDict], + is_main_process: bool = True, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~FlaxPreTrainedModel.from_pretrained`]` class method + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + + + + Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, + which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing + folder. Pass along `temp_dir=True` to use a temporary directory instead. + + + + kwargs: + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # save model + output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) + with open(output_model_file, "wb") as f: + model_bytes = to_bytes(params) + f.write(model_bytes) + + logger.info(f"Model weights saved in {output_model_file}") From 91559f3107c7d31763094dae4076837d81cbcea1 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 08:23:27 +0000 Subject: [PATCH 02/17] Rm unused method `framework` --- src/diffusers/modeling_flax_utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index dbfe258c9c13..b8deb4487c66 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -55,13 +55,6 @@ def _from_config(cls, config, **kwargs): """ return cls(config, **kwargs) - @property - def framework(self) -> str: - """ - :str: Identifies that this is a Flax model. - """ - return "flax" - def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. From f7a0ab2d45b6c6d6068cf5c506f3b704fcf9a388 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 10:24:12 +0200 Subject: [PATCH 03/17] Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index b8deb4487c66..4eab80085d92 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -452,7 +452,7 @@ def save_pretrained( ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~FlaxPreTrainedModel.from_pretrained`]` class method + `[`~FlaxModelMixin.from_pretrained`]` class method Arguments: save_directory (`str` or `os.PathLike`): From 5d81bf8eef9677d058bffc68b0330dbb55a88d9e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 14:03:54 +0200 Subject: [PATCH 04/17] some more changes --- src/diffusers/configuration_utils.py | 41 ++++++++++++++-------------- src/diffusers/modeling_flax_utils.py | 2 -- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index bd08f25bffdf..07f2f73d2cf3 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -21,6 +21,7 @@ import re from collections import OrderedDict from typing import Any, Dict, Tuple, Union +import dataclasses from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError @@ -408,36 +409,36 @@ def flax_register_to_config(cls): @functools.wraps(original_init) def init(self, *args, **kwargs): - # Ignore private kwargs in the init. - init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} - # original_init(self, *args, **init_kwargs) if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " "not inherit from `ConfigMixin`." ) - ignore = getattr(self, "ignore_for_config", []) + # Ignore private kwargs in the init. Retrieve all passed attributes + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + + # Retrieve default values + fields = dataclasses.fields(self) + default_kwargs = {} + for field in fields: + if field.name in ("parent", "name"): + continue + if type(field.default) == dataclasses._MISSING_TYPE: + default_kwargs[field.name] = None + else: + default_kwargs[field.name] = getattr(self, field.name) + + # Make sure init_kwargs override default kwargs + new_kwargs = {**default_kwargs, **init_kwargs} + # Get positional arguments aligned with kwargs - new_kwargs = {} - signature = inspect.signature(init) - parameters = { - name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore - } - for arg, name in zip(args, parameters.keys()): + for i, arg in enumerate(args): + name = fields[i].name new_kwargs[name] = arg - # Then add all kwargs - new_kwargs.update( - { - k: init_kwargs.get(k, default) - for k, default in parameters.items() - if k not in ignore and k not in new_kwargs - } - ) getattr(self, "register_to_config")(**new_kwargs) - - original_init(self, *args, **init_kwargs) + original_init(self, *args, **kwargs) cls.__init__ = init return cls diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 4eab80085d92..ba431686d7ec 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -43,9 +43,7 @@ class FlaxModelMixin: [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading and saving models. """ - _missing_keys = set() config_name = CONFIG_NAME - ignore_for_config = ["parent", "name"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] @classmethod From 1430ab80807feccf34876053f8e11e7e4eb7d08e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 14:04:52 +0200 Subject: [PATCH 05/17] make style --- src/diffusers/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 07f2f73d2cf3..f648c932083a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ ConfigMixinuration base class and utilities.""" +import dataclasses import functools import inspect import json @@ -21,7 +22,6 @@ import re from collections import OrderedDict from typing import Any, Dict, Tuple, Union -import dataclasses from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError From 6a2a4c1f13beeb704b967509c3e4db7e2f4c96e0 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 12:45:48 +0000 Subject: [PATCH 06/17] Add comment --- src/diffusers/configuration_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f648c932083a..e567655f4f77 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -422,6 +422,7 @@ def init(self, *args, **kwargs): fields = dataclasses.fields(self) default_kwargs = {} for field in fields: + # ignore flax specific attributes if field.name in ("parent", "name"): continue if type(field.default) == dataclasses._MISSING_TYPE: From 2bf02677cd94cd5f3ae0b6e0a90180aa32fe0282 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 14:53:18 +0200 Subject: [PATCH 07/17] Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index ba431686d7ec..e06552e032f1 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -31,7 +31,7 @@ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging -FLAX_WEIGHTS_NAME = "flax_model.msgpack" # TODO should be "diffusion_flax_model.msgpack" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" # TODO should be "diffusion_flax_model.msgpack" logger = logging.get_logger(__name__) From 25ab3cad11e624fd16d16283fa7224fe082881b8 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 12:53:51 +0000 Subject: [PATCH 08/17] Rm unneeded comment --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index e06552e032f1..61b66e5bb546 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -31,7 +31,7 @@ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging -FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" # TODO should be "diffusion_flax_model.msgpack" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" logger = logging.get_logger(__name__) From 1e8466e49c5678ac0c719e478d32a9ad3966d82c Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 13:05:34 +0000 Subject: [PATCH 09/17] Update docstrings --- src/diffusers/modeling_flax_utils.py | 44 ++++++---------------------- 1 file changed, 9 insertions(+), 35 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 61b66e5bb546..7a00cce850fc 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -118,7 +118,7 @@ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" - Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the + Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. Arguments: @@ -145,7 +145,7 @@ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" - Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the `params` in place. This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full @@ -225,27 +225,9 @@ def from_pretrained( [`~ModelMixin.to_bf16`]. model_args (sequence of positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. - config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): - Can be either: - - - an instance of a class derived from [`PretrainedConfig`], - - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. - - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a @@ -274,7 +256,7 @@ def from_pretrained( underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to a configuration attribute will be used to override said attribute with the supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's `__init__` function. @@ -446,7 +428,6 @@ def save_pretrained( save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict], is_main_process: bool = True, - **kwargs, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the @@ -455,19 +436,12 @@ def save_pretrained( Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. - - - - Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, - which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing - folder. Pass along `temp_dir=True` to use a temporary directory instead. - - - - kwargs: - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") From 6842d29e30a497f52331c90d277241bde9fad3e9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 15:07:55 +0200 Subject: [PATCH 10/17] correct ignore kwargs --- src/diffusers/configuration_utils.py | 7 ++++++- src/diffusers/modeling_flax_utils.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index e567655f4f77..bb66205412c3 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -272,6 +272,11 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove general kwargs if present in dict if "kwargs" in expected_keys: expected_keys.remove("kwargs") + # remove flax interal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + # remove keys to be ignored if len(cls.ignore_for_config) > 0: expected_keys = expected_keys - set(cls.ignore_for_config) @@ -423,7 +428,7 @@ def init(self, *args, **kwargs): default_kwargs = {} for field in fields: # ignore flax specific attributes - if field.name in ("parent", "name"): + if field.name in self._flax_internal_args: continue if type(field.default) == dataclasses._MISSING_TYPE: default_kwargs[field.name] = None diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 61b66e5bb546..06658cf60ad8 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -45,6 +45,7 @@ class FlaxModelMixin: """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _flax_internal_args = ["name", "parent"] @classmethod def _from_config(cls, config, **kwargs): From 0f26c05ab096f5914cd3eb5f81888af59e61a9eb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 15:08:06 +0200 Subject: [PATCH 11/17] make style --- src/diffusers/modeling_flax_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 77121c9327fc..f4aa7f683b60 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -257,10 +257,10 @@ def from_pretrained( underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. + initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to + a configuration attribute will be used to override said attribute with the supplied `kwargs` + value. Remaining keys that do not correspond to any configuration attribute will be passed to the + underlying model's `__init__` function. Examples: From d98e8c70565919f3983fe333aa42fdedd856e781 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 13:14:47 +0000 Subject: [PATCH 12/17] Update docstring examples --- src/diffusers/modeling_flax_utils.py | 44 +++++++++++++--------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 7a00cce850fc..eb5b29c3ac44 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -95,24 +95,24 @@ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): Examples: ```python - >>> from transformers import FlaxBertModel + >>> from diffusers import FlaxUNet2DConditionModel >>> # load model - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision - >>> model.params = model.to_bf16(model.params) + >>> params = model.to_bf16(params) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") - >>> flat_params = traverse_util.flatten_dict(model.params) + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) - >>> model.params = model.to_bf16(model.params, mask) + >>> params = model.to_bf16(params, mask) ```""" return self._cast_floating_to(params, jnp.bfloat16, mask) @@ -131,15 +131,15 @@ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): Examples: ```python - >>> from transformers import FlaxBertModel + >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # By default, the model params will be in fp32, to illustrate the use of this method, >>> # we'll first cast to fp16 and back to fp32 - >>> model.params = model.to_f16(model.params) + >>> params = model.to_f16(params) >>> # now cast back to fp32 - >>> model.params = model.to_fp32(model.params) + >>> params = model.to_fp32(params) ```""" return self._cast_floating_to(params, jnp.float32, mask) @@ -161,24 +161,24 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): Examples: ```python - >>> from transformers import FlaxBertModel + >>> from diffusers import FlaxUNet2DConditionModel >>> # load model - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # By default, the model params will be in fp32, to cast these to float16 - >>> model.params = model.to_fp16(model.params) + >>> params = model.to_fp16(params) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") - >>> flat_params = traverse_util.flatten_dict(model.params) + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) - >>> model.params = model.to_fp16(model.params, mask) + >>> params = model.to_fp16(params, mask) ```""" return self._cast_floating_to(params, jnp.float16, mask) @@ -205,8 +205,7 @@ def from_pretrained( Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a - user or organization name, like `dbmdz/bert-base-german-cased`. + Valid model ids are namespaced under a user or organization name, like `CompVis/stable-diffusion-v1-4`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`. - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, @@ -264,15 +263,12 @@ def from_pretrained( Examples: ```python - >>> from transformers import BertConfig, FlaxBertModel + >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co and cache. - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). - >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") - >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). - >>> config = BertConfig.from_json_file("./pt_model/config.json") - >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") ```""" config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) From 5d085770979b3e61ae70da88210d0dd7a61c9ddc Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 13:21:24 +0000 Subject: [PATCH 13/17] Make style --- src/diffusers/modeling_flax_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 6b5b56164981..1abf900fa359 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -206,7 +206,8 @@ def from_pretrained( Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids are namespaced under a user or organization name, like `CompVis/stable-diffusion-v1-4`. + Valid model ids are namespaced under a user or organization name, like + `CompVis/stable-diffusion-v1-4`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`. - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, From 091fa42896e0b52bebc0353d241e97baf25cbc32 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 15:50:44 +0200 Subject: [PATCH 14/17] Update src/diffusers/modeling_flax_utils.py Co-authored-by: Pedro Cuenca --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 1abf900fa359..4df1039c6f31 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -102,7 +102,7 @@ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision >>> params = model.to_bf16(params) - >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util From 82311a5dc3fb4a310adc70d0192588feea849ef9 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 13:52:53 +0000 Subject: [PATCH 15/17] Rm incorrect docstring --- src/diffusers/modeling_flax_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 4df1039c6f31..4f2d25dfb168 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -210,8 +210,6 @@ def from_pretrained( `CompVis/stable-diffusion-v1-4`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, - `from_pt` should be set to `True`. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). From a94968f434022298a9f4a2e40409ea0465561947 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 14:07:37 +0000 Subject: [PATCH 16/17] Add FlaxModelMixin to __init__.py --- src/diffusers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 14fb19ef408e..15a790e0c369 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -10,6 +10,7 @@ __version__ = "0.4.0.dev0" from .configuration_utils import ConfigMixin +from .modeling_flax_utils import FlaxModelMixin from .modeling_utils import ModelMixin from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .onnx_utils import OnnxRuntimeModel From 642279fc199e6266ceb466bbab8eb8f12d6a6cb1 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 14:14:20 +0000 Subject: [PATCH 17/17] make fix-copies --- src/diffusers/utils/dummy_flax_objects.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index b5f4362bcb6e..981dc5586ad9 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -4,6 +4,13 @@ from ..utils import DummyObject, requires_backends +class FlaxModelMixin(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"]