Skip to content

Commit f7ebe56

Browse files
authored
Warning for too long prompts in DiffusionPipelines (Resolve #447) (#472)
* Return encoded texts by DiffusionPipelines * Updated README to show hot to use enoded_text_input * Reverted examples in README.md * Reverted all * Warning for long prompts * Fix bugs * Formatted
1 parent 57b70c5 commit f7ebe56

File tree

4 files changed

+61
-17
lines changed

4 files changed

+61
-17
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010
from ...models import AutoencoderKL, UNet2DConditionModel
1111
from ...pipeline_utils import DiffusionPipeline
1212
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
13+
from ...utils import logging
1314
from . import StableDiffusionPipelineOutput
1415
from .safety_checker import StableDiffusionSafetyChecker
1516

1617

18+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19+
20+
1721
class StableDiffusionPipeline(DiffusionPipeline):
1822
r"""
1923
Pipeline for text-to-image generation using Stable Diffusion.
@@ -188,22 +192,30 @@ def __call__(
188192
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
189193

190194
# get prompt text embeddings
191-
text_input = self.tokenizer(
195+
text_inputs = self.tokenizer(
192196
prompt,
193197
padding="max_length",
194198
max_length=self.tokenizer.model_max_length,
195-
truncation=True,
196199
return_tensors="pt",
197200
)
198-
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
201+
text_input_ids = text_inputs.input_ids
202+
203+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
204+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
205+
logger.warning(
206+
"The following part of your input was truncated because CLIP can only handle sequences up to"
207+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
208+
)
209+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
210+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
199211

200212
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
201213
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
202214
# corresponds to doing no classifier free guidance.
203215
do_classifier_free_guidance = guidance_scale > 1.0
204216
# get unconditional embeddings for classifier free guidance
205217
if do_classifier_free_guidance:
206-
max_length = text_input.input_ids.shape[-1]
218+
max_length = text_input_ids.shape[-1]
207219
uncond_input = self.tokenizer(
208220
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
209221
)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
from ...models import AutoencoderKL, UNet2DConditionModel
1313
from ...pipeline_utils import DiffusionPipeline
1414
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15+
from ...utils import logging
1516
from . import StableDiffusionPipelineOutput
1617
from .safety_checker import StableDiffusionSafetyChecker
1718

1819

20+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21+
22+
1923
def preprocess(image):
2024
w, h = image.size
2125
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
@@ -216,22 +220,30 @@ def __call__(
216220
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
217221

218222
# get prompt text embeddings
219-
text_input = self.tokenizer(
223+
text_inputs = self.tokenizer(
220224
prompt,
221225
padding="max_length",
222226
max_length=self.tokenizer.model_max_length,
223-
truncation=True,
224227
return_tensors="pt",
225228
)
226-
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
229+
text_input_ids = text_inputs.input_ids
230+
231+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
232+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
233+
logger.warning(
234+
"The following part of your input was truncated because CLIP can only handle sequences up to"
235+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
236+
)
237+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
238+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
227239

228240
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
229241
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
230242
# corresponds to doing no classifier free guidance.
231243
do_classifier_free_guidance = guidance_scale > 1.0
232244
# get unconditional embeddings for classifier free guidance
233245
if do_classifier_free_guidance:
234-
max_length = text_input.input_ids.shape[-1]
246+
max_length = text_input_ids.shape[-1]
235247
uncond_input = self.tokenizer(
236248
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
237249
)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,22 +254,30 @@ def __call__(
254254
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
255255

256256
# get prompt text embeddings
257-
text_input = self.tokenizer(
257+
text_inputs = self.tokenizer(
258258
prompt,
259259
padding="max_length",
260260
max_length=self.tokenizer.model_max_length,
261-
truncation=True,
262261
return_tensors="pt",
263262
)
264-
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
263+
text_input_ids = text_inputs.input_ids
264+
265+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
266+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
267+
logger.warning(
268+
"The following part of your input was truncated because CLIP can only handle sequences up to"
269+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
270+
)
271+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
272+
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
265273

266274
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
267275
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
268276
# corresponds to doing no classifier free guidance.
269277
do_classifier_free_guidance = guidance_scale > 1.0
270278
# get unconditional embeddings for classifier free guidance
271279
if do_classifier_free_guidance:
272-
max_length = text_input.input_ids.shape[-1]
280+
max_length = text_input_ids.shape[-1]
273281
uncond_input = self.tokenizer(
274282
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
275283
)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
from ...onnx_utils import OnnxRuntimeModel
99
from ...pipeline_utils import DiffusionPipeline
1010
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
11+
from ...utils import logging
1112
from . import StableDiffusionPipelineOutput
1213

1314

15+
logger = logging.get_logger(__name__)
16+
17+
1418
class StableDiffusionOnnxPipeline(DiffusionPipeline):
1519
vae_decoder: OnnxRuntimeModel
1620
text_encoder: OnnxRuntimeModel
@@ -66,22 +70,30 @@ def __call__(
6670
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
6771

6872
# get prompt text embeddings
69-
text_input = self.tokenizer(
73+
text_inputs = self.tokenizer(
7074
prompt,
7175
padding="max_length",
7276
max_length=self.tokenizer.model_max_length,
73-
truncation=True,
74-
return_tensors="np",
77+
return_tensors="pt",
7578
)
76-
text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0]
79+
text_input_ids = text_inputs.input_ids
80+
81+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
82+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
83+
logger.warning(
84+
"The following part of your input was truncated because CLIP can only handle sequences up to"
85+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
86+
)
87+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
88+
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
7789

7890
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
7991
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
8092
# corresponds to doing no classifier free guidance.
8193
do_classifier_free_guidance = guidance_scale > 1.0
8294
# get unconditional embeddings for classifier free guidance
8395
if do_classifier_free_guidance:
84-
max_length = text_input.input_ids.shape[-1]
96+
max_length = text_input_ids.shape[-1]
8597
uncond_input = self.tokenizer(
8698
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
8799
)

0 commit comments

Comments
 (0)