Skip to content

handle lora scale and clip skip in lpw sd and sdxl community pipelines #8988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor

Expand Down Expand Up @@ -199,6 +203,7 @@ def get_unweighted_text_embeddings(
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
clip_skip: Optional[int] = None,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
Expand All @@ -214,7 +219,21 @@ def get_unweighted_text_embeddings(
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk)[0]
if clip_skip is None:
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
text_embedding = prompt_embeds[0]
else:
prompt_embeds = pipe.text_encoder(
text_input_chunk.to(pipe.device), output_hidden_states=True)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)

if no_boseos_middle:
if i == 0:
Expand All @@ -230,7 +249,10 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
if clip_skip is None:
clip_skip = 0
prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)
return text_embeddings


Expand All @@ -242,6 +264,8 @@ def get_weighted_text_embeddings(
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
clip_skip=None,
lora_scale=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
Expand All @@ -268,6 +292,16 @@ def get_weighted_text_embeddings(
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):
pipe._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
Comment on lines +299 to +301
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? Because without the PEFT backend, you cannot really do LoRA inference in the recent diffusers versions. No strong opinions either.

else:
scale_lora_layers(pipe.text_encoder, lora_scale)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
Expand Down Expand Up @@ -338,6 +372,7 @@ def get_weighted_text_embeddings(
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
if uncond_prompt is not None:
Expand All @@ -346,6 +381,7 @@ def get_weighted_text_embeddings(
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)

Expand All @@ -362,6 +398,11 @@ def get_weighted_text_embeddings(
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)

if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)

if uncond_prompt is not None:
return text_embeddings, uncond_embeddings
return text_embeddings, None
Expand Down Expand Up @@ -409,11 +450,7 @@ def preprocess_mask(mask, batch_size, scale_factor=8):


class StableDiffusionLongPromptWeightingPipeline(
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
FromSingleFileMixin,
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
Expand Down Expand Up @@ -549,6 +586,8 @@ def _encode_prompt(
max_embeddings_multiples=3,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[float] = None
):
r"""
Encodes the prompt into text encoder hidden states.
Expand Down Expand Up @@ -597,6 +636,8 @@ def _encode_prompt(
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=clip_skip,
lora_scale=lora_scale
)
if prompt_embeds is None:
prompt_embeds = prompt_embeds1
Expand Down Expand Up @@ -790,6 +831,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip: Optional[int] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -865,6 +907,9 @@ def __call__(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Expand Down Expand Up @@ -903,6 +948,7 @@ def __call__(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
lora_scale = (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)

# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
Expand All @@ -914,6 +960,8 @@ def __call__(
max_embeddings_multiples,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clip_skip=clip_skip,
lora_scale=lora_scale
)
dtype = prompt_embeds.dtype

Expand Down Expand Up @@ -1044,6 +1092,7 @@ def text2img(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip=None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -1101,6 +1150,9 @@ def text2img(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Expand Down Expand Up @@ -1135,6 +1187,7 @@ def text2img(
return_dict=return_dict,
callback=callback,
is_cancelled_callback=is_cancelled_callback,
clip_skip=clip_skip,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
)
Expand Down
46 changes: 40 additions & 6 deletions examples/community/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,25 @@
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor

Expand Down Expand Up @@ -261,6 +265,7 @@ def get_weighted_text_embeddings_sdxl(
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None
):
"""
This function can process long prompt with weights, no length limitation
Expand All @@ -281,6 +286,24 @@ def get_weighted_text_embeddings_sdxl(
"""
device = device or pipe._execution_device

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
pipe._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if pipe.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)

if pipe.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
else:
scale_lora_layers(pipe.text_encoder_2, lora_scale)
Comment on lines +301 to +305
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied these lines from pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py, should i just leave scale_lora_layers(pipe.text_encoder_2, lora_scale) ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh then it's okay.


if prompt_2:
prompt = f"{prompt} {prompt_2}"

Expand Down Expand Up @@ -429,6 +452,16 @@ def get_weighted_text_embeddings_sdxl(
bs_embed * num_images_per_prompt, -1
)

if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)

if pipe.text_encoder_2 is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder_2, lora_scale)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds


Expand Down Expand Up @@ -549,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
StableDiffusionMixin,
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
):
r"""
Expand All @@ -561,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings

Args:
Expand Down Expand Up @@ -743,7 +776,7 @@ def encode_prompt(

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale

if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1612,7 +1645,7 @@ def __call__(
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 3. Encode input prompt
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
lora_scale = (self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None)

negative_prompt = negative_prompt if negative_prompt is not None else ""

Expand All @@ -1627,6 +1660,7 @@ def __call__(
neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
lora_scale=lora_scale
)
dtype = prompt_embeds.dtype

Expand Down
Loading