-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Add training example for DreamBooth. #554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
bd4d674
Add training example for DreamBooth.
Victarry 51340f9
Fix bugs.
Victarry 88ab347
Update readme and default hyperparameters.
Victarry 5bb534b
Reformatting code with black.
Victarry faffe23
Update for multi-gpu trianing.
Victarry 2eeabe7
Apply suggestions from code review
patil-suraj 195cd46
improgve sampling
patil-suraj 1acc678
fix autocast
patil-suraj 627cc49
improve sampling more
patil-suraj f1c3c8e
fix saving
patil-suraj 509e4e3
actuallu fix saving
patil-suraj eafc000
fix saving
patil-suraj 6f99f29
improve dataset
patil-suraj 392fbf3
fix collate fun
patil-suraj d6c88f4
fix collate_fn
patil-suraj a3d604e
fix collate fn
patil-suraj f4a91a6
fix key name
patil-suraj 8e92d69
fix dataset
patil-suraj ef01331
fix collate fn
patil-suraj c66cf4d
concat batch in collate fn
patil-suraj 2894a92
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patil-suraj 16ecc08
add grad ckpt
patil-suraj 87bc752
add option for 8bit adam
patil-suraj 661ca46
do two forward passes for prior preservation
patil-suraj ce2a3be
Revert "do two forward passes for prior preservation"
patil-suraj 248e77d
add option for prior_loss_weight
patil-suraj abbb614
add option for clip grad norm
patil-suraj c05b043
add more comments
patil-suraj 90bac83
update readme
patil-suraj 89991a1
update readme
patil-suraj 265d2b1
Apply suggestions from code review
patil-suraj d63fa4d
add docstr for dataset
patil-suraj 102ad70
update the saving logic
patil-suraj 7ad4316
Update examples/dreambooth/README.md
patil-suraj d72c659
remove unused imports
patil-suraj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# DreamBooth training example | ||
|
||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. | ||
The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. | ||
|
||
|
||
## Running locally | ||
### Installing the dependencies | ||
|
||
Before running the scripts, make sure to install the library's training dependencies: | ||
|
||
```bash | ||
pip install diffusers[training] accelerate transformers | ||
``` | ||
|
||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
||
```bash | ||
accelerate config | ||
``` | ||
|
||
### Dog toy example | ||
|
||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. | ||
|
||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). | ||
|
||
Run the following command to authenticate your token | ||
|
||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
If you have already cloned the repo, then you won't need to go through these steps. You can simple remove the `--use_auth_token` arg from the following command. | ||
|
||
<br> | ||
|
||
Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data. | ||
|
||
And launch the training using | ||
|
||
```bash | ||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | ||
export INSTANCE_DIR="path-to-instance-images" | ||
export OUTPUT_DIR="path-to-save-model" | ||
|
||
accelerate launch train_dreambooth.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ | ||
--instance_data_dir=$INSTANCE_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--instance_prompt="a photo of sks dog" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=1 \ | ||
--learning_rate=5e-6 \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--max_train_steps=400 | ||
``` | ||
|
||
### Training with prior-preservation loss | ||
|
||
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. | ||
According to the paper, it's recommened to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. | ||
|
||
```bash | ||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | ||
export INSTANCE_DIR="path-to-instance-images" | ||
export CLASS_DIR="path-to-class-images" | ||
export OUTPUT_DIR="path-to-save-model" | ||
|
||
accelerate launch train_dreambooth.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ | ||
--instance_data_dir=$INSTANCE_DIR \ | ||
--class_data_dir=$CLASS_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--with_prior_preservation --prior_loss_weight=1.0 \ | ||
--instance_prompt="a photo of sks dog" \ | ||
--class_prompt="a photo of dog" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=1 \ | ||
--learning_rate=5e-6 \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--num_class_images=200 \ | ||
--max_train_steps=800 | ||
``` | ||
|
||
### Training on a 16GB GPU: | ||
|
||
With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU. | ||
|
||
Install `bitsandbytes` with `pip install bitsandbytes` | ||
|
||
```bash | ||
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | ||
export INSTANCE_DIR="path-to-instance-images" | ||
export CLASS_DIR="path-to-class-images" | ||
export OUTPUT_DIR="path-to-save-model" | ||
|
||
accelerate launch train_dreambooth.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ | ||
--instance_data_dir=$INSTANCE_DIR \ | ||
--class_data_dir=$CLASS_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--with_prior_preservation --prior_loss_weight=1.0 \ | ||
--instance_prompt="a photo of sks dog" \ | ||
--class_prompt="a photo of dog" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=2 --gradient_checkpointing \ | ||
--use_8bit_adam \ | ||
--learning_rate=5e-6 \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--num_class_images=200 \ | ||
--max_train_steps=800 | ||
``` | ||
|
||
|
||
## Inference | ||
|
||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. | ||
|
||
```python | ||
|
||
from torch import autocast | ||
from diffusers import StableDiffusionPipeline | ||
import torch | ||
|
||
model_id = "path-to-your-trained-model" | ||
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") | ||
|
||
prompt = "A photo of sks dog in a bucket" | ||
|
||
with autocast("cuda"): | ||
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] | ||
|
||
image.save("dog-bucket.png") | ||
``` |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.