Skip to content

Flax safety checker #825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])


class DummyChecker:
def __init__(self):
self.dummy = True


def import_flax_or_no_model(module, class_name):
try:
# 1. First make sure that if a Flax object is present, import this one
Expand Down Expand Up @@ -177,10 +172,6 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union
if save_method_name is not None:
break

# TODO(Patrick, Suraj): to delete after
if isinstance(sub_model, DummyChecker):
continue

save_method = getattr(sub_model, save_method_name)
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())

Expand All @@ -194,7 +185,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.

The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).

Expand Down Expand Up @@ -349,11 +340,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# TODO(Patrick, Suraj) - delete later
if class_name == "DummyChecker":
library_name = "stable_diffusion"
class_name = "FlaxStableDiffusionSafetyChecker"

is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None

Expand Down Expand Up @@ -422,11 +408,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
# make sure we don't initialize the weights to save time
if name == "safety_checker":
loaded_sub_model = DummyChecker()
loaded_params = {}
elif from_pt:
if from_pt:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
loaded_params = loaded_sub_model.params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import jax
import jax.numpy as jnp
import numpy as np
from jax import pmap
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import replicate, unreplicate
from flax.training.common_utils import shard

from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel

from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
Expand Down Expand Up @@ -52,7 +58,8 @@ def __init__(
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
scheduler = scheduler.set_format("np")
# TODO: review and adapt to new scheduler API
# scheduler = scheduler.set_format("np")
self.dtype = dtype

self.register_modules(
Expand All @@ -78,7 +85,30 @@ def prepare_inputs(self, prompt: Union[str, List[str]]):
)
return text_input.input_ids

def __call__(
def get_safety_scores(self, features, params):
special_cos_dist, cos_dist = self.safety_checker(features, params)
return (special_cos_dist, cos_dist)

def run_safety_checker(self, images, safety_model_params):
# safety_model_params should already be replicated
pil_images = [Image.fromarray(image) for image in images]
jnp_images = jnp.array(images)
jnp_images = shard(jnp_images)
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
# features = jnp.transpose(features, (0, 2, 3, 1))
features = shard(features)

p_safety_scores = jax.pmap(self.get_safety_scores)
special_cos_dist, cos_dist = p_safety_scores(features, safety_model_params)
images, has_nsfw = self.safety_checker.filtered_with_scores(
unshard(special_cos_dist),
unshard(cos_dist),
images,
params = unreplicate(safety_model_params),
)
return images, has_nsfw

def generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
Expand All @@ -88,50 +118,8 @@ def __call__(
width: Optional[int] = 512,
guidance_scale: Optional[float] = 7.5,
latents: Optional[jnp.array] = None,
return_dict: bool = True,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.

Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.

Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Expand Down Expand Up @@ -199,21 +187,92 @@ def loop_body(step, args):

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# TODO: check when flax vae gets merged into main
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample

image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image


def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
latents: jnp.array = None,
return_dict: bool = True,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.

# image = jnp.asarray(image).transpose(0, 2, 3, 1)
# run safety checker
# TODO: check when flax safety checker gets merged into main
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
# image, has_nsfw_concept = self.safety_checker(
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
# )
has_nsfw_concept = False
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.

Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
# Delegate to `self.generate` por parallel generation then run safety checker as an additional step

# TODO: assert dimensions
params = replicate(params)
prompt_ids = shard(prompt_ids)
prng_seed = jax.random.split(prng_seed, jax.device_count())

p_generate = pmap(self.generate, static_broadcasted_argnums=(3,))
# TODO: send latents if necessary
# images = p_generate(prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug)
images = p_generate(prompt_ids, params, prng_seed, num_inference_steps)

safety_params = params["safety_checker"]
images = jnp.clip(images, 0, 1)
images = (images * 255).round().astype("uint8")
images = np.asarray(images).reshape(-1, height, width, 3)
images, has_nsfw_concept = self.run_safety_checker(images, safety_params)

if not return_dict:
return (image, has_nsfw_concept)
return (images, has_nsfw_concept)

return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)


return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def unshard(x: jnp.ndarray):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe also make it private

# einops.rearrange(x, 'd b ... -> (d b) ...')
d, b = x.shape[:2]
rest = x.shape[2:]
return x.reshape(d*b, *rest)