diff --git a/docs/source/api/configuration.mdx b/docs/source/api/configuration.mdx
index 5c435dc8e1f1..45176f55b018 100644
--- a/docs/source/api/configuration.mdx
+++ b/docs/source/api/configuration.mdx
@@ -10,19 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Models
+# Configuration
-Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
-The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
-The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
+In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
+passed to the respective `__init__` methods in a JSON-configuration file.
-## API
+TODO(PVP) - add example and better info here
-Models should provide the `def forward` function and initialization of the model.
-All saving, loading, and utilities should be in the base ['ModelMixin'] class.
-
-## Examples
-
-- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
-- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
-- TODO: mention VAE / SDE score estimation
\ No newline at end of file
+## ConfigMixin
+[[autodoc]] ConfigMixin
+ - from_config
+ - save_config
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 215bfdf3e082..5d7015e5eb1e 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -9,6 +9,7 @@
__version__ = "0.3.0.dev0"
+from .configuration_utils import ConfigMixin
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .onnx_utils import OnnxRuntimeModel
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 053ccd6429e0..fbe75f3f1441 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -37,9 +37,16 @@
class ConfigMixin:
r"""
- Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
- methods for loading/downloading/saving configurations.
-
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
+ - [`~ConfigMixin.from_config`]
+ - [`~ConfigMixin.save_config`]
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overriden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overriden by parent class).
"""
config_name = None
ignore_for_config = []
@@ -74,8 +81,6 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
- kwargs:
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -90,6 +95,64 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
+ r"""
+ Instantiate a Python class from a pre-defined JSON-file.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+ checkpoint with 3 labels).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model.
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
@@ -298,10 +361,10 @@ def __setitem__(self, name, value):
def register_to_config(init):
- """
- Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
- sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
- registered in the config, use the `ignore_for_config` class variable
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py
index 39e326deb1b6..fb613614a878 100644
--- a/src/diffusers/modeling_utils.py
+++ b/src/diffusers/modeling_utils.py
@@ -119,8 +119,6 @@ class ModelMixin(torch.nn.Module):
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
and saving models.
- Class attributes:
-
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
[`~modeling_utils.ModelMixin.save_pretrained`].
"""
@@ -200,10 +198,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
- user or organization name, like `dbmdz/bert-base-german-cased`.
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
- e.g., `./my_model_directory/`.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
@@ -236,9 +233,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
- kwargs (remaining dictionary of keyword arguments, *optional*):
- Can be used to update the [`ConfigMixin`] of the model (after it being loaded).
-
Passing `use_auth_token=True`` is required when you want to use a private model.