diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 776551c7136d..acdddaac4d26 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -66,6 +66,7 @@ from .modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL + from .pipeline_flax_utils import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, @@ -76,3 +77,8 @@ ) else: from .utils.dummy_flax_objects import * # noqa F403 + +if is_flax_available() and is_transformers_available(): + from .pipelines import FlaxStableDiffusionPipeline +else: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index d1dfacf36265..ed62b5fe579a 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -306,16 +306,16 @@ def from_pretrained( # Load model if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): - # Load from a Flax checkpoint - model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) - # At this stage we don't have a weight file so we will raise an error. - elif from_pt: + if from_pt: if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " ) model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) + # At this stage we don't have a weight file so we will raise an error. elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e0ac5c8d548b..1242ad6fca7f 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .unet_2d import UNet2DModel -from .unet_2d_condition import UNet2DConditionModel -from .vae import AutoencoderKL, VQModel +from ..utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .vae import AutoencoderKL, VQModel + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 525be4818dcc..461fb8b0ac33 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -144,7 +144,6 @@ def setup(self): def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape - # import ipdb; ipdb.set_trace() residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py new file mode 100644 index 000000000000..fca793c0dfc2 --- /dev/null +++ b/src/diffusers/pipeline_flax_utils.py @@ -0,0 +1,476 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from typing import Dict, List, Optional, Union + +import numpy as np + +import flax +import PIL +from flax.core.frozen_dict import FrozenDict +from huggingface_hub import snapshot_download +from PIL import Image +from tqdm.auto import tqdm + +from .configuration_utils import ConfigMixin +from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging + + +if is_transformers_available(): + from transformers import FlaxPreTrainedModel + +INDEX_FILE = "diffusion_flax_model.bin" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "FlaxModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_config", "from_config"], + "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +class DummyChecker: + def __init__(self): + self.dummy = True + + +def import_flax_or_no_model(module, class_name): + try: + # 1. First make sure that if a Flax object is present, import this one + class_obj = getattr(module, "Flax" + class_name) + except AttributeError: + # 2. If this doesn't work, it's not a model and we don't append "Flax" + class_obj = getattr(module, class_name) + except AttributeError: + raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") + + return class_obj + + +@flax.struct.dataclass +class FlaxImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class FlaxDiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion + pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all + pipelines to: + + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + """ + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrieve library + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): + # TODO: handle inference_state + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class + method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class) + if issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + # TODO(Patrick, Suraj): to delete after + if isinstance(sub_model, DummyChecker): + continue + + save_method = getattr(sub_model, save_method_name) + expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) + + if expects_params: + save_method( + os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] + ) + else: + save_method(os.path.join(save_directory, pipeline_component_name)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + dtype (`str` or `jnp.dtype`, *optional*): + Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overritten components are then directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.* + `"CompVis/stable-diffusion-v1-4"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import FlaxDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + + >>> # Download pipeline, but overwrite scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + >>> pipeline = FlaxDiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True + ... ) + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_pt = kwargs.pop("from_pt", False) + dtype = kwargs.pop("dtype", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.get_config_dict( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.get_config_dict(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if cls != FlaxDiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + init_kwargs = {} + + # inference_params + params = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + # TODO(Patrick, Suraj) - delete later + if class_name == "DummyChecker": + library_name = "stable_diffusion" + class_name = "StableDiffusionSafetyChecker" + + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + if from_pt: + class_obj = import_flax_or_no_model(pipeline_module, class_name) + else: + class_obj = getattr(pipeline_module, class_name) + + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + if from_pt: + class_obj = import_flax_or_no_model(library, class_name) + else: + class_obj = getattr(library, class_name) + + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loadable_folder = os.path.join(cached_folder, name) + else: + loaded_sub_model = cached_folder + + if issubclass(class_obj, FlaxModelMixin): + # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here + if name == "safety_checker": + loaded_sub_model = DummyChecker() + loaded_params = DummyChecker() + else: + loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) + params[name] = loaded_params + elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): + # make sure we don't initialize the weights to save time + if from_pt: + # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here + loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) + loaded_params = loaded_sub_model.params + del loaded_sub_model._params + else: + loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) + params[name] = loaded_params + elif issubclass(class_obj, SchedulerMixin): + loaded_sub_model = load_method(loadable_folder) + params[name] = loaded_sub_model.create_state() + else: + loaded_sub_model = load_method(loadable_folder) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + model = pipeline_class(**init_kwargs, dtype=dtype) + return model, params + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + # TODO: make it compatible with jax.lax + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3e2aeb4fb2b7..8e3c8592a258 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ -from ..utils import is_onnx_available, is_transformers_available +from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline @@ -7,7 +7,7 @@ from .stochastic_karras_ve import KarrasVePipeline -if is_transformers_available(): +if is_torch_available() and is_transformers_available(): from .latent_diffusion import LDMTextToImagePipeline from .stable_diffusion import ( StableDiffusionImg2ImgPipeline, @@ -17,3 +17,6 @@ if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline + +if is_transformers_available() and is_flax_available(): + from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index e41043e0ad53..1016ce69e450 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -37,4 +37,24 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] + + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py new file mode 100644 index 000000000000..6cca376791ac --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -0,0 +1,219 @@ +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...pipeline_flax_utils import FlaxDiffusionPipeline +from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`FlaxSchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + scheduler = scheduler.set_format("np") + self.dtype = dtype + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def prepare_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids + + def __call__( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.PRNGKey, + num_inference_steps: Optional[int] = 50, + height: Optional[int] = 512, + width: Optional[int] = 512, + guidance_scale: Optional[float] = 7.5, + latents: Optional[jnp.array] = None, + return_dict: bool = True, + debug: bool = False, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + context = jnp.concatenate([uncond_embeddings, text_embeddings]) + + # TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline + latents_shape = ( + batch_size, + self.unet.in_channels, + self.unet.sample_size, + self.unet.sample_size, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + rngs={}, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps) + + if debug: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + # TODO: check when flax vae gets merged into main + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + + # image = jnp.asarray(image).transpose(0, 2, 3, 1) + # run safety checker + # TODO: check when flax safety checker gets merged into main + # safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + # image, has_nsfw_concept = self.safety_checker( + # images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"] + # ) + has_nsfw_concept = False + + if not return_dict: + return (image, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 015b79b2780d..dd5a87df654a 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -21,7 +21,6 @@ import flax import jax.numpy as jnp -from jax import random from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -60,11 +59,12 @@ def alpha_bar(time_step): class DDIMSchedulerState: # setable values timesteps: jnp.ndarray + alphas_cumprod: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod - def create(cls, num_train_timesteps: int): - return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod) @dataclass @@ -112,13 +112,9 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, ): - if trained_betas is not None: - self.betas = jnp.asarray(trained_betas) if beta_schedule == "linear": self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": @@ -131,19 +127,24 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + # HACK for now - clean up later (PVP) + self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0]) - self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return DDIMSchedulerState.create( + num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod + ) - def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + def _get_variance(self, timestep, prev_timestep, alphas_cumprod): + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -177,9 +178,6 @@ def step( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: random.KeyArray, - eta: float = 0.0, - use_clipped_model_output: bool = False, return_dict: bool = True, ) -> Union[FlaxSchedulerOutput, Tuple]: """ @@ -221,41 +219,28 @@ def step( # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + alphas_cumprod = state.alphas_cumprod + # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - # 4. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = jnp.clip(pred_original_sample, -1, 1) - - # 5. compute variance: "sigma_t(η)" -> see formula (16) + # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) + variance = self._get_variance(timestep, prev_timestep, alphas_cumprod) + std_dev_t = variance ** (0.5) - if use_clipped_model_output: - # the model_output is always re-derived from the clipped x_0 in Glide - model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - if eta > 0: - key = random.split(key, num=1) - noise = random.normal(key=key, shape=model_output.shape) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise - - prev_sample = prev_sample + variance - if not return_dict: return (prev_sample, state) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index efc3858ca75a..4c8c43810b6f 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -148,7 +148,8 @@ def __init__( # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 - self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: """ diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py new file mode 100644 index 000000000000..51ee3b184816 --- /dev/null +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class FlaxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 9615afb6f920..1e3ac002a609 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -11,42 +11,56 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxDDIMScheduler(metaclass=DummyObject): +class FlaxUNet2DConditionModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxDDPMScheduler(metaclass=DummyObject): +class FlaxAutoencoderKL(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxKarrasVeScheduler(metaclass=DummyObject): +class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxLMSDiscreteScheduler(metaclass=DummyObject): +class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxPNDMScheduler(metaclass=DummyObject): +class FlaxDDPMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxUNet2DConditionModel(metaclass=DummyObject): +class FlaxKarrasVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxLMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs):