From b9ca4060d527cfd29b1a1a23cf905e79e5105cd6 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 09:28:00 +0000 Subject: [PATCH 01/15] WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline --- src/diffusers/models/__init__.py | 1 + src/diffusers/pipeline_flax_utils.py | 428 ++++++++++++++++++ .../pipelines/stable_diffusion/__init__.py | 28 ++ .../pipeline_flax_stable_diffusion.py | 245 ++++++++++ 4 files changed, 702 insertions(+) create mode 100644 src/diffusers/pipeline_flax_utils.py create mode 100644 src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e0ac5c8d548b..3c3656a572f0 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -14,4 +14,5 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel +from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae import AutoencoderKL, VQModel diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py new file mode 100644 index 000000000000..ae5b72703548 --- /dev/null +++ b/src/diffusers/pipeline_flax_utils.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from typing import List, Optional, Union + +import numpy as np + +import diffusers +import flax +import jax.numpy as jnp +import PIL +from huggingface_hub import snapshot_download +from PIL import Image +from tqdm.auto import tqdm + +from .configuration_utils import ConfigMixin +from .modeling_flax_utils import FLAX_WEIGHTS_NAME +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging + + +INDEX_FILE = "diffusion_flax_model.bin" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "FlaxModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_config", "from_config"], + "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +@flax.structs.dataclass +class FlaxImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class FlaxDiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion + pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all + pipelines to: + + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + """ + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrieve library + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class + method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class) + if issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + save_method = getattr(sub_model, save_method_name) + save_method(os.path.join(save_directory, pipeline_component_name)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + dtype (`str` or `jnp.dtype`, *optional*): + Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overritten components are then directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.* + `"CompVis/stable-diffusion-v1-4"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import FlaxDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + + >>> # Download pipeline, but overwrite scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + >>> pipeline = FlaxDiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True + ... ) + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + inference_state_dict = kwargs.pop("inference_state_dict", None) + dtype = kwargs.pop("dtype", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.get_config_dict( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.get_config_dict(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if cls != FlaxDiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + init_kwargs = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + loading_kwargs = {} + if issubclass(class_obj, flax.linen.Module): + loading_kwargs["dtype"] = dtype + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + if issubclass(class_obj, flax.linen.Module): + loaded_sub_model, loaded_params = load_method( + os.path.join(cached_folder, name), **loading_kwargs + ) + params_key = f"{name}_params" + if params_key not in inference_state_dict: + inference_state_dict[params_key] = loaded_params + else: + loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + else: + # else load from the root directory + if issubclass(class_obj, flax.linen.Module): + loaded_sub_model, loaded_params = load_method(cached_folder, **loading_kwargs) + params_key = f"{name}_params" + if params_key not in inference_state_dict: + inference_state_dict[params_key] = loaded_params + else: + loaded_sub_model = load_method(cached_folder, **loading_kwargs) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + # 4. Instantiate the pipeline + # TODO: fix hard-coded `StableDifusion.InferenceState`, it should be inferred as `{XYZ_Pipeline}.InferenceState` + from .pipelines.stable_diffusion import InferenceState + + inference_state = InferenceState(**inference_state_dict) + model = pipeline_class(**init_kwargs, dtype=dtype, inference_state=inference_state) + return model + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + # TODO: make it compatible with jax.lax + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5ffda93f1721..e8eeca5cae06 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -3,9 +3,11 @@ import numpy as np +import flax import PIL from PIL import Image +from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from ...utils import BaseOutput, is_onnx_available, is_transformers_available @@ -27,6 +29,32 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected: List[bool] +@flax.struct.dataclass +class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] + + +@flax.struct.dataclass +class InferenceState: + text_encoder_params: flax.core.FrozenDict + unet_params: flax.core.FrozenDict + vae_params: flax.core.FrozenDict + scheduler_state: PNDMSchedulerState + + if is_transformers_available(): from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py new file mode 100644 index 000000000000..b3679d33dd75 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -0,0 +1,245 @@ +import inspect +import warnings +from typing import List, Optional, Union + +import jax +import jax.numpy as jnp +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...configuration_utils import FrozenDict +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...pipeline_flax_utils import FlaxDiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import FlaxStableDiffusionPipelineOutput, InferenceState +from .flax_safety_checker import FlaxStableDiffusionSafetyChecker + + +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offsensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + inference_state: InferenceState, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + scheduler = scheduler.set_format("np") + self.dtype = dtype + self.inference_state = inference_state + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def __call__( + self, + prompt: Union[str, List[str]], + prng_seed: jax.random.PRNGKey, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + latents: Optional[jnp.array] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + debug: bool = False, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + inference_state = self.inference_state + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_embeddings = self.text_encoder(text_input.input_ids, params=inference_state.text_encoder_params)[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=inference_state.text_encoder_params)[0] + context = jnp.concatenate([uncond_embeddings, text_embeddings]) + + # TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline + latents_shape = ( + text_input.input_ids.shape[0], + self.unet.sample_size, + self.unet.sample_size, + self.unet.in_channels, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": inference_state.unet_params}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + rngs={}, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents) + latents = latents["prev_sample"] + return latents, scheduler_state + + scheduler_state = inference_state.scheduler_state + num_inference_steps = len(scheduler_state.timesteps) + if debug: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + # TODO: check when flax vae gets merged into main + image = self.vae.decode(latents, params=inference_state.vae_params).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + # TODO: check when flax safety checker gets merged into main + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_cheker_input.pixel_values, params=inference_state.safety_params + ) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 30abc633dca496d1827b8f80bcb2a5e8038cd5e0 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 09:32:21 +0000 Subject: [PATCH 02/15] todo comment --- src/diffusers/pipeline_flax_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index ae5b72703548..37d4e1418937 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -122,6 +122,7 @@ def register_modules(self, **kwargs): setattr(self, name, module) def save_pretrained(self, save_directory: Union[str, os.PathLike]): + # TODO: handle inference_state """ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading From 4b2becb89057aa0f7d21f7715cb0902fb9a08cdc Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 15:38:50 +0000 Subject: [PATCH 03/15] Fix imports --- src/diffusers/models/__init__.py | 1 + src/diffusers/pipeline_flax_utils.py | 2 +- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3c3656a572f0..a6007d15b9db 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,3 +16,4 @@ from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae import AutoencoderKL, VQModel +from .vae_flax import FlaxAutoencoderKL \ No newline at end of file diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 37d4e1418937..d61472949a73 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -60,7 +60,7 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) -@flax.structs.dataclass +@flax.struct.dataclass class FlaxImagePipelineOutput(BaseOutput): """ Output class for image pipelines. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index b3679d33dd75..896ffabbeec3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -11,7 +11,7 @@ from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from . import FlaxStableDiffusionPipelineOutput, InferenceState -from .flax_safety_checker import FlaxStableDiffusionSafetyChecker +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): From 7f0e4297943ae4c01ccd8505801e2d963b1eecdc Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 15:45:33 +0000 Subject: [PATCH 04/15] Fix imports --- src/diffusers/__init__.py | 8 ++++++++ src/diffusers/pipelines/__init__.py | 5 ++++- src/diffusers/pipelines/stable_diffusion/__init__.py | 1 + 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 776551c7136d..9b71955a7107 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -74,5 +74,13 @@ FlaxPNDMScheduler, FlaxScoreSdeVeScheduler, ) + from .pipeline_flax_utils import FlaxDiffusionPipeline else: from .utils.dummy_flax_objects import * # noqa F403 + +if is_flax_available() and is_transformers_available(): + from .pipelines import FlaxStableDiffusionPipeline +else: + pass + # TODO: dummy_flax_and_transformers_objects + # from .utils.dummy_flax_and_transformers_objects import * # noqa F403 \ No newline at end of file diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3e2aeb4fb2b7..0db153c864ed 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ -from ..utils import is_onnx_available, is_transformers_available +from ..utils import is_onnx_available, is_transformers_available, is_flax_available from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline @@ -17,3 +17,6 @@ if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline + +if is_flax_available(): + from .stable_diffusion import FlaxStableDiffusionPipeline \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index eb2b51155a39..7c30ff0e3d20 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -66,3 +66,4 @@ class InferenceState: if is_transformers_available() and is_flax_available(): from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline From d9e2ae18623686ccb1974a9d1a42af4e515b41ba Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 19 Sep 2022 15:56:17 +0000 Subject: [PATCH 05/15] add dummies --- .../dummy_flax_and_transformers_objects.py | 11 +++++++++++ src/diffusers/utils/dummy_flax_objects.py | 18 ++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/utils/dummy_flax_and_transformers_objects.py diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py new file mode 100644 index 000000000000..51ee3b184816 --- /dev/null +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class FlaxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 9615afb6f920..424e4f3bf6d8 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -11,6 +11,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxUNet2DConditionModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxAutoencoderKL(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] @@ -46,14 +60,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxUNet2DConditionModel(metaclass=DummyObject): +class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxScoreSdeVeScheduler(metaclass=DummyObject): +class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): From d51e8816edbae173c0ece6b8f1bf9495828c5e1d Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Mon, 19 Sep 2022 15:59:49 +0000 Subject: [PATCH 06/15] Fix empty init --- src/diffusers/pipeline_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index d61472949a73..0fb3a43e8d2e 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -258,7 +258,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) - inference_state_dict = kwargs.pop("inference_state_dict", None) + inference_state_dict = kwargs.pop("inference_state_dict", dict()) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs From 7aab68d6b18f77881ed5a3415aaf2174dd7329cf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 19 Sep 2022 20:44:12 +0000 Subject: [PATCH 07/15] make pipeline work --- src/diffusers/__init__.py | 6 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/pipeline_flax_utils.py | 61 +++++++------- src/diffusers/pipelines/__init__.py | 8 +- .../pipelines/stable_diffusion/__init__.py | 2 +- .../pipeline_flax_stable_diffusion.py | 83 +++++++++---------- .../schedulers/scheduling_ddim_flax.py | 61 +++++--------- .../schedulers/scheduling_pndm_flax.py | 3 +- 8 files changed, 100 insertions(+), 126 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9b71955a7107..acdddaac4d26 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -66,6 +66,7 @@ from .modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL + from .pipeline_flax_utils import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, @@ -74,13 +75,10 @@ FlaxPNDMScheduler, FlaxScoreSdeVeScheduler, ) - from .pipeline_flax_utils import FlaxDiffusionPipeline else: from .utils.dummy_flax_objects import * # noqa F403 if is_flax_available() and is_transformers_available(): from .pipelines import FlaxStableDiffusionPipeline else: - pass - # TODO: dummy_flax_and_transformers_objects - # from .utils.dummy_flax_and_transformers_objects import * # noqa F403 \ No newline at end of file + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index a6007d15b9db..b5fe089e05f0 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,4 +16,4 @@ from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae import AutoencoderKL, VQModel -from .vae_flax import FlaxAutoencoderKL \ No newline at end of file +from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index d61472949a73..8092b29bc582 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -17,24 +17,26 @@ import importlib import inspect import os -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np -import diffusers import flax -import jax.numpy as jnp import PIL +from flax.core.frozen_dict import FrozenDict from huggingface_hub import snapshot_download from PIL import Image from tqdm.auto import tqdm from .configuration_utils import ConfigMixin -from .modeling_flax_utils import FLAX_WEIGHTS_NAME -from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging +from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging +if is_transformers_available(): + from transformers import FlaxPreTrainedModel + INDEX_FILE = "diffusion_flax_model.bin" @@ -121,7 +123,7 @@ def register_modules(self, **kwargs): # set models setattr(self, name, module) - def save_pretrained(self, save_directory: Union[str, os.PathLike]): + def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): # TODO: handle inference_state """ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to @@ -258,7 +260,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) - inference_state_dict = kwargs.pop("inference_state_dict", None) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs @@ -312,6 +313,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P init_kwargs = {} + # inference_params + params = {} + # import it here to avoid circular import from diffusers import pipelines @@ -373,34 +377,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): - if issubclass(class_obj, flax.linen.Module): - loaded_sub_model, loaded_params = load_method( - os.path.join(cached_folder, name), **loading_kwargs - ) - params_key = f"{name}_params" - if params_key not in inference_state_dict: - inference_state_dict[params_key] = loaded_params - else: - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) + loadable_folder = os.path.join(cached_folder, name) else: - # else load from the root directory - if issubclass(class_obj, flax.linen.Module): - loaded_sub_model, loaded_params = load_method(cached_folder, **loading_kwargs) - params_key = f"{name}_params" - if params_key not in inference_state_dict: - inference_state_dict[params_key] = loaded_params - else: - loaded_sub_model = load_method(cached_folder, **loading_kwargs) + loaded_sub_model = cached_folder + + if issubclass(class_obj, FlaxModelMixin): + loaded_sub_model, loaded_params = load_method(loadable_folder, **loading_kwargs) + params[name] = loaded_params + elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): + # make sure we don't initialize the weights to save time + loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False, **loading_kwargs) + params[name] = loaded_params + elif issubclass(class_obj, SchedulerMixin): + loaded_sub_model = load_method(loadable_folder, **loading_kwargs) + params[name] = loaded_sub_model.create_state() + else: + loaded_sub_model = load_method(loadable_folder, **loading_kwargs) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - # 4. Instantiate the pipeline - # TODO: fix hard-coded `StableDifusion.InferenceState`, it should be inferred as `{XYZ_Pipeline}.InferenceState` - from .pipelines.stable_diffusion import InferenceState - - inference_state = InferenceState(**inference_state_dict) - model = pipeline_class(**init_kwargs, dtype=dtype, inference_state=inference_state) - return model + model = pipeline_class(**init_kwargs, dtype=dtype) + return model, params @staticmethod def numpy_to_pil(images): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0db153c864ed..8e3c8592a258 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ -from ..utils import is_onnx_available, is_transformers_available, is_flax_available +from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline @@ -7,7 +7,7 @@ from .stochastic_karras_ve import KarrasVePipeline -if is_transformers_available(): +if is_torch_available() and is_transformers_available(): from .latent_diffusion import LDMTextToImagePipeline from .stable_diffusion import ( StableDiffusionImg2ImgPipeline, @@ -18,5 +18,5 @@ if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline -if is_flax_available(): - from .stable_diffusion import FlaxStableDiffusionPipeline \ No newline at end of file +if is_transformers_available() and is_flax_available(): + from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 7c30ff0e3d20..378dd8e9a99c 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -65,5 +65,5 @@ class InferenceState: from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): - from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 896ffabbeec3..98546e206e32 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -1,9 +1,12 @@ import inspect import warnings -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union + +import numpy as np import jax import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...configuration_utils import FrozenDict @@ -51,13 +54,11 @@ def __init__( scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, - inference_state: InferenceState, dtype: jnp.dtype = jnp.float32, ): super().__init__() scheduler = scheduler.set_format("np") self.dtype = dtype - self.inference_state = inference_state if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( @@ -83,17 +84,29 @@ def __init__( feature_extractor=feature_extractor, ) + def prepare_prompts(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids + def __call__( self, - prompt: Union[str, List[str]], + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, + num_inference_steps: Optional[int] = 50, height: Optional[int] = 512, width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, latents: Optional[jnp.array] = None, - output_type: Optional[str] = "pil", return_dict: bool = True, debug: bool = False, **kwargs, @@ -117,9 +130,6 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -141,40 +151,26 @@ def __call__( element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - inference_state = self.inference_state - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_embeddings = self.text_encoder(text_input.input_ids, params=inference_state.text_encoder_params)[0] + text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` - max_length = text_input.input_ids.shape[-1] + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=inference_state.text_encoder_params)[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) # TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline latents_shape = ( - text_input.input_ids.shape[0], + batch_size, self.unet.sample_size, self.unet.sample_size, self.unet.in_channels, @@ -197,7 +193,7 @@ def loop_body(step, args): # predict the noise residual noise_pred = self.unet.apply( - {"params": inference_state.unet_params}, + {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, @@ -208,12 +204,11 @@ def loop_body(step, args): noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents) - latents = latents["prev_sample"] + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state - scheduler_state = inference_state.scheduler_state - num_inference_steps = len(scheduler_state.timesteps) + scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps) + if debug: # run with python for loop for i in range(num_inference_steps): @@ -224,20 +219,18 @@ def loop_body(step, args): # scale and decode the image latents with vae latents = 1 / 0.18215 * latents # TODO: check when flax vae gets merged into main - image = self.vae.decode(latents, params=inference_state.vae_params).sample + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + # image = jnp.asarray(image).transpose(0, 2, 3, 1) # run safety checker # TODO: check when flax safety checker gets merged into main - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_cheker_input.pixel_values, params=inference_state.safety_params - ) - - if output_type == "pil": - image = self.numpy_to_pil(image) + # safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + # image, has_nsfw_concept = self.safety_checker( + # images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"] + # ) + has_nsfw_concept = False if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 015b79b2780d..dd5a87df654a 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -21,7 +21,6 @@ import flax import jax.numpy as jnp -from jax import random from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -60,11 +59,12 @@ def alpha_bar(time_step): class DDIMSchedulerState: # setable values timesteps: jnp.ndarray + alphas_cumprod: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod - def create(cls, num_train_timesteps: int): - return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod) @dataclass @@ -112,13 +112,9 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, ): - if trained_betas is not None: - self.betas = jnp.asarray(trained_betas) if beta_schedule == "linear": self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": @@ -131,19 +127,24 @@ def __init__( raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + # HACK for now - clean up later (PVP) + self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0]) - self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return DDIMSchedulerState.create( + num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod + ) - def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + def _get_variance(self, timestep, prev_timestep, alphas_cumprod): + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -177,9 +178,6 @@ def step( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: random.KeyArray, - eta: float = 0.0, - use_clipped_model_output: bool = False, return_dict: bool = True, ) -> Union[FlaxSchedulerOutput, Tuple]: """ @@ -221,41 +219,28 @@ def step( # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + alphas_cumprod = state.alphas_cumprod + # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - # 4. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = jnp.clip(pred_original_sample, -1, 1) - - # 5. compute variance: "sigma_t(η)" -> see formula (16) + # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) + variance = self._get_variance(timestep, prev_timestep, alphas_cumprod) + std_dev_t = variance ** (0.5) - if use_clipped_model_output: - # the model_output is always re-derived from the clipped x_0 in Glide - model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - if eta > 0: - key = random.split(key, num=1) - noise = random.normal(key=key, shape=model_output.shape) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise - - prev_sample = prev_sample + variance - if not return_dict: return (prev_sample, state) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index efc3858ca75a..4c8c43810b6f 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -148,7 +148,8 @@ def __init__( # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 - self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: """ From 47d77393b7d06f5ce8f5cb3e8ce97a43e82cb8e6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 19 Sep 2022 20:46:44 +0000 Subject: [PATCH 08/15] up --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 98546e206e32..c85f2c7faf04 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -1,19 +1,15 @@ -import inspect import warnings from typing import Dict, List, Optional, Union -import numpy as np - import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel -from ...configuration_utils import FrozenDict from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from . import FlaxStableDiffusionPipelineOutput, InferenceState +from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker From 0c2a868ec4e939a0fca3f590c39280a6eff1a686 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 07:43:26 +0000 Subject: [PATCH 09/15] Use Flax schedulers (typing, docstring) --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index c85f2c7faf04..92c135f872e3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -8,7 +8,7 @@ from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...pipeline_flax_utils import FlaxDiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -31,9 +31,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): + scheduler ([`FlaxSchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offsensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. @@ -47,7 +47,7 @@ def __init__( text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, dtype: jnp.dtype = jnp.float32, From 69b1d7accd3dcc93e3d32b13dbeecde9fca212eb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 20 Sep 2022 09:14:45 +0000 Subject: [PATCH 10/15] Wrap model imports inside availability checks. --- src/diffusers/models/__init__.py | 15 ++++++++++----- src/diffusers/utils/dummy_flax_objects.py | 14 +++++++------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b5fe089e05f0..d58e4d77ff73 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .unet_2d import UNet2DModel -from .unet_2d_condition import UNet2DConditionModel -from .unet_2d_condition_flax import FlaxUNet2DConditionModel -from .vae import AutoencoderKL, VQModel -from .vae_flax import FlaxAutoencoderKL +from ..utils import is_torch_available, is_flax_available + +if is_torch_available(): + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .vae import AutoencoderKL, VQModel + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 424e4f3bf6d8..1e3ac002a609 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -25,49 +25,49 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxDDIMScheduler(metaclass=DummyObject): +class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxDDPMScheduler(metaclass=DummyObject): +class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxKarrasVeScheduler(metaclass=DummyObject): +class FlaxDDPMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxLMSDiscreteScheduler(metaclass=DummyObject): +class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxPNDMScheduler(metaclass=DummyObject): +class FlaxLMSDiscreteScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxScoreSdeVeScheduler(metaclass=DummyObject): +class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxDiffusionPipeline(metaclass=DummyObject): +class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): From 82a5cf340acad8998f58ae33a41f02c09fa5cba5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Sep 2022 17:28:52 +0000 Subject: [PATCH 11/15] more updates --- src/diffusers/models/attention_flax.py | 1 - src/diffusers/pipeline_flax_utils.py | 51 +++++++++++++++---- .../pipeline_flax_stable_diffusion.py | 2 +- 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 525be4818dcc..461fb8b0ac33 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -144,7 +144,6 @@ def setup(self): def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape - # import ipdb; ipdb.set_trace() residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 8092b29bc582..6b01275de85d 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -62,6 +62,19 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) +def import_flax_or_no_model(module, class_name): + try: + # 1. First make sure that if a Flax object is present, import this one + class_obj = getattr(module, "Flax" + class_name) + except AttributeError: + # 2. If this doesn't work, it's not a model and we don't append "Flax" + class_obj = getattr(module, class_name) + except AttributeError: + raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") + + return class_obj + + @flax.struct.dataclass class FlaxImagePipelineOutput(BaseOutput): """ @@ -260,6 +273,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) + from_pt = kwargs.pop("from_pt", False) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs @@ -353,13 +367,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model = passed_class_obj[name] elif is_pipeline_module: pipeline_module = getattr(pipelines, library_name) - class_obj = getattr(pipeline_module, class_name) + if from_pt: + class_obj = import_flax_or_no_model(pipeline_module, class_name) + else: + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) + if from_pt: + class_obj = import_flax_or_no_model(library, class_name) + else: + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} @@ -371,10 +393,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P load_method = getattr(class_obj, load_method_name) - loading_kwargs = {} - if issubclass(class_obj, flax.linen.Module): - loading_kwargs["dtype"] = dtype - # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loadable_folder = os.path.join(cached_folder, name) @@ -382,17 +400,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model = cached_folder if issubclass(class_obj, FlaxModelMixin): - loaded_sub_model, loaded_params = load_method(loadable_folder, **loading_kwargs) - params[name] = loaded_params + # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here + if name == "safety_checker": + loaded_sub_model = None + loaded_params = None + else: + loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) + params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): # make sure we don't initialize the weights to save time - loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False, **loading_kwargs) + if from_pt: + # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here + loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) + loaded_params = loaded_sub_model.params + del loaded_sub_model._params + else: + loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params elif issubclass(class_obj, SchedulerMixin): - loaded_sub_model = load_method(loadable_folder, **loading_kwargs) + loaded_sub_model = load_method(loadable_folder) params[name] = loaded_sub_model.create_state() else: - loaded_sub_model = load_method(loadable_folder, **loading_kwargs) + loaded_sub_model = load_method(loadable_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index c85f2c7faf04..10b908de71fb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -80,7 +80,7 @@ def __init__( feature_extractor=feature_extractor, ) - def prepare_prompts(self, prompt: Union[str, List[str]]): + def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") From 2e9e52308fd0e48b7685e727ba13226996d0f144 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Sep 2022 17:35:22 +0000 Subject: [PATCH 12/15] make sure flax is not broken --- src/diffusers/pipeline_flax_utils.py | 9 +++- .../pipelines/stable_diffusion/__init__.py | 45 ++++++++----------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 6b01275de85d..6f096b814e53 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -402,8 +402,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if issubclass(class_obj, FlaxModelMixin): # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here if name == "safety_checker": - loaded_sub_model = None - loaded_params = None + class DummyChecker: + + def __init__(self): + self.dummy = True + + loaded_sub_model = DummyChecker() + loaded_params = DummyChecker() else: loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) params[name] = loaded_params diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 378dd8e9a99c..bd3e3e03c83e 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -3,7 +3,6 @@ import numpy as np -import flax import PIL from PIL import Image @@ -29,32 +28,6 @@ class StableDiffusionPipelineOutput(BaseOutput): nsfw_content_detected: List[bool] -@flax.struct.dataclass -class FlaxStableDiffusionPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - nsfw_content_detected (`List[bool]`) - List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: List[bool] - - -@flax.struct.dataclass -class InferenceState: - text_encoder_params: flax.core.FrozenDict - unet_params: flax.core.FrozenDict - vae_params: flax.core.FrozenDict - scheduler_state: PNDMSchedulerState - - if is_transformers_available(): from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline @@ -65,5 +38,23 @@ class InferenceState: from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker From abb22502e53afd0678eaba4dc635ab152c608fcc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Sep 2022 17:37:08 +0000 Subject: [PATCH 13/15] make style --- src/diffusers/models/__init__.py | 3 ++- src/diffusers/pipeline_flax_utils.py | 2 +- src/diffusers/pipelines/stable_diffusion/__init__.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d58e4d77ff73..1242ad6fca7f 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_torch_available, is_flax_available +from ..utils import is_flax_available, is_torch_available + if is_torch_available(): from .unet_2d import UNet2DModel diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 6f096b814e53..768b6714ef21 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -402,8 +402,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if issubclass(class_obj, FlaxModelMixin): # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here if name == "safety_checker": - class DummyChecker: + class DummyChecker: def __init__(self): self.dummy = True diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index bd3e3e03c83e..859fda136baf 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -53,6 +53,7 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. """ + images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: List[bool] From 61342a2c4c0174e2383482a99e8b1d9fe20949a4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Sep 2022 19:07:50 +0000 Subject: [PATCH 14/15] more fixes --- src/diffusers/pipeline_flax_utils.py | 28 +++++++++++++++---- .../pipelines/stable_diffusion/__init__.py | 1 - .../pipeline_flax_stable_diffusion.py | 17 +---------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 768b6714ef21..fca793c0dfc2 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -62,6 +62,11 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) +class DummyChecker: + def __init__(self): + self.dummy = True + + def import_flax_or_no_model(module, class_name): try: # 1. First make sure that if a Flax object is present, import this one @@ -172,8 +177,19 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union if save_method_name is not None: break + # TODO(Patrick, Suraj): to delete after + if isinstance(sub_model, DummyChecker): + continue + save_method = getattr(sub_model, save_method_name) - save_method(os.path.join(save_directory, pipeline_component_name)) + expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) + + if expects_params: + save_method( + os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] + ) + else: + save_method(os.path.join(save_directory, pipeline_component_name)) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): @@ -335,6 +351,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): + # TODO(Patrick, Suraj) - delete later + if class_name == "DummyChecker": + library_name = "stable_diffusion" + class_name = "StableDiffusionSafetyChecker" + is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None @@ -402,11 +423,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if issubclass(class_obj, FlaxModelMixin): # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here if name == "safety_checker": - - class DummyChecker: - def __init__(self): - self.dummy = True - loaded_sub_model = DummyChecker() loaded_params = DummyChecker() else: diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 859fda136baf..1016ce69e450 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,6 @@ import PIL from PIL import Image -from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 45ef05b260ce..6cca376791ac 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -1,4 +1,3 @@ -import warnings from typing import Dict, List, Optional, Union import jax @@ -56,20 +55,6 @@ def __init__( scheduler = scheduler.set_format("np") self.dtype = dtype - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - warnings.warn( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file", - DeprecationWarning, - ) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -167,9 +152,9 @@ def __call__( # TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline latents_shape = ( batch_size, + self.unet.in_channels, self.unet.sample_size, self.unet.sample_size, - self.unet.in_channels, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) From 182e485ee51b57cf8dfb12eec4b7abcd4cb80537 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Sep 2022 19:24:00 +0000 Subject: [PATCH 15/15] up --- src/diffusers/modeling_flax_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index d1dfacf36265..ed62b5fe579a 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -306,16 +306,16 @@ def from_pretrained( # Load model if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): - # Load from a Flax checkpoint - model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) - # At this stage we don't have a weight file so we will raise an error. - elif from_pt: + if from_pt: if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " ) model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) + # At this stage we don't have a weight file so we will raise an error. elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"