-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Allow SD3 DreamBooth LoRA fine-tuning on a free-tier Colab #8762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
40d8890
add experimental scripts to train SD3 transformer lora on colab
sayakpaul 9cf038e
add readme
sayakpaul 810a67c
add colab
sayakpaul dada56c
Merge branch 'main' into colab-sd3-lora
sayakpaul bc7b4e6
Apply suggestions from code review
sayakpaul 7940575
Merge branch 'main' into colab-sd3-lora
sayakpaul e760974
fix link in the notebook.
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Running Stable Diffusion 3 DreamBooth LoRA training under 16GB | ||
|
||
This is **EDUCATIONAL** project that provides utilities to conduct DreamBooth LoRA training for [Stable Diffusion 3 (SD3)](ttps://huggingface.co/papers/2403.03206) under 16GB GPU VRAM. This means you can successfully try out this project using a free-tier Colab Notebook instance. [Here is one](./sd3_dreambooth_lora_16gb.ipynb) for you to quickly get started 🤗 | ||
|
||
> [!NOTE] | ||
> As SD3 is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform. | ||
|
||
For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above. | ||
|
||
## How | ||
|
||
We make use of several techniques to make this possible: | ||
|
||
* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit T5 to keep memory requirements manageable. More details have been provided below. | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of: | ||
* 8bit Adam for optimization through the `bitsandbytes` library. | ||
* Gradient checkpointing and gradient accumulation. | ||
* FP16 precision. | ||
* Flash attention through `F.scaled_dot_product_attention()`. | ||
|
||
Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB. | ||
|
||
For this project, we leverage 8Bit T5 (8bit as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) that reduces the memory requirements further to ~10.5GB. | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
## Gotchas | ||
|
||
This project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are are some commonly known gotchas that the users should be aware of: | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
* Training of text encoders is purposefully disabled. | ||
* Techniques such as prior-preservation is unsupported. | ||
* Custom instance captions for instance images are unsupported. But this should be relatively easy to integrate. | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Hopefully, this project gives you a template to extend it further to suit your needs. |
123 changes: 123 additions & 0 deletions
123
examples/research_projects/sd3_lora_colab/compute_embeddings.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import glob | ||
import hashlib | ||
|
||
import pandas as pd | ||
import torch | ||
from transformers import T5EncoderModel | ||
|
||
from diffusers import StableDiffusion3Pipeline | ||
|
||
|
||
PROMPT = "a photo of sks dog" | ||
MAX_SEQ_LENGTH = 77 | ||
LOCAL_DATA_DIR = "dog" | ||
OUTPUT_PATH = "sample_embeddings.parquet" | ||
|
||
|
||
def bytes_to_giga_bytes(bytes): | ||
return bytes / 1024 / 1024 / 1024 | ||
|
||
|
||
def generate_image_hash(image_path): | ||
with open(image_path, "rb") as f: | ||
img_data = f.read() | ||
return hashlib.sha256(img_data).hexdigest() | ||
|
||
|
||
def load_sd3_pipeline(): | ||
id = "stabilityai/stable-diffusion-3-medium-diffusers" | ||
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto") | ||
pipeline = StableDiffusion3Pipeline.from_pretrained( | ||
id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced" | ||
) | ||
return pipeline | ||
|
||
|
||
@torch.no_grad() | ||
def compute_embeddings(pipeline, prompt, max_sequence_length): | ||
( | ||
prompt_embeds, | ||
negative_prompt_embeds, | ||
pooled_prompt_embeds, | ||
negative_pooled_prompt_embeds, | ||
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length) | ||
|
||
print( | ||
f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}" | ||
) | ||
|
||
max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) | ||
print(f"Max memory allocated: {max_memory:.3f} GB") | ||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds | ||
|
||
|
||
def run(args): | ||
pipeline = load_sd3_pipeline() | ||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings( | ||
pipeline, args.prompt, args.max_sequence_length | ||
) | ||
|
||
# Assumes that the images within `args.local_image_dir` have a JPEG extension. Change | ||
# as needed. | ||
image_paths = glob.glob(f"{args.local_data_dir}/*.jpeg") | ||
data = [] | ||
for image_path in image_paths: | ||
img_hash = generate_image_hash(image_path) | ||
data.append( | ||
(img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) | ||
) | ||
|
||
# Create a DataFrame | ||
embedding_cols = [ | ||
"prompt_embeds", | ||
"negative_prompt_embeds", | ||
"pooled_prompt_embeds", | ||
"negative_pooled_prompt_embeds", | ||
] | ||
df = pd.DataFrame( | ||
data, | ||
columns=["image_hash"] + embedding_cols, | ||
) | ||
|
||
# Convert embedding lists to arrays (for proper storage in parquet) | ||
for col in embedding_cols: | ||
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist()) | ||
|
||
# Save the dataframe to a parquet file | ||
df.to_parquet(args.output_path) | ||
print(f"Data successfully serialized to {args.output_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.") | ||
parser.add_argument( | ||
"--max_sequence_length", | ||
type=int, | ||
default=MAX_SEQ_LENGTH, | ||
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.", | ||
) | ||
parser.add_argument( | ||
"--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images." | ||
) | ||
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.") | ||
args = parser.parse_args() | ||
|
||
run(args) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.