diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 776551c7136d..acdddaac4d26 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -66,6 +66,7 @@
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
+ from .pipeline_flax_utils import FlaxDiffusionPipeline
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
@@ -76,3 +77,8 @@
)
else:
from .utils.dummy_flax_objects import * # noqa F403
+
+if is_flax_available() and is_transformers_available():
+ from .pipelines import FlaxStableDiffusionPipeline
+else:
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py
index d1dfacf36265..ed62b5fe579a 100644
--- a/src/diffusers/modeling_flax_utils.py
+++ b/src/diffusers/modeling_flax_utils.py
@@ -306,16 +306,16 @@ def from_pretrained(
# Load model
if os.path.isdir(pretrained_model_name_or_path):
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
- # Load from a Flax checkpoint
- model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
- # At this stage we don't have a weight file so we will raise an error.
- elif from_pt:
+ if from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
)
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
+ # Load from a Flax checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
+ # At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index e0ac5c8d548b..1242ad6fca7f 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .unet_2d import UNet2DModel
-from .unet_2d_condition import UNet2DConditionModel
-from .vae import AutoencoderKL, VQModel
+from ..utils import is_flax_available, is_torch_available
+
+
+if is_torch_available():
+ from .unet_2d import UNet2DModel
+ from .unet_2d_condition import UNet2DConditionModel
+ from .vae import AutoencoderKL, VQModel
+
+if is_flax_available():
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
+ from .vae_flax import FlaxAutoencoderKL
diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py
index 525be4818dcc..461fb8b0ac33 100644
--- a/src/diffusers/models/attention_flax.py
+++ b/src/diffusers/models/attention_flax.py
@@ -144,7 +144,6 @@ def setup(self):
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
- # import ipdb; ipdb.set_trace()
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py
new file mode 100644
index 000000000000..fca793c0dfc2
--- /dev/null
+++ b/src/diffusers/pipeline_flax_utils.py
@@ -0,0 +1,476 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+import flax
+import PIL
+from flax.core.frozen_dict import FrozenDict
+from huggingface_hub import snapshot_download
+from PIL import Image
+from tqdm.auto import tqdm
+
+from .configuration_utils import ConfigMixin
+from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
+from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
+from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
+
+
+if is_transformers_available():
+ from transformers import FlaxPreTrainedModel
+
+INDEX_FILE = "diffusion_flax_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "FlaxModelMixin": ["save_pretrained", "from_pretrained"],
+ "SchedulerMixin": ["save_config", "from_config"],
+ "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+class DummyChecker:
+ def __init__(self):
+ self.dummy = True
+
+
+def import_flax_or_no_model(module, class_name):
+ try:
+ # 1. First make sure that if a Flax object is present, import this one
+ class_obj = getattr(module, "Flax" + class_name)
+ except AttributeError:
+ # 2. If this doesn't work, it's not a model and we don't append "Flax"
+ class_obj = getattr(module, class_name)
+ except AttributeError:
+ raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")
+
+ return class_obj
+
+
+@flax.struct.dataclass
+class FlaxImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+class FlaxDiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all models.
+
+ [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion
+ pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all
+ pipelines to:
+
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
+ components of the diffusion pipeline.
+ """
+ config_name = "model_index.json"
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ # retrieve library
+ library = module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ pipeline_dir = module.__module__.split(".")[-2]
+ path = module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
+ library = pipeline_dir
+
+ # retrieve class_name
+ class_name = module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]):
+ # TODO: handle inference_state
+ """
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
+ method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class
+ method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ self.save_config(save_directory)
+
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name")
+ model_index_dict.pop("_diffusers_version")
+ model_index_dict.pop("_module", None)
+
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ library = importlib.import_module(library_name)
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class)
+ if issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ # TODO(Patrick, Suraj): to delete after
+ if isinstance(sub_model, DummyChecker):
+ continue
+
+ save_method = getattr(sub_model, save_method_name)
+ expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
+
+ if expects_params:
+ save_method(
+ os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name]
+ )
+ else:
+ save_method(os.path.join(save_directory, pipeline_component_name))
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
+
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
+ `CompVis/ldm-text2im-large-256`.
+ - A path to a *directory* containing pipeline weights saved using
+ [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
+ dtype (`str` or `jnp.dtype`, *optional*):
+ Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information. specify the folder name here.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
+ specific pipeline class. The overritten components are then directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
+ `"CompVis/stable-diffusion-v1-4"`
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import FlaxDiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+
+ >>> # Download pipeline that requires an authorization token
+ >>> # For more information on access tokens, please refer to this section
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+
+ >>> # Download pipeline, but overwrite scheduler
+ >>> from diffusers import LMSDiscreteScheduler
+
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+ >>> pipeline = FlaxDiffusionPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
+ ... )
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_pt = kwargs.pop("from_pt", False)
+ dtype = kwargs.pop("dtype", None)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ config_dict = cls.get_config_dict(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ # make sure we only download sub-folders and `diffusers` filenames
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
+ allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
+
+ # download all allow_patterns
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ allow_patterns=allow_patterns,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.get_config_dict(cached_folder)
+
+ # 2. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ if cls != FlaxDiffusionPipeline:
+ pipeline_class = cls
+ else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+
+ init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ init_kwargs = {}
+
+ # inference_params
+ params = {}
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 3. Load each module in the pipeline
+ for name, (library_name, class_name) in init_dict.items():
+ # TODO(Patrick, Suraj) - delete later
+ if class_name == "DummyChecker":
+ library_name = "stable_diffusion"
+ class_name = "StableDiffusionSafetyChecker"
+
+ is_pipeline_module = hasattr(pipelines, library_name)
+ loaded_sub_model = None
+
+ # if the model is in a pipeline module, then we load it from the pipeline
+ if name in passed_class_obj:
+ # 1. check that passed_class_obj has correct parent class
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
+ f" {expected_class_obj}"
+ )
+ else:
+ logger.warn(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+ # set passed class object
+ loaded_sub_model = passed_class_obj[name]
+ elif is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ if from_pt:
+ class_obj = import_flax_or_no_model(pipeline_module, class_name)
+ else:
+ class_obj = getattr(pipeline_module, class_name)
+
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+ if from_pt:
+ class_obj = import_flax_or_no_model(library, class_name)
+ else:
+ class_obj = getattr(library, class_name)
+
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ if loaded_sub_model is None:
+ load_method_name = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ load_method = getattr(class_obj, load_method_name)
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loadable_folder = os.path.join(cached_folder, name)
+ else:
+ loaded_sub_model = cached_folder
+
+ if issubclass(class_obj, FlaxModelMixin):
+ # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
+ if name == "safety_checker":
+ loaded_sub_model = DummyChecker()
+ loaded_params = DummyChecker()
+ else:
+ loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
+ params[name] = loaded_params
+ elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
+ # make sure we don't initialize the weights to save time
+ if from_pt:
+ # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
+ loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
+ loaded_params = loaded_sub_model.params
+ del loaded_sub_model._params
+ else:
+ loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
+ params[name] = loaded_params
+ elif issubclass(class_obj, SchedulerMixin):
+ loaded_sub_model = load_method(loadable_folder)
+ params[name] = loaded_sub_model.create_state()
+ else:
+ loaded_sub_model = load_method(loadable_folder)
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ model = pipeline_class(**init_kwargs, dtype=dtype)
+ return model, params
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ # TODO: make it compatible with jax.lax
+ def progress_bar(self, iterable):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ return tqdm(iterable, **self._progress_bar_config)
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 3e2aeb4fb2b7..8e3c8592a258 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -1,4 +1,4 @@
-from ..utils import is_onnx_available, is_transformers_available
+from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
@@ -7,7 +7,7 @@
from .stochastic_karras_ve import KarrasVePipeline
-if is_transformers_available():
+if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
@@ -17,3 +17,6 @@
if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline
+
+if is_transformers_available() and is_flax_available():
+ from .stable_diffusion import FlaxStableDiffusionPipeline
diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py
index e41043e0ad53..1016ce69e450 100644
--- a/src/diffusers/pipelines/stable_diffusion/__init__.py
+++ b/src/diffusers/pipelines/stable_diffusion/__init__.py
@@ -37,4 +37,24 @@ class StableDiffusionPipelineOutput(BaseOutput):
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available():
+ import flax
+
+ @flax.struct.dataclass
+ class FlaxStableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: List[bool]
+
+ from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
new file mode 100644
index 000000000000..6cca376791ac
--- /dev/null
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -0,0 +1,219 @@
+from typing import Dict, List, Optional, Union
+
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict
+from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
+
+from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
+from ...pipeline_flax_utils import FlaxDiffusionPipeline
+from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
+from . import FlaxStableDiffusionPipelineOutput
+from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
+
+
+class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`FlaxAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`FlaxCLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`FlaxSchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
+ safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: FlaxAutoencoderKL,
+ text_encoder: FlaxCLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: FlaxUNet2DConditionModel,
+ scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
+ safety_checker: FlaxStableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ dtype: jnp.dtype = jnp.float32,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("np")
+ self.dtype = dtype
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def prepare_inputs(self, prompt: Union[str, List[str]]):
+ if not isinstance(prompt, (str, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ return text_input.input_ids
+
+ def __call__(
+ self,
+ prompt_ids: jnp.array,
+ params: Union[Dict, FrozenDict],
+ prng_seed: jax.random.PRNGKey,
+ num_inference_steps: Optional[int] = 50,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ guidance_scale: Optional[float] = 7.5,
+ latents: Optional[jnp.array] = None,
+ return_dict: bool = True,
+ debug: bool = False,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`jnp.array`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
+ a plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ """
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
+
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
+ batch_size = prompt_ids.shape[0]
+
+ max_length = prompt_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
+ context = jnp.concatenate([uncond_embeddings, text_embeddings])
+
+ # TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline
+ latents_shape = (
+ batch_size,
+ self.unet.in_channels,
+ self.unet.sample_size,
+ self.unet.sample_size,
+ )
+ if latents is None:
+ latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ def loop_body(step, args):
+ latents, scheduler_state = args
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = jnp.concatenate([latents] * 2)
+
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
+ timestep = jnp.broadcast_to(t, latents_input.shape[0])
+
+ # predict the noise residual
+ noise_pred = self.unet.apply(
+ {"params": params["unet"]},
+ jnp.array(latents_input),
+ jnp.array(timestep, dtype=jnp.int32),
+ encoder_hidden_states=context,
+ rngs={},
+ ).sample
+ # perform guidance
+ noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
+ return latents, scheduler_state
+
+ scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
+
+ if debug:
+ # run with python for loop
+ for i in range(num_inference_steps):
+ latents, scheduler_state = loop_body(i, (latents, scheduler_state))
+ else:
+ latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ # TODO: check when flax vae gets merged into main
+ image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
+
+ image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
+
+ # image = jnp.asarray(image).transpose(0, 2, 3, 1)
+ # run safety checker
+ # TODO: check when flax safety checker gets merged into main
+ # safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
+ # image, has_nsfw_concept = self.safety_checker(
+ # images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"]
+ # )
+ has_nsfw_concept = False
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py
index 015b79b2780d..dd5a87df654a 100644
--- a/src/diffusers/schedulers/scheduling_ddim_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddim_flax.py
@@ -21,7 +21,6 @@
import flax
import jax.numpy as jnp
-from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -60,11 +59,12 @@ def alpha_bar(time_step):
class DDIMSchedulerState:
# setable values
timesteps: jnp.ndarray
+ alphas_cumprod: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
- def create(cls, num_train_timesteps: int):
- return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
+ def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
+ return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
@dataclass
@@ -112,13 +112,9 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
- trained_betas: Optional[jnp.ndarray] = None,
- clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
):
- if trained_betas is not None:
- self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
@@ -131,19 +127,24 @@ def __init__(
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
- self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
+
+ # HACK for now - clean up later (PVP)
+ self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
- self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+ self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
- self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps)
+ def create_state(self):
+ return DDIMSchedulerState.create(
+ num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
+ )
- def _get_variance(self, timestep, prev_timestep):
- alpha_prod_t = self.alphas_cumprod[timestep]
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
+ alpha_prod_t = alphas_cumprod[timestep]
+ alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -177,9 +178,6 @@ def step(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
- key: random.KeyArray,
- eta: float = 0.0,
- use_clipped_model_output: bool = False,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
@@ -221,41 +219,28 @@ def step(
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
+ alphas_cumprod = state.alphas_cumprod
+
# 2. compute alphas, betas
- alpha_prod_t = self.alphas_cumprod[timestep]
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ alpha_prod_t = alphas_cumprod[timestep]
+ alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
- # 4. Clip "predicted x_0"
- if self.config.clip_sample:
- pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
-
- # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
- variance = self._get_variance(timestep, prev_timestep)
- std_dev_t = eta * variance ** (0.5)
+ variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
+ std_dev_t = variance ** (0.5)
- if use_clipped_model_output:
- # the model_output is always re-derived from the clipped x_0 in Glide
- model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
-
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
- if eta > 0:
- key = random.split(key, num=1)
- noise = random.normal(key=key, shape=model_output.shape)
- variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
-
- prev_sample = prev_sample + variance
-
if not return_dict:
return (prev_sample, state)
diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py
index efc3858ca75a..4c8c43810b6f 100644
--- a/src/diffusers/schedulers/scheduling_pndm_flax.py
+++ b/src/diffusers/schedulers/scheduling_pndm_flax.py
@@ -148,7 +148,8 @@ def __init__(
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
- self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
+ def create_state(self):
+ return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
"""
diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py
new file mode 100644
index 000000000000..51ee3b184816
--- /dev/null
+++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py
@@ -0,0 +1,11 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class FlaxStableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["flax", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax", "transformers"])
diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py
index 9615afb6f920..1e3ac002a609 100644
--- a/src/diffusers/utils/dummy_flax_objects.py
+++ b/src/diffusers/utils/dummy_flax_objects.py
@@ -11,42 +11,56 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
-class FlaxDDIMScheduler(metaclass=DummyObject):
+class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
-class FlaxDDPMScheduler(metaclass=DummyObject):
+class FlaxAutoencoderKL(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
-class FlaxKarrasVeScheduler(metaclass=DummyObject):
+class FlaxDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
-class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
+class FlaxDDIMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
-class FlaxPNDMScheduler(metaclass=DummyObject):
+class FlaxDDPMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
-class FlaxUNet2DConditionModel(metaclass=DummyObject):
+class FlaxKarrasVeScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):