|
| 1 | +from functools import partial |
1 | 2 | from typing import Dict, List, Optional, Union
|
2 | 3 |
|
| 4 | +import numpy as np |
| 5 | + |
3 | 6 | import jax
|
4 | 7 | import jax.numpy as jnp
|
5 | 8 | from flax.core.frozen_dict import FrozenDict
|
| 9 | +from flax.jax_utils import unreplicate |
| 10 | +from flax.training.common_utils import shard |
| 11 | +from PIL import Image |
6 | 12 | from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
7 | 13 |
|
8 | 14 | from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
@@ -77,60 +83,44 @@ def prepare_inputs(self, prompt: Union[str, List[str]]):
|
77 | 83 | )
|
78 | 84 | return text_input.input_ids
|
79 | 85 |
|
80 |
| - def __call__( |
| 86 | + def _get_safety_scores(self, features, params): |
| 87 | + special_cos_dist, cos_dist = self.safety_checker(features, params) |
| 88 | + return (special_cos_dist, cos_dist) |
| 89 | + |
| 90 | + def _run_safety_checker(self, images, safety_model_params, jit=False): |
| 91 | + # safety_model_params should already be replicated when jit is True |
| 92 | + pil_images = [Image.fromarray(image) for image in images] |
| 93 | + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values |
| 94 | + |
| 95 | + if jit: |
| 96 | + features = shard(features) |
| 97 | + special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params) |
| 98 | + special_cos_dist = unshard(special_cos_dist) |
| 99 | + cos_dist = unshard(cos_dist) |
| 100 | + safety_model_params = unreplicate(safety_model_params) |
| 101 | + else: |
| 102 | + special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params) |
| 103 | + |
| 104 | + images, has_nsfw = self.safety_checker.filtered_with_scores( |
| 105 | + special_cos_dist, |
| 106 | + cos_dist, |
| 107 | + images, |
| 108 | + safety_model_params, |
| 109 | + ) |
| 110 | + return images, has_nsfw |
| 111 | + |
| 112 | + def _generate( |
81 | 113 | self,
|
82 | 114 | prompt_ids: jnp.array,
|
83 | 115 | params: Union[Dict, FrozenDict],
|
84 | 116 | prng_seed: jax.random.PRNGKey,
|
85 |
| - num_inference_steps: Optional[int] = 50, |
86 |
| - height: Optional[int] = 512, |
87 |
| - width: Optional[int] = 512, |
88 |
| - guidance_scale: Optional[float] = 7.5, |
| 117 | + num_inference_steps: int = 50, |
| 118 | + height: int = 512, |
| 119 | + width: int = 512, |
| 120 | + guidance_scale: float = 7.5, |
89 | 121 | latents: Optional[jnp.array] = None,
|
90 |
| - return_dict: bool = True, |
91 | 122 | debug: bool = False,
|
92 |
| - **kwargs, |
93 | 123 | ):
|
94 |
| - r""" |
95 |
| - Function invoked when calling the pipeline for generation. |
96 |
| -
|
97 |
| - Args: |
98 |
| - prompt (`str` or `List[str]`): |
99 |
| - The prompt or prompts to guide the image generation. |
100 |
| - height (`int`, *optional*, defaults to 512): |
101 |
| - The height in pixels of the generated image. |
102 |
| - width (`int`, *optional*, defaults to 512): |
103 |
| - The width in pixels of the generated image. |
104 |
| - num_inference_steps (`int`, *optional*, defaults to 50): |
105 |
| - The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
106 |
| - expense of slower inference. |
107 |
| - guidance_scale (`float`, *optional*, defaults to 7.5): |
108 |
| - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
109 |
| - `guidance_scale` is defined as `w` of equation 2. of [Imagen |
110 |
| - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
111 |
| - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
112 |
| - usually at the expense of lower image quality. |
113 |
| - generator (`torch.Generator`, *optional*): |
114 |
| - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation |
115 |
| - deterministic. |
116 |
| - latents (`jnp.array`, *optional*): |
117 |
| - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
118 |
| - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
119 |
| - tensor will ge generated by sampling using the supplied random `generator`. |
120 |
| - output_type (`str`, *optional*, defaults to `"pil"`): |
121 |
| - The output format of the generate image. Choose between |
122 |
| - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
123 |
| - return_dict (`bool`, *optional*, defaults to `True`): |
124 |
| - Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of |
125 |
| - a plain tuple. |
126 |
| -
|
127 |
| - Returns: |
128 |
| - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: |
129 |
| - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a |
130 |
| - `tuple. When returning a tuple, the first element is a list with the generated images, and the second |
131 |
| - element is a list of `bool`s denoting whether the corresponding generated image likely represents |
132 |
| - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. |
133 |
| - """ |
134 | 124 | if height % 8 != 0 or width % 8 != 0:
|
135 | 125 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
136 | 126 |
|
@@ -203,21 +193,106 @@ def loop_body(step, args):
|
203 | 193 |
|
204 | 194 | # scale and decode the image latents with vae
|
205 | 195 | latents = 1 / 0.18215 * latents
|
206 |
| - # TODO: check when flax vae gets merged into main |
207 | 196 | image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
208 | 197 |
|
209 | 198 | image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
| 199 | + return image |
210 | 200 |
|
211 |
| - # image = jnp.asarray(image).transpose(0, 2, 3, 1) |
212 |
| - # run safety checker |
213 |
| - # TODO: check when flax safety checker gets merged into main |
214 |
| - # safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") |
215 |
| - # image, has_nsfw_concept = self.safety_checker( |
216 |
| - # images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"] |
217 |
| - # ) |
218 |
| - has_nsfw_concept = False |
| 201 | + def __call__( |
| 202 | + self, |
| 203 | + prompt_ids: jnp.array, |
| 204 | + params: Union[Dict, FrozenDict], |
| 205 | + prng_seed: jax.random.PRNGKey, |
| 206 | + num_inference_steps: int = 50, |
| 207 | + height: int = 512, |
| 208 | + width: int = 512, |
| 209 | + guidance_scale: float = 7.5, |
| 210 | + latents: jnp.array = None, |
| 211 | + return_dict: bool = True, |
| 212 | + jit: bool = False, |
| 213 | + debug: bool = False, |
| 214 | + **kwargs, |
| 215 | + ): |
| 216 | + r""" |
| 217 | + Function invoked when calling the pipeline for generation. |
| 218 | +
|
| 219 | + Args: |
| 220 | + prompt (`str` or `List[str]`): |
| 221 | + The prompt or prompts to guide the image generation. |
| 222 | + height (`int`, *optional*, defaults to 512): |
| 223 | + The height in pixels of the generated image. |
| 224 | + width (`int`, *optional*, defaults to 512): |
| 225 | + The width in pixels of the generated image. |
| 226 | + num_inference_steps (`int`, *optional*, defaults to 50): |
| 227 | + The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| 228 | + expense of slower inference. |
| 229 | + guidance_scale (`float`, *optional*, defaults to 7.5): |
| 230 | + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
| 231 | + `guidance_scale` is defined as `w` of equation 2. of [Imagen |
| 232 | + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
| 233 | + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
| 234 | + usually at the expense of lower image quality. |
| 235 | + generator (`torch.Generator`, *optional*): |
| 236 | + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation |
| 237 | + deterministic. |
| 238 | + latents (`jnp.array`, *optional*): |
| 239 | + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
| 240 | + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| 241 | + tensor will ge generated by sampling using the supplied random `generator`. |
| 242 | + output_type (`str`, *optional*, defaults to `"pil"`): |
| 243 | + The output format of the generate image. Choose between |
| 244 | + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
| 245 | + jit (`bool`, defaults to `False`): |
| 246 | + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument |
| 247 | + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. |
| 248 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 249 | + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of |
| 250 | + a plain tuple. |
| 251 | +
|
| 252 | + Returns: |
| 253 | + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: |
| 254 | + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a |
| 255 | + `tuple. When returning a tuple, the first element is a list with the generated images, and the second |
| 256 | + element is a list of `bool`s denoting whether the corresponding generated image likely represents |
| 257 | + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. |
| 258 | + """ |
| 259 | + if jit: |
| 260 | + images = _p_generate( |
| 261 | + self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug |
| 262 | + ) |
| 263 | + else: |
| 264 | + images = self._generate( |
| 265 | + prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug |
| 266 | + ) |
| 267 | + |
| 268 | + safety_params = params["safety_checker"] |
| 269 | + images = (images * 255).round().astype("uint8") |
| 270 | + images = np.asarray(images).reshape(-1, height, width, 3) |
| 271 | + images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit) |
219 | 272 |
|
220 | 273 | if not return_dict:
|
221 |
| - return (image, has_nsfw_concept) |
| 274 | + return (images, has_nsfw_concept) |
| 275 | + |
| 276 | + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) |
| 277 | + |
| 278 | + |
| 279 | +# TODO: maybe use a config dict instead of so many static argnums |
| 280 | +@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) |
| 281 | +def _p_generate( |
| 282 | + pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug |
| 283 | +): |
| 284 | + return pipe._generate( |
| 285 | + prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug |
| 286 | + ) |
| 287 | + |
| 288 | + |
| 289 | +@partial(jax.pmap, static_broadcasted_argnums=(0,)) |
| 290 | +def _p_get_safety_scores(pipe, features, params): |
| 291 | + return pipe._get_safety_scores(features, params) |
| 292 | + |
222 | 293 |
|
223 |
| - return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| 294 | +def unshard(x: jnp.ndarray): |
| 295 | + # einops.rearrange(x, 'd b ... -> (d b) ...') |
| 296 | + num_devices, batch_size = x.shape[:2] |
| 297 | + rest = x.shape[2:] |
| 298 | + return x.reshape(num_devices * batch_size, *rest) |
0 commit comments