Skip to content

Commit 48ad9a9

Browse files
committed
[AuraFlow] fix long prompt handling (#8937)
fix
1 parent d9a9cf4 commit 48ad9a9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def encode_prompt(
260260
padding="max_length",
261261
return_tensors="pt",
262262
)
263-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
264263
text_input_ids = text_inputs["input_ids"]
265264
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
266265

@@ -273,6 +272,7 @@ def encode_prompt(
273272
f" {max_length} tokens: {removed_text}"
274273
)
275274

275+
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
276276
prompt_embeds = self.text_encoder(**text_inputs)[0]
277277
prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
278278
prompt_embeds = prompt_embeds * prompt_attention_mask

0 commit comments

Comments
 (0)