Skip to content

Commit 3b747de

Browse files
Victarrypatil-surajpatrickvonplaten
authored
Add training example for DreamBooth. (#554)
* Add training example for DreamBooth. * Fix bugs. * Update readme and default hyperparameters. * Reformatting code with black. * Update for multi-gpu trianing. * Apply suggestions from code review * improgve sampling * fix autocast * improve sampling more * fix saving * actuallu fix saving * fix saving * improve dataset * fix collate fun * fix collate_fn * fix collate fn * fix key name * fix dataset * fix collate fn * concat batch in collate fn * add grad ckpt * add option for 8bit adam * do two forward passes for prior preservation * Revert "do two forward passes for prior preservation" This reverts commit 661ca46. * add option for prior_loss_weight * add option for clip grad norm * add more comments * update readme * update readme * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * add docstr for dataset * update the saving logic * Update examples/dreambooth/README.md * remove unused imports Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent d886e49 commit 3b747de

File tree

2 files changed

+747
-0
lines changed

2 files changed

+747
-0
lines changed

examples/dreambooth/README.md

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# DreamBooth training example
2+
3+
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
4+
The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion.
5+
6+
7+
## Running locally
8+
### Installing the dependencies
9+
10+
Before running the scripts, make sure to install the library's training dependencies:
11+
12+
```bash
13+
pip install diffusers[training] accelerate transformers
14+
```
15+
16+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
17+
18+
```bash
19+
accelerate config
20+
```
21+
22+
### Dog toy example
23+
24+
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
25+
26+
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
27+
28+
Run the following command to authenticate your token
29+
30+
```bash
31+
huggingface-cli login
32+
```
33+
34+
If you have already cloned the repo, then you won't need to go through these steps. You can simple remove the `--use_auth_token` arg from the following command.
35+
36+
<br>
37+
38+
Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
39+
40+
And launch the training using
41+
42+
```bash
43+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
44+
export INSTANCE_DIR="path-to-instance-images"
45+
export OUTPUT_DIR="path-to-save-model"
46+
47+
accelerate launch train_dreambooth.py \
48+
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
49+
--instance_data_dir=$INSTANCE_DIR \
50+
--output_dir=$OUTPUT_DIR \
51+
--instance_prompt="a photo of sks dog" \
52+
--resolution=512 \
53+
--train_batch_size=1 \
54+
--gradient_accumulation_steps=1 \
55+
--learning_rate=5e-6 \
56+
--lr_scheduler="constant" \
57+
--lr_warmup_steps=0 \
58+
--max_train_steps=400
59+
```
60+
61+
### Training with prior-preservation loss
62+
63+
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
64+
According to the paper, it's recommened to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
65+
66+
```bash
67+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
68+
export INSTANCE_DIR="path-to-instance-images"
69+
export CLASS_DIR="path-to-class-images"
70+
export OUTPUT_DIR="path-to-save-model"
71+
72+
accelerate launch train_dreambooth.py \
73+
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
74+
--instance_data_dir=$INSTANCE_DIR \
75+
--class_data_dir=$CLASS_DIR \
76+
--output_dir=$OUTPUT_DIR \
77+
--with_prior_preservation --prior_loss_weight=1.0 \
78+
--instance_prompt="a photo of sks dog" \
79+
--class_prompt="a photo of dog" \
80+
--resolution=512 \
81+
--train_batch_size=1 \
82+
--gradient_accumulation_steps=1 \
83+
--learning_rate=5e-6 \
84+
--lr_scheduler="constant" \
85+
--lr_warmup_steps=0 \
86+
--num_class_images=200 \
87+
--max_train_steps=800
88+
```
89+
90+
### Training on a 16GB GPU:
91+
92+
With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
93+
94+
Install `bitsandbytes` with `pip install bitsandbytes`
95+
96+
```bash
97+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
98+
export INSTANCE_DIR="path-to-instance-images"
99+
export CLASS_DIR="path-to-class-images"
100+
export OUTPUT_DIR="path-to-save-model"
101+
102+
accelerate launch train_dreambooth.py \
103+
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
104+
--instance_data_dir=$INSTANCE_DIR \
105+
--class_data_dir=$CLASS_DIR \
106+
--output_dir=$OUTPUT_DIR \
107+
--with_prior_preservation --prior_loss_weight=1.0 \
108+
--instance_prompt="a photo of sks dog" \
109+
--class_prompt="a photo of dog" \
110+
--resolution=512 \
111+
--train_batch_size=1 \
112+
--gradient_accumulation_steps=2 --gradient_checkpointing \
113+
--use_8bit_adam \
114+
--learning_rate=5e-6 \
115+
--lr_scheduler="constant" \
116+
--lr_warmup_steps=0 \
117+
--num_class_images=200 \
118+
--max_train_steps=800
119+
```
120+
121+
122+
## Inference
123+
124+
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
125+
126+
```python
127+
128+
from torch import autocast
129+
from diffusers import StableDiffusionPipeline
130+
import torch
131+
132+
model_id = "path-to-your-trained-model"
133+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
134+
135+
prompt = "A photo of sks dog in a bucket"
136+
137+
with autocast("cuda"):
138+
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
139+
140+
image.save("dog-bucket.png")
141+
```

0 commit comments

Comments
 (0)