Skip to content

Commit 26c7df5

Browse files
authored
Fix type mismatch error, add tests for negative prompts (#823)
1 parent e001fed commit 26c7df5

File tree

5 files changed

+138
-9
lines changed

5 files changed

+138
-9
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ def __call__(
234234
uncond_tokens = [""]
235235
elif type(prompt) is not type(negative_prompt):
236236
raise TypeError(
237-
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
238-
" {type(prompt)}."
237+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
238+
f" {type(prompt)}."
239239
)
240240
elif isinstance(negative_prompt, str):
241241
uncond_tokens = [negative_prompt]

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def __call__(
195195
"""
196196
if isinstance(prompt, str):
197197
batch_size = 1
198-
prompt = [prompt]
199198
elif isinstance(prompt, list):
200199
batch_size = len(prompt)
201200
else:
@@ -250,8 +249,8 @@ def __call__(
250249
uncond_tokens = [""]
251250
elif type(prompt) is not type(negative_prompt):
252251
raise TypeError(
253-
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
254-
" {type(prompt)}."
252+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
253+
f" {type(prompt)}."
255254
)
256255
elif isinstance(negative_prompt, str):
257256
uncond_tokens = [negative_prompt]
@@ -285,6 +284,8 @@ def __call__(
285284
init_latents = init_latent_dist.sample(generator=generator)
286285
init_latents = 0.18215 * init_latents
287286

287+
if isinstance(prompt, str):
288+
prompt = [prompt]
288289
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
289290
# expand init_latents for batch_size
290291
deprecation_message = (

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def __call__(
266266
uncond_tokens = [""]
267267
elif type(prompt) is not type(negative_prompt):
268268
raise TypeError(
269-
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
270-
" {type(prompt)}."
269+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
270+
f" {type(prompt)}."
271271
)
272272
elif isinstance(negative_prompt, str):
273273
uncond_tokens = [negative_prompt]

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def __call__(
108108
uncond_tokens = [""] * batch_size
109109
elif type(prompt) is not type(negative_prompt):
110110
raise TypeError(
111-
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
112-
" {type(prompt)}."
111+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
112+
f" {type(prompt)}."
113113
)
114114
elif isinstance(negative_prompt, str):
115115
uncond_tokens = [negative_prompt] * batch_size

tests/test_pipelines.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,46 @@ def test_stable_diffusion_attention_chunk(self):
575575

576576
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
577577

578+
def test_stable_diffusion_negative_prompt(self):
579+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
580+
unet = self.dummy_cond_unet
581+
scheduler = PNDMScheduler(skip_prk_steps=True)
582+
vae = self.dummy_vae
583+
bert = self.dummy_text_encoder
584+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
585+
586+
# make sure here that pndm scheduler skips prk
587+
sd_pipe = StableDiffusionPipeline(
588+
unet=unet,
589+
scheduler=scheduler,
590+
vae=vae,
591+
text_encoder=bert,
592+
tokenizer=tokenizer,
593+
safety_checker=self.dummy_safety_checker,
594+
feature_extractor=self.dummy_extractor,
595+
)
596+
sd_pipe = sd_pipe.to(device)
597+
sd_pipe.set_progress_bar_config(disable=None)
598+
599+
prompt = "A painting of a squirrel eating a burger"
600+
negative_prompt = "french fries"
601+
generator = torch.Generator(device=device).manual_seed(0)
602+
output = sd_pipe(
603+
prompt,
604+
negative_prompt=negative_prompt,
605+
generator=generator,
606+
guidance_scale=6.0,
607+
num_inference_steps=2,
608+
output_type="np",
609+
)
610+
611+
image = output.images
612+
image_slice = image[0, -3:, -3:, -1]
613+
614+
assert image.shape == (1, 128, 128, 3)
615+
expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719])
616+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
617+
578618
def test_score_sde_ve_pipeline(self):
579619
unet = self.dummy_uncond_unet
580620
scheduler = ScoreSdeVeScheduler()
@@ -704,6 +744,48 @@ def test_stable_diffusion_img2img(self):
704744
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
705745
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
706746

747+
def test_stable_diffusion_img2img_negative_prompt(self):
748+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
749+
unet = self.dummy_cond_unet
750+
scheduler = PNDMScheduler(skip_prk_steps=True)
751+
vae = self.dummy_vae
752+
bert = self.dummy_text_encoder
753+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
754+
755+
init_image = self.dummy_image.to(device)
756+
757+
# make sure here that pndm scheduler skips prk
758+
sd_pipe = StableDiffusionImg2ImgPipeline(
759+
unet=unet,
760+
scheduler=scheduler,
761+
vae=vae,
762+
text_encoder=bert,
763+
tokenizer=tokenizer,
764+
safety_checker=self.dummy_safety_checker,
765+
feature_extractor=self.dummy_extractor,
766+
)
767+
sd_pipe = sd_pipe.to(device)
768+
sd_pipe.set_progress_bar_config(disable=None)
769+
770+
prompt = "A painting of a squirrel eating a burger"
771+
negative_prompt = "french fries"
772+
generator = torch.Generator(device=device).manual_seed(0)
773+
output = sd_pipe(
774+
prompt,
775+
negative_prompt=negative_prompt,
776+
generator=generator,
777+
guidance_scale=6.0,
778+
num_inference_steps=2,
779+
output_type="np",
780+
init_image=init_image,
781+
)
782+
image = output.images
783+
image_slice = image[0, -3:, -3:, -1]
784+
785+
assert image.shape == (1, 32, 32, 3)
786+
expected_slice = np.array([0.4065, 0.3783, 0.4050, 0.5266, 0.4781, 0.4252, 0.4203, 0.4692, 0.4365])
787+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
788+
707789
def test_stable_diffusion_img2img_multiple_init_images(self):
708790
device = "cpu" # ensure determinism for the device-dependent torch.Generator
709791
unet = self.dummy_cond_unet
@@ -861,6 +943,52 @@ def test_stable_diffusion_inpaint(self):
861943
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
862944
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
863945

946+
def test_stable_diffusion_inpaint_negative_prompt(self):
947+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
948+
unet = self.dummy_cond_unet
949+
scheduler = PNDMScheduler(skip_prk_steps=True)
950+
vae = self.dummy_vae
951+
bert = self.dummy_text_encoder
952+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
953+
954+
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
955+
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
956+
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
957+
958+
# make sure here that pndm scheduler skips prk
959+
sd_pipe = StableDiffusionInpaintPipeline(
960+
unet=unet,
961+
scheduler=scheduler,
962+
vae=vae,
963+
text_encoder=bert,
964+
tokenizer=tokenizer,
965+
safety_checker=self.dummy_safety_checker,
966+
feature_extractor=self.dummy_extractor,
967+
)
968+
sd_pipe = sd_pipe.to(device)
969+
sd_pipe.set_progress_bar_config(disable=None)
970+
971+
prompt = "A painting of a squirrel eating a burger"
972+
negative_prompt = "french fries"
973+
generator = torch.Generator(device=device).manual_seed(0)
974+
output = sd_pipe(
975+
prompt,
976+
negative_prompt=negative_prompt,
977+
generator=generator,
978+
guidance_scale=6.0,
979+
num_inference_steps=2,
980+
output_type="np",
981+
init_image=init_image,
982+
mask_image=mask_image,
983+
)
984+
985+
image = output.images
986+
image_slice = image[0, -3:, -3:, -1]
987+
988+
assert image.shape == (1, 32, 32, 3)
989+
expected_slice = np.array([0.4765, 0.5339, 0.4541, 0.6240, 0.5439, 0.4055, 0.5503, 0.5891, 0.5150])
990+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
991+
864992
def test_stable_diffusion_num_images_per_prompt(self):
865993
device = "cpu" # ensure determinism for the device-dependent torch.Generator
866994
unet = self.dummy_cond_unet

0 commit comments

Comments
 (0)