Skip to content

Commit e713346

Browse files
Give more customizable options for safety checker (#815)
* Give more customizable options for safety checker * Apply suggestions from code review * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * Finish * make style * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent 26c7df5 commit e713346

File tree

6 files changed

+93
-27
lines changed

6 files changed

+93
-27
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,26 @@ def register_modules(self, **kwargs):
113113

114114
for name, module in kwargs.items():
115115
# retrieve library
116-
library = module.__module__.split(".")[0]
116+
if module is None:
117+
register_dict = {name: (None, None)}
118+
else:
119+
library = module.__module__.split(".")[0]
117120

118-
# check if the module is a pipeline module
119-
pipeline_dir = module.__module__.split(".")[-2]
120-
path = module.__module__.split(".")
121-
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
121+
# check if the module is a pipeline module
122+
pipeline_dir = module.__module__.split(".")[-2]
123+
path = module.__module__.split(".")
124+
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
122125

123-
# if library is not in LOADABLE_CLASSES, then it is a custom module.
124-
# Or if it's a pipeline module, then the module is inside the pipeline
125-
# folder so we set the library to module name.
126-
if library not in LOADABLE_CLASSES or is_pipeline_module:
127-
library = pipeline_dir
126+
# if library is not in LOADABLE_CLASSES, then it is a custom module.
127+
# Or if it's a pipeline module, then the module is inside the pipeline
128+
# folder so we set the library to module name.
129+
if library not in LOADABLE_CLASSES or is_pipeline_module:
130+
library = pipeline_dir
128131

129-
# retrieve class_name
130-
class_name = module.__class__.__name__
132+
# retrieve class_name
133+
class_name = module.__class__.__name__
131134

132-
register_dict = {name: (library, class_name)}
135+
register_dict = {name: (library, class_name)}
133136

134137
# save model index config
135138
self.register_to_config(**register_dict)
@@ -429,6 +432,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
429432

430433
is_pipeline_module = hasattr(pipelines, library_name)
431434
loaded_sub_model = None
435+
sub_model_should_be_defined = True
432436

433437
# if the model is in a pipeline module, then we load it from the pipeline
434438
if name in passed_class_obj:
@@ -449,6 +453,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
449453
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
450454
f" {expected_class_obj}"
451455
)
456+
elif passed_class_obj[name] is None:
457+
logger.warn(
458+
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
459+
f" that this might lead to problems when using {pipeline_class} and is not recommended."
460+
)
461+
sub_model_should_be_defined = False
452462
else:
453463
logger.warn(
454464
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@@ -469,7 +479,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
469479
importable_classes = LOADABLE_CLASSES[library_name]
470480
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
471481

472-
if loaded_sub_model is None:
482+
if loaded_sub_model is None and sub_model_should_be_defined:
473483
load_method_name = None
474484
for class_name, class_candidate in class_candidates.items():
475485
if issubclass(class_obj, class_candidate):

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import List, Union
2+
from typing import List, Optional, Union
33

44
import numpy as np
55

@@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput):
2020
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
2121
nsfw_content_detected (`List[bool]`)
2222
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
23-
(nsfw) content.
23+
(nsfw) content, or `None` if safety checking could not be performed.
2424
"""
2525

2626
images: Union[List[PIL.Image.Image], np.ndarray]
27-
nsfw_content_detected: List[bool]
27+
nsfw_content_detected: Optional[List[bool]]
2828

2929

3030
if is_transformers_available() and is_torch_available():

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ def __init__(
7171
new_config["steps_offset"] = 1
7272
scheduler._internal_dict = FrozenDict(new_config)
7373

74+
if safety_checker is None:
75+
logger.warn(
76+
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
77+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
78+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
79+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
80+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
81+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
82+
)
83+
7484
self.register_modules(
7585
vae=vae,
7686
text_encoder=text_encoder,
@@ -335,10 +345,15 @@ def __call__(
335345
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
336346
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
337347

338-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
339-
image, has_nsfw_concept = self.safety_checker(
340-
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
341-
)
348+
if self.safety_checker is not None:
349+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
350+
self.device
351+
)
352+
image, has_nsfw_concept = self.safety_checker(
353+
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
354+
)
355+
else:
356+
has_nsfw_concept = None
342357

343358
if output_type == "pil":
344359
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ def __init__(
8383
new_config["steps_offset"] = 1
8484
scheduler._internal_dict = FrozenDict(new_config)
8585

86+
if safety_checker is None:
87+
logger.warn(
88+
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
89+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
90+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
91+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
92+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
93+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
94+
)
95+
8696
self.register_modules(
8797
vae=vae,
8898
text_encoder=text_encoder,
@@ -359,10 +369,15 @@ def __call__(
359369
image = (image / 2 + 0.5).clamp(0, 1)
360370
image = image.cpu().permute(0, 2, 3, 1).numpy()
361371

362-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
363-
image, has_nsfw_concept = self.safety_checker(
364-
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
365-
)
372+
if self.safety_checker is not None:
373+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
374+
self.device
375+
)
376+
image, has_nsfw_concept = self.safety_checker(
377+
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
378+
)
379+
else:
380+
has_nsfw_concept = None
366381

367382
if output_type == "pil":
368383
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ def __init__(
9898
new_config["steps_offset"] = 1
9999
scheduler._internal_dict = FrozenDict(new_config)
100100

101+
if safety_checker is None:
102+
logger.warn(
103+
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
104+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
105+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
106+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
107+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
108+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
109+
)
110+
101111
self.register_modules(
102112
vae=vae,
103113
text_encoder=text_encoder,
@@ -382,8 +392,13 @@ def __call__(
382392
image = (image / 2 + 0.5).clamp(0, 1)
383393
image = image.cpu().permute(0, 2, 3, 1).numpy()
384394

385-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
386-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
395+
if self.safety_checker is not None:
396+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
397+
self.device
398+
)
399+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
400+
else:
401+
has_nsfw_concept = None
387402

388403
if output_type == "pil":
389404
image = self.numpy_to_pil(image)

tests/test_pipelines.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,17 @@ def test_from_pretrained_error_message_uninstalled_packages(self):
498498
assert isinstance(pipe, StableDiffusionPipeline)
499499
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
500500

501+
def test_stable_diffusion_no_safety_checker(self):
502+
pipe = StableDiffusionPipeline.from_pretrained(
503+
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
504+
)
505+
assert isinstance(pipe, StableDiffusionPipeline)
506+
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
507+
assert pipe.safety_checker is None
508+
509+
image = pipe("example prompt", num_inference_steps=2).images[0]
510+
assert image is not None
511+
501512
def test_stable_diffusion_k_lms(self):
502513
device = "cpu" # ensure determinism for the device-dependent torch.Generator
503514
unet = self.dummy_cond_unet

0 commit comments

Comments
 (0)