From cde3c937321289c437346298370855838e41ec07 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Oct 2022 18:49:59 +0200 Subject: [PATCH 1/7] Give more customizable options for safety checker --- src/diffusers/pipeline_utils.py | 43 ++++++++++++------- .../pipelines/stable_diffusion/__init__.py | 6 +-- .../pipeline_stable_diffusion.py | 23 ++++++++-- tests/test_pipelines.py | 11 +++++ 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index a7b6031d137a..6a29ec795a91 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -112,26 +112,29 @@ def register_modules(self, **kwargs): for name, module in kwargs.items(): # retrieve library - library = module.__module__.split(".")[0] + if module is None: + register_dict = {name: (None, None)} + else: + library = module.__module__.split(".")[0] - # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] - path = module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: - library = pipeline_dir + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir - # retrieve class_name - class_name = module.__class__.__name__ + # retrieve class_name + class_name = module.__class__.__name__ - register_dict = {name: (library, class_name)} + register_dict = {name: (library, class_name)} - # save model index config - self.register_to_config(**register_dict) + # save model index config + self.register_to_config(**register_dict) # set models setattr(self, name, module) @@ -422,6 +425,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None + sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: @@ -442,6 +446,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) + elif passed_class_obj[name] is None: + logger.warn( + f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" + f" that this might lead to problems when using {pipeline_class} and is generally not" + " recommended to do." + ) + sub_model_should_be_defined = False else: logger.warn( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" @@ -462,7 +473,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - if loaded_sub_model is None: + if loaded_sub_model is None and sub_model_should_be_defined: load_method_name = None for class_name, class_candidate in class_candidates.items(): if issubclass(class_obj, class_candidate): diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 615fa404da0b..a799a7580332 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Union +from typing import List, Optional, Union import numpy as np @@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput): num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content. + (nsfw) content. If safety checker is disabled `None` will be returned. """ images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: List[bool] + nsfw_content_detected: Optional[List[bool]] if is_transformers_available() and is_torch_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ca6c580ffc55..2ccb4a4b8744 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -71,6 +71,16 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + if safety_checker is None: + logger.warn( + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`.Please" + " make sure you have very good reasons for this and have considered the consequences of doing so.The" + " `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly" + " suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to" + " `from_pretrained`.For more information, please have a look at" + " https://github.com/huggingface/diffusers/pull/254" + ) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -335,10 +345,15 @@ def __call__( # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 8004241ac15e..e8f13ded3da5 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -498,6 +498,17 @@ def test_from_pretrained_error_message_uninstalled_packages(self): assert isinstance(pipe, StableDiffusionPipeline) assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + def test_stable_diffusion_no_safety_checker(self): + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None + ) + assert isinstance(pipe, StableDiffusionPipeline) + assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + assert pipe.safety_checker is None + + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + def test_stable_diffusion_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet From 5cbc2d3f01231739615c4e4373ce2d69beaa3967 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Oct 2022 18:50:55 +0200 Subject: [PATCH 2/7] Apply suggestions from code review --- src/diffusers/pipeline_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 6a29ec795a91..4ab054c90ffc 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -133,8 +133,9 @@ def register_modules(self, **kwargs): register_dict = {name: (library, class_name)} - # save model index config - self.register_to_config(**register_dict) + + # save model index config + self.register_to_config(**register_dict) # set models setattr(self, name, module) From a79e651b4d288e8665046a70001d17ec241565b4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Oct 2022 18:51:37 +0200 Subject: [PATCH 3/7] Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 2ccb4a4b8744..bd8a0670e132 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -77,7 +77,7 @@ def __init__( " make sure you have very good reasons for this and have considered the consequences of doing so.The" " `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly" " suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to" - " `from_pretrained`.For more information, please have a look at" + " `from_pretrained`. For more information, please have a look at" " https://github.com/huggingface/diffusers/pull/254" ) From 7e0d96040dccca5f29361d5aa79f0b5a23dbb9db Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Oct 2022 18:53:44 +0200 Subject: [PATCH 4/7] Finish --- .../pipeline_stable_diffusion_img2img.py | 23 +++++++++++++++---- .../pipeline_stable_diffusion_inpaint.py | 19 +++++++++++++-- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 3e5ac4b33582..0436fd62052a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -83,6 +83,16 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + if safety_checker is None: + logger.warn( + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`.Please" + " make sure you have very good reasons for this and have considered the consequences of doing so.The" + " `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly" + " suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to" + " `from_pretrained`.For more information, please have a look at" + " https://github.com/huggingface/diffusers/pull/254" + ) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -358,10 +368,15 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 30a588e754b3..06bc655018ed 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -98,6 +98,16 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + if safety_checker is None: + logger.warn( + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`.Please" + " make sure you have very good reasons for this and have considered the consequences of doing so.The" + " `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly" + " suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to" + " `from_pretrained`.For more information, please have a look at" + " https://github.com/huggingface/diffusers/pull/254" + ) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -382,8 +392,13 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + else: + has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) From 3307ff4cbc29aa0bbef44003c4efed055874186e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 12 Oct 2022 18:57:54 +0200 Subject: [PATCH 5/7] make style --- src/diffusers/pipeline_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 4ab054c90ffc..8cc27318caff 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -133,7 +133,6 @@ def register_modules(self, **kwargs): register_dict = {name: (library, class_name)} - # save model index config self.register_to_config(**register_dict) From a9abef0c5f448f05cd087663359394c7372c2a0c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Oct 2022 11:49:21 +0200 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/pipeline_utils.py | 3 +-- src/diffusers/pipelines/stable_diffusion/__init__.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 8cc27318caff..fa66c2aec852 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -449,8 +449,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P elif passed_class_obj[name] is None: logger.warn( f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" - f" that this might lead to problems when using {pipeline_class} and is generally not" - " recommended to do." + f" that this might lead to problems when using {pipeline_class} and is not recommended." ) sub_model_should_be_defined = False else: diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index a799a7580332..8c07afe58fc2 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -20,7 +20,7 @@ class StableDiffusionPipelineOutput(BaseOutput): num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content. If safety checker is disabled `None` will be returned. + (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bd8a0670e132..7aedd88466c4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -73,8 +73,8 @@ def __init__( if safety_checker is None: logger.warn( - f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`.Please" - " make sure you have very good reasons for this and have considered the consequences of doing so.The" + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Please" + " make sure you have very good reasons for this and have considered the consequences of doing so. The" " `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly" " suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to" " `from_pretrained`. For more information, please have a look at" From 9ee278768437b29f2169ab463eef4c7fdb5d7240 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Oct 2022 15:48:11 +0200 Subject: [PATCH 7/7] up --- .../stable_diffusion/pipeline_stable_diffusion.py | 12 ++++++------ .../pipeline_stable_diffusion_img2img.py | 12 ++++++------ .../pipeline_stable_diffusion_inpaint.py | 12 ++++++------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bd8a0670e132..c5b7ebd48baa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -73,12 +73,12 @@ def __init__( if safety_checker is None: logger.warn( - f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`.Please" - " make sure you have very good reasons for this and have considered the consequences of doing so.The" - " `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly" - " suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to" - " `from_pretrained`. For more information, please have a look at" - " https://github.com/huggingface/diffusers/pull/254" + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 0436fd62052a..b998e5c613b7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -85,12 +85,12 @@ def __init__( if safety_checker is None: logger.warn( - f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`.Please" - " make sure you have very good reasons for this and have considered the consequences of doing so.The" - " `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly" - " suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to" - " `from_pretrained`.For more information, please have a look at" - " https://github.com/huggingface/diffusers/pull/254" + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules( diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 06bc655018ed..ba31840209b1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -100,12 +100,12 @@ def __init__( if safety_checker is None: logger.warn( - f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`.Please" - " make sure you have very good reasons for this and have considered the consequences of doing so.The" - " `diffusers` team does not recommend disabling the safety under ANY circumstances and strongly" - " suggests to not disable the `safety_checker` by NOT passing `safety_checker=None` to" - " `from_pretrained`.For more information, please have a look at" - " https://github.com/huggingface/diffusers/pull/254" + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules(