-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Flax safety checker #825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Flax safety checker #825
Changes from 3 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
3c838a2
Remove set_format in Flax pipeline.
pcuenca a444010
Remove DummyChecker.
pcuenca 3ca68c4
Run safety_checker in pipeline.
pcuenca 9d84107
Merge branch 'main' into flax-safety-checker
patrickvonplaten 750e20f
Don't pmap on every call.
pcuenca dcd27fd
Remove commented line
pcuenca a0680ed
Merge branch 'flax-safety-checker' of github.com:huggingface/diffuser…
pcuenca d65d1a2
Replicate outside __call__, prepare for optional jitting.
pcuenca 4239bff
Remove unnecessary clipping.
pcuenca 866600b
Do not jit unless requested.
pcuenca 86cb5a1
Send all args to generate.
pcuenca 1cd8bb5
Merge remote-tracking branch 'origin/main' into flax-safety-checker
pcuenca fe2817b
make style
pcuenca b255e9a
Remove unused imports.
pcuenca 2533c50
Fix docstring.
pcuenca File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,13 @@ | |
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from jax import pmap | ||
from flax.core.frozen_dict import FrozenDict | ||
from flax.jax_utils import replicate, unreplicate | ||
from flax.training.common_utils import shard | ||
|
||
from PIL import Image | ||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel | ||
|
||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel | ||
|
@@ -52,7 +58,8 @@ def __init__( | |
dtype: jnp.dtype = jnp.float32, | ||
): | ||
super().__init__() | ||
scheduler = scheduler.set_format("np") | ||
# TODO: review and adapt to new scheduler API | ||
# scheduler = scheduler.set_format("np") | ||
self.dtype = dtype | ||
|
||
self.register_modules( | ||
|
@@ -78,7 +85,30 @@ def prepare_inputs(self, prompt: Union[str, List[str]]): | |
) | ||
return text_input.input_ids | ||
|
||
def __call__( | ||
def get_safety_scores(self, features, params): | ||
special_cos_dist, cos_dist = self.safety_checker(features, params) | ||
return (special_cos_dist, cos_dist) | ||
|
||
def run_safety_checker(self, images, safety_model_params): | ||
# safety_model_params should already be replicated | ||
pil_images = [Image.fromarray(image) for image in images] | ||
jnp_images = jnp.array(images) | ||
jnp_images = shard(jnp_images) | ||
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values | ||
# features = jnp.transpose(features, (0, 2, 3, 1)) | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
features = shard(features) | ||
|
||
p_safety_scores = jax.pmap(self.get_safety_scores) | ||
special_cos_dist, cos_dist = p_safety_scores(features, safety_model_params) | ||
images, has_nsfw = self.safety_checker.filtered_with_scores( | ||
unshard(special_cos_dist), | ||
unshard(cos_dist), | ||
images, | ||
params = unreplicate(safety_model_params), | ||
) | ||
return images, has_nsfw | ||
|
||
def generate( | ||
self, | ||
prompt_ids: jnp.array, | ||
params: Union[Dict, FrozenDict], | ||
|
@@ -88,50 +118,8 @@ def __call__( | |
width: Optional[int] = 512, | ||
guidance_scale: Optional[float] = 7.5, | ||
latents: Optional[jnp.array] = None, | ||
return_dict: bool = True, | ||
debug: bool = False, | ||
**kwargs, | ||
): | ||
r""" | ||
Function invoked when calling the pipeline for generation. | ||
|
||
Args: | ||
prompt (`str` or `List[str]`): | ||
The prompt or prompts to guide the image generation. | ||
height (`int`, *optional*, defaults to 512): | ||
The height in pixels of the generated image. | ||
width (`int`, *optional*, defaults to 512): | ||
The width in pixels of the generated image. | ||
num_inference_steps (`int`, *optional*, defaults to 50): | ||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
expense of slower inference. | ||
guidance_scale (`float`, *optional*, defaults to 7.5): | ||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||
`guidance_scale` is defined as `w` of equation 2. of [Imagen | ||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||
usually at the expense of lower image quality. | ||
generator (`torch.Generator`, *optional*): | ||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | ||
deterministic. | ||
latents (`jnp.array`, *optional*): | ||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||
tensor will ge generated by sampling using the supplied random `generator`. | ||
output_type (`str`, *optional*, defaults to `"pil"`): | ||
The output format of the generate image. Choose between | ||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of | ||
a plain tuple. | ||
|
||
Returns: | ||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: | ||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a | ||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second | ||
element is a list of `bool`s denoting whether the corresponding generated image likely represents | ||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`. | ||
""" | ||
if height % 8 != 0 or width % 8 != 0: | ||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
|
||
|
@@ -199,21 +187,92 @@ def loop_body(step, args): | |
|
||
# scale and decode the image latents with vae | ||
latents = 1 / 0.18215 * latents | ||
# TODO: check when flax vae gets merged into main | ||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample | ||
|
||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) | ||
return image | ||
|
||
|
||
def __call__( | ||
self, | ||
prompt_ids: jnp.array, | ||
params: Union[Dict, FrozenDict], | ||
prng_seed: jax.random.PRNGKey, | ||
num_inference_steps: int = 50, | ||
height: int = 512, | ||
width: int = 512, | ||
guidance_scale: float = 7.5, | ||
latents: jnp.array = None, | ||
return_dict: bool = True, | ||
debug: bool = False, | ||
**kwargs, | ||
): | ||
r""" | ||
Function invoked when calling the pipeline for generation. | ||
|
||
# image = jnp.asarray(image).transpose(0, 2, 3, 1) | ||
# run safety checker | ||
# TODO: check when flax safety checker gets merged into main | ||
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") | ||
# image, has_nsfw_concept = self.safety_checker( | ||
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"] | ||
# ) | ||
has_nsfw_concept = False | ||
Args: | ||
prompt (`str` or `List[str]`): | ||
The prompt or prompts to guide the image generation. | ||
height (`int`, *optional*, defaults to 512): | ||
The height in pixels of the generated image. | ||
width (`int`, *optional*, defaults to 512): | ||
The width in pixels of the generated image. | ||
num_inference_steps (`int`, *optional*, defaults to 50): | ||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
expense of slower inference. | ||
guidance_scale (`float`, *optional*, defaults to 7.5): | ||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | ||
`guidance_scale` is defined as `w` of equation 2. of [Imagen | ||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | ||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | ||
usually at the expense of lower image quality. | ||
generator (`torch.Generator`, *optional*): | ||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | ||
deterministic. | ||
latents (`jnp.array`, *optional*): | ||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | ||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | ||
tensor will ge generated by sampling using the supplied random `generator`. | ||
output_type (`str`, *optional*, defaults to `"pil"`): | ||
The output format of the generate image. Choose between | ||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of | ||
a plain tuple. | ||
|
||
Returns: | ||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: | ||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a | ||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second | ||
element is a list of `bool`s denoting whether the corresponding generated image likely represents | ||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`. | ||
""" | ||
# Delegate to `self.generate` por parallel generation then run safety checker as an additional step | ||
|
||
# TODO: assert dimensions | ||
params = replicate(params) | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prompt_ids = shard(prompt_ids) | ||
prng_seed = jax.random.split(prng_seed, jax.device_count()) | ||
|
||
p_generate = pmap(self.generate, static_broadcasted_argnums=(3,)) | ||
patil-suraj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO: send latents if necessary | ||
# images = p_generate(prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug) | ||
images = p_generate(prompt_ids, params, prng_seed, num_inference_steps) | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
safety_params = params["safety_checker"] | ||
images = jnp.clip(images, 0, 1) | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
images = (images * 255).round().astype("uint8") | ||
images = np.asarray(images).reshape(-1, height, width, 3) | ||
images, has_nsfw_concept = self.run_safety_checker(images, safety_params) | ||
|
||
if not return_dict: | ||
return (image, has_nsfw_concept) | ||
return (images, has_nsfw_concept) | ||
|
||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) | ||
|
||
|
||
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | ||
def unshard(x: jnp.ndarray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's maybe also make it private |
||
# einops.rearrange(x, 'd b ... -> (d b) ...') | ||
d, b = x.shape[:2] | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rest = x.shape[2:] | ||
return x.reshape(d*b, *rest) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.