Skip to content

Give more customizable options for safety checker #815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,27 @@ def register_modules(self, **kwargs):

for name, module in kwargs.items():
# retrieve library
library = module.__module__.split(".")[0]
if module is None:
register_dict = {name: (None, None)}
else:
library = module.__module__.split(".")[0]

# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)

# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir

# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# retrieve class_name
class_name = module.__class__.__name__

# retrieve class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}

register_dict = {name: (library, class_name)}

# save model index config
self.register_to_config(**register_dict)
Expand Down Expand Up @@ -422,6 +426,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True

# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
Expand All @@ -442,6 +447,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is generally not"
" recommended to do."
)
sub_model_should_be_defined = False
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
Expand All @@ -462,7 +474,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}

if loaded_sub_model is None:
if loaded_sub_model is None and sub_model_should_be_defined:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union

import numpy as np

Expand All @@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput):
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
(nsfw) content. If safety checker is disabled `None` will be returned.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten one question/request -- would it be possible to have a setting that still returns whether or not an image is nsfw, but does not black out that image? This would allow devs who are building on top of this library to do something like show a popup to a user (e.g. 'you are about to view a NSFW image, do you want to proceed')

One possible way to implement this is to pass a flag into the safety checker module that disables the 'return a black image' part of the checker. Another way is to expose the safety checker class so that end users can add the checker in at the end (solely to get the bool[] indicating nsfw-ness)

"""

images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
nsfw_content_detected: Optional[List[bool]]


if is_transformers_available() and is_torch_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)

if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`.Please"
" make sure you have very good reasons for this and have considered the consequences of doing so.The"
" `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly"
" suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to"
" `from_pretrained`.For more information, please have a look at"
" https://github.com/huggingface/diffusers/pull/254"
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think referencing another PR dilutes the message. I'd propose something like:

The diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all circumstances. Please, make sure you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public.

I think it's worthwhile to make this as clear as we can, maybe @natolambert, @mmitchellai, @yjernite can provide better wording.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even say:

Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results.

I mean, I think proposing a use case when removing the safety filter is acceptable kind of justifies why we are proposing this PR to remove it.

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand Down Expand Up @@ -335,10 +345,15 @@ def __call__(
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,17 @@ def test_from_pretrained_error_message_uninstalled_packages(self):
assert isinstance(pipe, StableDiffusionPipeline)
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)

def test_stable_diffusion_no_safety_checker(self):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
)
assert isinstance(pipe, StableDiffusionPipeline)
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
assert pipe.safety_checker is None

image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None

def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
Expand Down