Skip to content

Commit 78db11d

Browse files
Flax safety checker (#825)
* Remove set_format in Flax pipeline. * Remove DummyChecker. * Run safety_checker in pipeline. * Don't pmap on every call. We could have decorated `generate` with `pmap`, but I wanted to keep it in case someone wants to invoke it in non-parallel mode. * Remove commented line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Replicate outside __call__, prepare for optional jitting. * Remove unnecessary clipping. As suggested by @kashif. * Do not jit unless requested. * Send all args to generate. * make style * Remove unused imports. * Fix docstring. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent e713346 commit 78db11d

File tree

2 files changed

+135
-78
lines changed

2 files changed

+135
-78
lines changed

src/diffusers/pipeline_flax_utils.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,6 @@
6262
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
6363

6464

65-
class DummyChecker:
66-
def __init__(self):
67-
self.dummy = True
68-
69-
7065
def import_flax_or_no_model(module, class_name):
7166
try:
7267
# 1. First make sure that if a Flax object is present, import this one
@@ -177,10 +172,6 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union
177172
if save_method_name is not None:
178173
break
179174

180-
# TODO(Patrick, Suraj): to delete after
181-
if isinstance(sub_model, DummyChecker):
182-
continue
183-
184175
save_method = getattr(sub_model, save_method_name)
185176
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
186177

@@ -194,7 +185,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union
194185
@classmethod
195186
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
196187
r"""
197-
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
188+
Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
198189
199190
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
200191
@@ -349,11 +340,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
349340

350341
# 3. Load each module in the pipeline
351342
for name, (library_name, class_name) in init_dict.items():
352-
# TODO(Patrick, Suraj) - delete later
353-
if class_name == "DummyChecker":
354-
library_name = "stable_diffusion"
355-
class_name = "FlaxStableDiffusionSafetyChecker"
356-
357343
is_pipeline_module = hasattr(pipelines, library_name)
358344
loaded_sub_model = None
359345

@@ -422,11 +408,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
422408
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
423409
params[name] = loaded_params
424410
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
425-
# make sure we don't initialize the weights to save time
426-
if name == "safety_checker":
427-
loaded_sub_model = DummyChecker()
428-
loaded_params = {}
429-
elif from_pt:
411+
if from_pt:
430412
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
431413
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
432414
loaded_params = loaded_sub_model.params

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 133 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
from functools import partial
12
from typing import Dict, List, Optional, Union
23

4+
import numpy as np
5+
36
import jax
47
import jax.numpy as jnp
58
from flax.core.frozen_dict import FrozenDict
9+
from flax.jax_utils import unreplicate
10+
from flax.training.common_utils import shard
11+
from PIL import Image
612
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
713

814
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
@@ -77,60 +83,44 @@ def prepare_inputs(self, prompt: Union[str, List[str]]):
7783
)
7884
return text_input.input_ids
7985

80-
def __call__(
86+
def _get_safety_scores(self, features, params):
87+
special_cos_dist, cos_dist = self.safety_checker(features, params)
88+
return (special_cos_dist, cos_dist)
89+
90+
def _run_safety_checker(self, images, safety_model_params, jit=False):
91+
# safety_model_params should already be replicated when jit is True
92+
pil_images = [Image.fromarray(image) for image in images]
93+
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
94+
95+
if jit:
96+
features = shard(features)
97+
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
98+
special_cos_dist = unshard(special_cos_dist)
99+
cos_dist = unshard(cos_dist)
100+
safety_model_params = unreplicate(safety_model_params)
101+
else:
102+
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
103+
104+
images, has_nsfw = self.safety_checker.filtered_with_scores(
105+
special_cos_dist,
106+
cos_dist,
107+
images,
108+
safety_model_params,
109+
)
110+
return images, has_nsfw
111+
112+
def _generate(
81113
self,
82114
prompt_ids: jnp.array,
83115
params: Union[Dict, FrozenDict],
84116
prng_seed: jax.random.PRNGKey,
85-
num_inference_steps: Optional[int] = 50,
86-
height: Optional[int] = 512,
87-
width: Optional[int] = 512,
88-
guidance_scale: Optional[float] = 7.5,
117+
num_inference_steps: int = 50,
118+
height: int = 512,
119+
width: int = 512,
120+
guidance_scale: float = 7.5,
89121
latents: Optional[jnp.array] = None,
90-
return_dict: bool = True,
91122
debug: bool = False,
92-
**kwargs,
93123
):
94-
r"""
95-
Function invoked when calling the pipeline for generation.
96-
97-
Args:
98-
prompt (`str` or `List[str]`):
99-
The prompt or prompts to guide the image generation.
100-
height (`int`, *optional*, defaults to 512):
101-
The height in pixels of the generated image.
102-
width (`int`, *optional*, defaults to 512):
103-
The width in pixels of the generated image.
104-
num_inference_steps (`int`, *optional*, defaults to 50):
105-
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
106-
expense of slower inference.
107-
guidance_scale (`float`, *optional*, defaults to 7.5):
108-
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
109-
`guidance_scale` is defined as `w` of equation 2. of [Imagen
110-
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
111-
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
112-
usually at the expense of lower image quality.
113-
generator (`torch.Generator`, *optional*):
114-
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
115-
deterministic.
116-
latents (`jnp.array`, *optional*):
117-
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
118-
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
119-
tensor will ge generated by sampling using the supplied random `generator`.
120-
output_type (`str`, *optional*, defaults to `"pil"`):
121-
The output format of the generate image. Choose between
122-
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
123-
return_dict (`bool`, *optional*, defaults to `True`):
124-
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
125-
a plain tuple.
126-
127-
Returns:
128-
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
129-
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
130-
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
131-
element is a list of `bool`s denoting whether the corresponding generated image likely represents
132-
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
133-
"""
134124
if height % 8 != 0 or width % 8 != 0:
135125
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
136126

@@ -203,21 +193,106 @@ def loop_body(step, args):
203193

204194
# scale and decode the image latents with vae
205195
latents = 1 / 0.18215 * latents
206-
# TODO: check when flax vae gets merged into main
207196
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
208197

209198
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
199+
return image
210200

211-
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
212-
# run safety checker
213-
# TODO: check when flax safety checker gets merged into main
214-
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
215-
# image, has_nsfw_concept = self.safety_checker(
216-
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
217-
# )
218-
has_nsfw_concept = False
201+
def __call__(
202+
self,
203+
prompt_ids: jnp.array,
204+
params: Union[Dict, FrozenDict],
205+
prng_seed: jax.random.PRNGKey,
206+
num_inference_steps: int = 50,
207+
height: int = 512,
208+
width: int = 512,
209+
guidance_scale: float = 7.5,
210+
latents: jnp.array = None,
211+
return_dict: bool = True,
212+
jit: bool = False,
213+
debug: bool = False,
214+
**kwargs,
215+
):
216+
r"""
217+
Function invoked when calling the pipeline for generation.
218+
219+
Args:
220+
prompt (`str` or `List[str]`):
221+
The prompt or prompts to guide the image generation.
222+
height (`int`, *optional*, defaults to 512):
223+
The height in pixels of the generated image.
224+
width (`int`, *optional*, defaults to 512):
225+
The width in pixels of the generated image.
226+
num_inference_steps (`int`, *optional*, defaults to 50):
227+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
228+
expense of slower inference.
229+
guidance_scale (`float`, *optional*, defaults to 7.5):
230+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
231+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
232+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
233+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
234+
usually at the expense of lower image quality.
235+
generator (`torch.Generator`, *optional*):
236+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
237+
deterministic.
238+
latents (`jnp.array`, *optional*):
239+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
240+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
241+
tensor will ge generated by sampling using the supplied random `generator`.
242+
output_type (`str`, *optional*, defaults to `"pil"`):
243+
The output format of the generate image. Choose between
244+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
245+
jit (`bool`, defaults to `False`):
246+
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
247+
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
248+
return_dict (`bool`, *optional*, defaults to `True`):
249+
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
250+
a plain tuple.
251+
252+
Returns:
253+
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
254+
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
255+
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
256+
element is a list of `bool`s denoting whether the corresponding generated image likely represents
257+
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
258+
"""
259+
if jit:
260+
images = _p_generate(
261+
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
262+
)
263+
else:
264+
images = self._generate(
265+
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
266+
)
267+
268+
safety_params = params["safety_checker"]
269+
images = (images * 255).round().astype("uint8")
270+
images = np.asarray(images).reshape(-1, height, width, 3)
271+
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
219272

220273
if not return_dict:
221-
return (image, has_nsfw_concept)
274+
return (images, has_nsfw_concept)
275+
276+
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
277+
278+
279+
# TODO: maybe use a config dict instead of so many static argnums
280+
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
281+
def _p_generate(
282+
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
283+
):
284+
return pipe._generate(
285+
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
286+
)
287+
288+
289+
@partial(jax.pmap, static_broadcasted_argnums=(0,))
290+
def _p_get_safety_scores(pipe, features, params):
291+
return pipe._get_safety_scores(features, params)
292+
222293

223-
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
294+
def unshard(x: jnp.ndarray):
295+
# einops.rearrange(x, 'd b ... -> (d b) ...')
296+
num_devices, batch_size = x.shape[:2]
297+
rest = x.shape[2:]
298+
return x.reshape(num_devices * batch_size, *rest)

0 commit comments

Comments
 (0)