Skip to content

Commit 59b8506

Browse files
ryanrussellPrathik Rao
authored andcommitted
refactor: pipelines readability improvements (huggingface#622)
* refactor: pipelines readability improvements Signed-off-by: Ryan Russell <git@ryanrussell.org> * docs: remove todo comment from flax pipeline Signed-off-by: Ryan Russell <git@ryanrussell.org> Signed-off-by: Ryan Russell <git@ryanrussell.org>
1 parent 78ae5a2 commit 59b8506

File tree

5 files changed

+24
-25
lines changed

5 files changed

+24
-25
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
3434
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
3535
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
3636
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
37-
Classification module that estimates whether generated images could be considered offsensive or harmful.
37+
Classification module that estimates whether generated images could be considered offensive or harmful.
3838
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
3939
feature_extractor ([`CLIPFeatureExtractor`]):
4040
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
@@ -149,7 +149,6 @@ def __call__(
149149
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
150150
context = jnp.concatenate([uncond_embeddings, text_embeddings])
151151

152-
# TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline
153152
latents_shape = (
154153
batch_size,
155154
self.unet.in_channels,
@@ -206,9 +205,9 @@ def loop_body(step, args):
206205
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
207206
# run safety checker
208207
# TODO: check when flax safety checker gets merged into main
209-
# safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
208+
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
210209
# image, has_nsfw_concept = self.safety_checker(
211-
# images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"]
210+
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
212211
# )
213212
has_nsfw_concept = False
214213

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
3636
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
3737
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
3838
safety_checker ([`StableDiffusionSafetyChecker`]):
39-
Classification module that estimates whether generated images could be considered offsensive or harmful.
39+
Classification module that estimates whether generated images could be considered offensive or harmful.
4040
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
4141
feature_extractor ([`CLIPFeatureExtractor`]):
4242
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
@@ -278,8 +278,8 @@ def __call__(
278278
image = image.cpu().permute(0, 2, 3, 1).numpy()
279279

280280
# run safety checker
281-
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
282-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
281+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
282+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
283283

284284
if output_type == "pil":
285285
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
4848
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
4949
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
5050
safety_checker ([`StableDiffusionSafetyChecker`]):
51-
Classification module that estimates whether generated images could be considered offsensive or harmful.
51+
Classification module that estimates whether generated images could be considered offensive or harmful.
5252
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
5353
feature_extractor ([`CLIPFeatureExtractor`]):
5454
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
@@ -288,8 +288,8 @@ def __call__(
288288
image = image.cpu().permute(0, 2, 3, 1).numpy()
289289

290290
# run safety checker
291-
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
292-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
291+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
292+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
293293

294294
if output_type == "pil":
295295
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
6666
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
6767
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
6868
safety_checker ([`StableDiffusionSafetyChecker`]):
69-
Classification module that estimates whether generated images could be considered offsensive or harmful.
69+
Classification module that estimates whether generated images could be considered offensive or harmful.
7070
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
7171
feature_extractor ([`CLIPFeatureExtractor`]):
7272
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
@@ -328,8 +328,8 @@ def __call__(
328328
image = image.cpu().permute(0, 2, 3, 1).numpy()
329329

330330
# run safety checker
331-
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
332-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
331+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
332+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
333333

334334
if output_type == "pil":
335335
image = self.numpy_to_pil(image)

src/diffusers/pipelines/stable_diffusion/safety_checker.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,20 @@ def forward(self, clip_input, images):
4848
# at the cost of increasing the possibility of filtering benign images
4949
adjustment = 0.0
5050

51-
for concet_idx in range(len(special_cos_dist[0])):
52-
concept_cos = special_cos_dist[i][concet_idx]
53-
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
54-
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
55-
if result_img["special_scores"][concet_idx] > 0:
56-
result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
51+
for concept_idx in range(len(special_cos_dist[0])):
52+
concept_cos = special_cos_dist[i][concept_idx]
53+
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
54+
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
55+
if result_img["special_scores"][concept_idx] > 0:
56+
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
5757
adjustment = 0.01
5858

59-
for concet_idx in range(len(cos_dist[0])):
60-
concept_cos = cos_dist[i][concet_idx]
61-
concept_threshold = self.concept_embeds_weights[concet_idx].item()
62-
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
63-
if result_img["concept_scores"][concet_idx] > 0:
64-
result_img["bad_concepts"].append(concet_idx)
59+
for concept_idx in range(len(cos_dist[0])):
60+
concept_cos = cos_dist[i][concept_idx]
61+
concept_threshold = self.concept_embeds_weights[concept_idx].item()
62+
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
63+
if result_img["concept_scores"][concept_idx] > 0:
64+
result_img["bad_concepts"].append(concept_idx)
6565

6666
result.append(result_img)
6767

0 commit comments

Comments
 (0)