Skip to content

Commit 3ca7f09

Browse files
sayakpaulstevhliu
andcommitted
Allow SD3 DreamBooth LoRA fine-tuning on a free-tier Colab (#8762)
* add experimental scripts to train SD3 transformer lora on colab * add readme * add colab * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix link in the notebook. --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent 11c5f6b commit 3ca7f09

File tree

4 files changed

+3736
-0
lines changed

4 files changed

+3736
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Running Stable Diffusion 3 DreamBooth LoRA training under 16GB
2+
3+
This is an **EDUCATIONAL** project that provides utilities for DreamBooth LoRA training for [Stable Diffusion 3 (SD3)](ttps://huggingface.co/papers/2403.03206) under 16GB GPU VRAM. This means you can successfully try out this project using a [free-tier Colab Notebook](./sd3_dreambooth_lora_16gb.ipynb) instance. 🤗
4+
5+
> [!NOTE]
6+
> SD3 is gated, so you need to make sure you agree to [share your contact info](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) to access the model before using it with Diffusers. Once you have access, you need to log in so your system knows you’re authorized. Use the command below to log in:
7+
8+
```bash
9+
huggingface-cli login
10+
```
11+
12+
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
13+
14+
For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above.
15+
16+
## How
17+
18+
We make use of several techniques to make this possible:
19+
20+
* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
21+
* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of:
22+
* 8bit Adam for optimization through the `bitsandbytes` library.
23+
* Gradient checkpointing and gradient accumulation.
24+
* FP16 precision.
25+
* Flash attention through `F.scaled_dot_product_attention()`.
26+
27+
Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB.
28+
29+
30+
## Gotchas
31+
32+
This project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are some commonly known gotchas that users should be aware of:
33+
34+
* Training of text encoders is purposefully disabled.
35+
* Techniques such as prior-preservation is unsupported.
36+
* Custom instance captions for instance images are unsupported, but this should be relatively easy to integrate.
37+
38+
Hopefully, this project gives you a template to extend it further to suit your needs.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import argparse
18+
import glob
19+
import hashlib
20+
21+
import pandas as pd
22+
import torch
23+
from transformers import T5EncoderModel
24+
25+
from diffusers import StableDiffusion3Pipeline
26+
27+
28+
PROMPT = "a photo of sks dog"
29+
MAX_SEQ_LENGTH = 77
30+
LOCAL_DATA_DIR = "dog"
31+
OUTPUT_PATH = "sample_embeddings.parquet"
32+
33+
34+
def bytes_to_giga_bytes(bytes):
35+
return bytes / 1024 / 1024 / 1024
36+
37+
38+
def generate_image_hash(image_path):
39+
with open(image_path, "rb") as f:
40+
img_data = f.read()
41+
return hashlib.sha256(img_data).hexdigest()
42+
43+
44+
def load_sd3_pipeline():
45+
id = "stabilityai/stable-diffusion-3-medium-diffusers"
46+
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto")
47+
pipeline = StableDiffusion3Pipeline.from_pretrained(
48+
id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced"
49+
)
50+
return pipeline
51+
52+
53+
@torch.no_grad()
54+
def compute_embeddings(pipeline, prompt, max_sequence_length):
55+
(
56+
prompt_embeds,
57+
negative_prompt_embeds,
58+
pooled_prompt_embeds,
59+
negative_pooled_prompt_embeds,
60+
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length)
61+
62+
print(
63+
f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}"
64+
)
65+
66+
max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
67+
print(f"Max memory allocated: {max_memory:.3f} GB")
68+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
69+
70+
71+
def run(args):
72+
pipeline = load_sd3_pipeline()
73+
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings(
74+
pipeline, args.prompt, args.max_sequence_length
75+
)
76+
77+
# Assumes that the images within `args.local_image_dir` have a JPEG extension. Change
78+
# as needed.
79+
image_paths = glob.glob(f"{args.local_data_dir}/*.jpeg")
80+
data = []
81+
for image_path in image_paths:
82+
img_hash = generate_image_hash(image_path)
83+
data.append(
84+
(img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
85+
)
86+
87+
# Create a DataFrame
88+
embedding_cols = [
89+
"prompt_embeds",
90+
"negative_prompt_embeds",
91+
"pooled_prompt_embeds",
92+
"negative_pooled_prompt_embeds",
93+
]
94+
df = pd.DataFrame(
95+
data,
96+
columns=["image_hash"] + embedding_cols,
97+
)
98+
99+
# Convert embedding lists to arrays (for proper storage in parquet)
100+
for col in embedding_cols:
101+
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
102+
103+
# Save the dataframe to a parquet file
104+
df.to_parquet(args.output_path)
105+
print(f"Data successfully serialized to {args.output_path}")
106+
107+
108+
if __name__ == "__main__":
109+
parser = argparse.ArgumentParser()
110+
parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.")
111+
parser.add_argument(
112+
"--max_sequence_length",
113+
type=int,
114+
default=MAX_SEQ_LENGTH,
115+
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
116+
)
117+
parser.add_argument(
118+
"--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images."
119+
)
120+
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
121+
args = parser.parse_args()
122+
123+
run(args)

0 commit comments

Comments
 (0)