Skip to content

Commit 8fa13a7

Browse files
authored
Fix type mismatch error, add tests for negative prompts (huggingface#823)
1 parent 2ed4ff2 commit 8fa13a7

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ def __call__(
234234
uncond_tokens = [""]
235235
elif type(prompt) is not type(negative_prompt):
236236
raise TypeError(
237-
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
238-
" {type(prompt)}."
237+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
238+
f" {type(prompt)}."
239239
)
240240
elif isinstance(negative_prompt, str):
241241
uncond_tokens = [negative_prompt]

pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def __call__(
195195
"""
196196
if isinstance(prompt, str):
197197
batch_size = 1
198-
prompt = [prompt]
199198
elif isinstance(prompt, list):
200199
batch_size = len(prompt)
201200
else:
@@ -250,8 +249,8 @@ def __call__(
250249
uncond_tokens = [""]
251250
elif type(prompt) is not type(negative_prompt):
252251
raise TypeError(
253-
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
254-
" {type(prompt)}."
252+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
253+
f" {type(prompt)}."
255254
)
256255
elif isinstance(negative_prompt, str):
257256
uncond_tokens = [negative_prompt]
@@ -285,6 +284,8 @@ def __call__(
285284
init_latents = init_latent_dist.sample(generator=generator)
286285
init_latents = 0.18215 * init_latents
287286

287+
if isinstance(prompt, str):
288+
prompt = [prompt]
288289
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
289290
# expand init_latents for batch_size
290291
deprecation_message = (

pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@ def __call__(
266266
uncond_tokens = [""]
267267
elif type(prompt) is not type(negative_prompt):
268268
raise TypeError(
269-
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
270-
" {type(prompt)}."
269+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
270+
f" {type(prompt)}."
271271
)
272272
elif isinstance(negative_prompt, str):
273273
uncond_tokens = [negative_prompt]

pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def __call__(
108108
uncond_tokens = [""] * batch_size
109109
elif type(prompt) is not type(negative_prompt):
110110
raise TypeError(
111-
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
112-
" {type(prompt)}."
111+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
112+
f" {type(prompt)}."
113113
)
114114
elif isinstance(negative_prompt, str):
115115
uncond_tokens = [negative_prompt] * batch_size

0 commit comments

Comments
 (0)