From bd4d6743d6ce3a9fb52b67371e6eb0af67534445 Mon Sep 17 00:00:00 2001 From: Zhenhuan Liu Date: Sun, 18 Sep 2022 10:56:39 -0700 Subject: [PATCH 01/34] Add training example for DreamBooth. --- examples/dreambooth/README.md | 82 ++++ examples/dreambooth/train_dreambooth.py | 543 ++++++++++++++++++++++++ 2 files changed, 625 insertions(+) create mode 100644 examples/dreambooth/README.md create mode 100644 examples/dreambooth/train_dreambooth.py diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md new file mode 100644 index 000000000000..5d6f3b9eb9a5 --- /dev/null +++ b/examples/dreambooth/README.md @@ -0,0 +1,82 @@ +## DreamBooth training example + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. +The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. + +## Running locally +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install diffusers[training] accelerate transformers +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +### Dog toy example + +You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. + +You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). + +Run the following command to authenticate your token + +```bash +huggingface-cli login +``` + +If you have already cloned the repo, then you won't need to go through these steps. You can simple remove the `--use_auth_token` arg from the following command. + +
+ +Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data. + +And launch the training using + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" + +python train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --resolution=512 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-5 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --output_dir="dreambooth_dog" \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" + --num_class_images=1000 \ + --max_train_steps=3000 +``` + + +### Inference + +Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. + +```python + +from torch import autocast +from diffusers import StableDiffusionPipeline + +model_id = "path-to-your-trained-model" +pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda") + +prompt = "A photo of sks dog in a bucket" + +with autocast("cuda"): + image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("dog-bucket.png") +``` diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py new file mode 100644 index 000000000000..4363b1c358bf --- /dev/null +++ b/examples/dreambooth/train_dreambooth.py @@ -0,0 +1,543 @@ +import argparse +import math +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.utils.data import Dataset + +import PIL +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +logger = get_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifing the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided intance images.", + ) + parser.add_argument( + "--without_prior_preservation", + default=False, + action="store_true", + help="Flag to remove prior perservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior perversation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--use_auth_token", + action="store_true", + help=( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" + " private models)." + ), + ) + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.instance_data_dir is None: + raise ValueError("You must specify a train data directory.") + + return args + + +class DreamBoothDataset(Dataset): + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + interpolation="bicubic", + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + assert self.instance_data_root.exists(), "Instance images root doesn't exists." + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + assert self.class_data_root.exists(), "Class images root doesn't exists." + self.class_images_path = list(Path(class_data_root).iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + + def __len__(self): + return self._length + + def transform_image(self, image: Image): + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + image = torch.from_numpy(image).permute(2, 0, 1) + return image + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.transform_image(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.transform_image(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + if args.seed is not None: + set_seed(args.seed) + + if not args.without_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + sd_model = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token + ) + sd_model = accelerator.prepare(sd_model) + sd_model.to(accelerator.device) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + total_prompts = [args.class_prompt] * num_new_images + batch_prompts = [ + total_prompts[x : x + args.sample_batch_size] for x in range(0, num_new_images, args.sample_batch_size) + ] + + img_id = cur_class_images + for text in tqdm( + batch_prompts, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + with torch.no_grad(): + images = sd_model(text, height=512, width=512, num_inference_steps=50)["sample"] + + for image in images: + image.save(class_images_dir / f"{img_id}.jpg") + img_id += 1 + del sd_model + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token + ) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + optimizer = torch.optim.AdamW( + unet.parameters(), # only optimize unet + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + ) + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if not args.without_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + # Move vae and unet to device + vae.to(accelerator.device) + unet.to(accelerator.device) + + # Keep vae in eval model as we don't train it + vae.eval() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + # Convert images to latent space + if not args.without_prior_preservation: + images = torch.cat([batch["instance_images"], batch["class_images"]], dim=0) + input_ids = torch.cat([batch["instance_prompt_ids"], batch["class_prompt_ids"]], dim=0) + else: + images = batch["instance_images"] + input_ids = batch["instance_prompt_ids"] + + latents = vae.encode(images).latent_dist.sample().detach() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids)[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + accelerator.backward(loss) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline( + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub( + args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() From 51340f9921cbdf30fb6e7c5da6c996d15bce8854 Mon Sep 17 00:00:00 2001 From: Zhenhuan Liu Date: Thu, 22 Sep 2022 10:19:03 -0700 Subject: [PATCH 02/34] Fix bugs. --- examples/dreambooth/train_dreambooth.py | 94 +++++++++++++++---------- 1 file changed, 55 insertions(+), 39 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 4363b1c358bf..6b121cfd6f4a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -52,7 +52,7 @@ def parse_args(): "--class_data_dir", type=str, default=None, - required=True, + required=False, help="A folder containing the training data of class images.", ) parser.add_argument( @@ -68,10 +68,10 @@ def parse_args(): help="The prompt to specify images in the same class as provided intance images.", ) parser.add_argument( - "--without_prior_preservation", + "--with_prior_preservation", default=False, action="store_true", - help="Flag to remove prior perservation loss.", + help="Flag to add prior perservation loss.", ) parser.add_argument( "--num_class_images", @@ -123,7 +123,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=1e-5, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -193,6 +193,13 @@ def parse_args(): if args.instance_data_dir is None: raise ValueError("You must specify a train data directory.") + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + return args @@ -222,7 +229,7 @@ def __init__( if class_data_root is not None: self.class_data_root = Path(class_data_root) - assert self.class_data_root.exists(), "Class images root doesn't exists." + self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(Path(class_data_root).iterdir()) self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) @@ -289,6 +296,19 @@ def __getitem__(self, index): return example +class PromptDataset(Dataset): + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: @@ -314,7 +334,7 @@ def main(): if args.seed is not None: set_seed(args.seed) - if not args.without_prior_preservation: + if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) if not class_images_dir.exists(): class_images_dir.mkdir(parents=True) @@ -324,27 +344,24 @@ def main(): sd_model = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token ) - sd_model = accelerator.prepare(sd_model) - sd_model.to(accelerator.device) - num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") - total_prompts = [args.class_prompt] * num_new_images - batch_prompts = [ - total_prompts[x : x + args.sample_batch_size] for x in range(0, num_new_images, args.sample_batch_size) - ] - - img_id = cur_class_images - for text in tqdm( - batch_prompts, desc="Generating class images", disable=not accelerator.is_local_main_process + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sd_model, sample_dataloader = accelerator.prepare(sd_model, sample_dataloader) + sd_model.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): with torch.no_grad(): - images = sd_model(text, height=512, width=512, num_inference_steps=50)["sample"] + images = sd_model(example["prompt"], height=512, width=512, num_inference_steps=50)["sample"] - for image in images: - image.save(class_images_dir / f"{img_id}.jpg") - img_id += 1 - del sd_model + for image, index in zip(images, example["index"]): + image.save(class_images_dir / f"{index + cur_class_images}.jpg") + del sd_model # Handle the repository creation if accelerator.is_main_process: @@ -363,7 +380,7 @@ def main(): elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # Load the tokenizer and add the placeholder token as a additional special token + # Load the tokenizer if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: @@ -401,7 +418,7 @@ def main(): train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, - class_data_root=args.class_data_dir if not args.without_prior_preservation else None, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, tokenizer=tokenizer, size=args.resolution, @@ -423,15 +440,12 @@ def main(): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler + text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler ) - # Move vae and unet to device - vae.to(accelerator.device) - unet.to(accelerator.device) - - # Keep vae in eval model as we don't train it + # Keep text_encoder and vae in eval model as we don't train it + text_encoder.eval() vae.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -463,19 +477,20 @@ def main(): global_step = 0 for epoch in range(args.num_train_epochs): - text_encoder.train() + unet.train() for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): + with accelerator.accumulate(unet): # Convert images to latent space - if not args.without_prior_preservation: + if args.with_prior_preservation: images = torch.cat([batch["instance_images"], batch["class_images"]], dim=0) input_ids = torch.cat([batch["instance_prompt_ids"], batch["class_prompt_ids"]], dim=0) else: images = batch["instance_images"] input_ids = batch["instance_prompt_ids"] - latents = vae.encode(images).latent_dist.sample().detach() - latents = latents * 0.18215 + with torch.no_grad(): + latents = vae.encode(images).latent_dist.sample() + latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) @@ -490,7 +505,8 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(input_ids)[0] + with torch.no_grad(): + encoder_hidden_states = text_encoder(input_ids)[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -520,8 +536,8 @@ def main(): if accelerator.is_main_process: pipeline = StableDiffusionPipeline( text_encoder=accelerator.unwrap_model(text_encoder), - vae=vae, - unet=unet, + vae=accelerator.unwrap_model(vae), + unet=accelerator.unwrap_model(unet), tokenizer=tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True From 88ab347d3ce0082ae6d44f36a423de786b46c330 Mon Sep 17 00:00:00 2001 From: Zhenhuan Liu Date: Thu, 22 Sep 2022 10:19:45 -0700 Subject: [PATCH 03/34] Update readme and default hyperparameters. --- examples/dreambooth/README.md | 36 ++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 5d6f3b9eb9a5..97a318b38a33 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -38,26 +38,48 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/ And launch the training using +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export OUTPUT_DIR="path-to-save-model" + +python train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=400 +``` + +Training with prior-preservation loss using ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export INSTANCE_DIR="path-to-instance-images" export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" python train_dreambooth.py \ --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ --instance_data_dir=$INSTANCE_DIR \ --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ --resolution=512 \ - --train_batch_size=4 \ + --train_batch_size=1 \ --gradient_accumulation_steps=1 \ - --learning_rate=1e-5 \ + --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ - --output_dir="dreambooth_dog" \ - --instance_prompt="a photo of sks dog" \ - --class_prompt="a photo of dog" - --num_class_images=1000 \ - --max_train_steps=3000 + --num_class_images=200 \ + --max_train_steps=1000 ``` From 5bb534b0a50d806d94e2d914b3eef9509a66c07b Mon Sep 17 00:00:00 2001 From: Zhenhuan Liu Date: Thu, 22 Sep 2022 10:23:53 -0700 Subject: [PATCH 04/34] Reformatting code with black. --- examples/dreambooth/train_dreambooth.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 6b121cfd6f4a..24be2418f48c 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -193,14 +193,13 @@ def parse_args(): if args.instance_data_dir is None: raise ValueError("You must specify a train data directory.") - + if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") - return args @@ -296,20 +295,22 @@ def __getitem__(self, index): return example + class PromptDataset(Dataset): def __init__(self, prompt, num_samples): self.prompt = prompt self.num_samples = num_samples - + def __len__(self): return self.num_samples - + def __getitem__(self, index): example = {} example["prompt"] = self.prompt example["index"] = index return example + def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() From faffe23627d1c19dd1c8d31ef6e7139287f2a99f Mon Sep 17 00:00:00 2001 From: Zhenhuan Liu Date: Fri, 23 Sep 2022 04:31:15 -0400 Subject: [PATCH 05/34] Update for multi-gpu trianing. --- examples/dreambooth/README.md | 4 ++-- examples/dreambooth/train_dreambooth.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 97a318b38a33..ee936d419b0d 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -43,7 +43,7 @@ export MODEL_NAME="CompVis/stable-diffusion-v1-4" export INSTANCE_DIR="path-to-instance-images" export OUTPUT_DIR="path-to-save-model" -python train_dreambooth.py \ +accelerate launch train_dreambooth.py \ --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ @@ -64,7 +64,7 @@ export INSTANCE_DIR="path-to-instance-images" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" -python train_dreambooth.py \ +accelerate launch train_dreambooth.py \ --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ --instance_data_dir=$INSTANCE_DIR \ --class_data_dir=$CLASS_DIR \ diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 24be2418f48c..e160e00c700b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -345,6 +345,7 @@ def main(): sd_model = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token ) + sd_model.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") @@ -441,10 +442,14 @@ def main(): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) + # Move text_encode and vae to gpu + text_encoder.to(accelerator.device) + vae.to(accelerator.device) + # Keep text_encoder and vae in eval model as we don't train it text_encoder.eval() vae.eval() @@ -536,8 +541,8 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: pipeline = StableDiffusionPipeline( - text_encoder=accelerator.unwrap_model(text_encoder), - vae=accelerator.unwrap_model(vae), + text_encoder=text_encoder, + vae=vae, unet=accelerator.unwrap_model(unet), tokenizer=tokenizer, scheduler=PNDMScheduler( From 2eeabe7adce54ac510187091ca801813487237d1 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 26 Sep 2022 10:34:27 +0200 Subject: [PATCH 06/34] Apply suggestions from code review --- examples/dreambooth/train_dreambooth.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e160e00c700b..498269b05f30 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -359,7 +359,7 @@ def main(): sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): with torch.no_grad(): - images = sd_model(example["prompt"], height=512, width=512, num_inference_steps=50)["sample"] + images = sd_model(example["prompt"], height=512, width=512, num_inference_steps=50).images for image, index in zip(images, example["index"]): image.save(class_images_dir / f"{index + cur_class_images}.jpg") @@ -450,9 +450,6 @@ def main(): text_encoder.to(accelerator.device) vae.to(accelerator.device) - # Keep text_encoder and vae in eval model as we don't train it - text_encoder.eval() - vae.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From 195cd463a83e5b90ef24dd38f93891785e954dc7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 13:45:54 +0200 Subject: [PATCH 07/34] improgve sampling --- examples/dreambooth/train_dreambooth.py | 31 ++++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 498269b05f30..9064410a453b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1,6 +1,8 @@ import argparse import math import os +from contextlib import nullcontext +from enum import auto from pathlib import Path from typing import Optional @@ -342,28 +344,36 @@ def main(): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - sd_model = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token, torch_dtype=torch_dtype ) - sd_model.set_progress_bar_config(disable=True) + pipeline.set_progress_bar_config(disable=True) + num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - sd_model, sample_dataloader = accelerator.prepare(sd_model, sample_dataloader) - sd_model.to(accelerator.device) + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + all_images = [] for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - with torch.no_grad(): - images = sd_model(example["prompt"], height=512, width=512, num_inference_steps=50).images + context = torch.autocast(accelerator.device) if accelerator.device.type == "cuda" else nullcontext + with context: + images = pipeline(example["prompt"]).images + all_images.extend(images) - for image, index in zip(images, example["index"]): - image.save(class_images_dir / f"{index + cur_class_images}.jpg") - del sd_model + for image, example in zip(all_images, sample_dataloader): + image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg") + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Handle the repository creation if accelerator.is_main_process: @@ -450,7 +460,6 @@ def main(): text_encoder.to(accelerator.device) vae.to(accelerator.device) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: From 1acc6786e5ead6e0b0e81e1cf5eae79173de97cf Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 13:49:21 +0200 Subject: [PATCH 08/34] fix autocast --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 9064410a453b..eccf9f14258b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -363,7 +363,7 @@ def main(): for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - context = torch.autocast(accelerator.device) if accelerator.device.type == "cuda" else nullcontext + context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext with context: images = pipeline(example["prompt"]).images all_images.extend(images) From 627cc494479bc7401589465ee5932f2cb4746bfb Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 13:58:07 +0200 Subject: [PATCH 09/34] improve sampling more --- examples/dreambooth/train_dreambooth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index eccf9f14258b..ed0a85f1eadc 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -344,7 +344,7 @@ def main(): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token, torch_dtype=torch_dtype ) @@ -360,16 +360,16 @@ def main(): pipeline.to(accelerator.device) all_images = [] + context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext with context: images = pipeline(example["prompt"]).images - all_images.extend(images) + all_images.extend((images, example["index"])) - for image, example in zip(all_images, sample_dataloader): - image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg") + for image, index in all_images: + image.save(class_images_dir / f"{index + cur_class_images}.jpg") del pipeline if torch.cuda.is_available(): From f1c3c8e5a48f22ec6e98439484bf475b28abaaf9 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:01:54 +0200 Subject: [PATCH 10/34] fix saving --- examples/dreambooth/train_dreambooth.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ed0a85f1eadc..9767adcf5d02 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -359,17 +359,15 @@ def main(): sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - all_images = [] context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): with context: images = pipeline(example["prompt"]).images - all_images.extend((images, example["index"])) - for image, index in all_images: - image.save(class_images_dir / f"{index + cur_class_images}.jpg") + for image in images: + image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg") del pipeline if torch.cuda.is_available(): From 509e4e3f9627f674a2df6153639d8f5d2b41f55b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:06:18 +0200 Subject: [PATCH 11/34] actuallu fix saving --- examples/dreambooth/train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 9767adcf5d02..e615fe5416ec 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -366,8 +366,8 @@ def main(): with context: images = pipeline(example["prompt"]).images - for image in images: - image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg") + for image, index in (images, example["index"]): + image.save(class_images_dir / f"{index + cur_class_images}.jpg") del pipeline if torch.cuda.is_available(): From eafc0002d428a79d2d7a451cd916195d93013f26 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:09:03 +0200 Subject: [PATCH 12/34] fix saving --- examples/dreambooth/train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e615fe5416ec..0c7deadb33e7 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -366,8 +366,8 @@ def main(): with context: images = pipeline(example["prompt"]).images - for image, index in (images, example["index"]): - image.save(class_images_dir / f"{index + cur_class_images}.jpg") + for i, image in enumerate(images): + image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") del pipeline if torch.cuda.is_available(): From 6f99f29f71ca4542749af8c106150669645303e8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:31:30 +0200 Subject: [PATCH 13/34] improve dataset --- examples/dreambooth/train_dreambooth.py | 84 ++++++++++++++----------- 1 file changed, 46 insertions(+), 38 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 0c7deadb33e7..e5b410a70f6c 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -2,17 +2,14 @@ import math import os from contextlib import nullcontext -from enum import auto from pathlib import Path from typing import Optional -import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from torch.utils.data import Dataset -import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -21,6 +18,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami from PIL import Image +from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -214,7 +212,6 @@ def __init__( class_data_root=None, class_prompt=None, size=512, - interpolation="bicubic", center_crop=False, ): self.size = size @@ -222,7 +219,9 @@ def __init__( self.tokenizer = tokenizer self.instance_data_root = Path(instance_data_root) - assert self.instance_data_root.exists(), "Instance images root doesn't exists." + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt @@ -238,61 +237,41 @@ def __init__( else: self.class_data_root = None - self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length - def transform_image(self, image: Image): - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - h, w, = ( - img.shape[0], - img.shape[1], - ) - img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] - - image = Image.fromarray(img) - image = image.resize((self.size, self.size), resample=self.interpolation) - - image = np.array(image).astype(np.uint8) - image = (image / 127.5 - 1.0).astype(np.float32) - image = torch.from_numpy(image).permute(2, 0, 1) - return image - def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") - example["instance_images"] = self.transform_image(instance_image) + example["instance_images"] = self.image_transforms(instance_image) example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, - padding="max_length", + padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, - return_tensors="pt", ).input_ids[0] if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") - example["class_images"] = self.transform_image(class_image) + example["class_images"] = self.image_transforms(class_image) example["class_prompt_ids"] = self.tokenizer( self.class_prompt, - padding="max_length", + padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, - return_tensors="pt", ).input_ids[0] return example @@ -434,7 +413,36 @@ def main(): size=args.resolution, center_crop=args.center_crop, ) - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) + + def collate_fn(examples): + def _collate(input_ids, pixel_values): + pixel_values = torch.stack([pixel_value for pixel_value in pixel_values]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = [input_id for input_id in input_ids] + input_ids = tokenizer.pad( + {"input_ids": input_ids}, + padding=True, + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + return input_ids, pixel_values + + instance_prompt_ids, instance_images = _collate(example["instance_prompt_ids"], example["instance_images"]) + + batch = { + "instance_images": instance_images, + "input_ids": instance_prompt_ids, + } + + if args.with_prior_preservation: + class_prompt_ids, class_images = _collate(example["class_prompt_ids"], example["class_images"]) + batch["class_images"] = class_images + batch["class_prompt_ids"] = class_prompt_ids + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False From 392fbf3c1d699edc5bf9131aa46388220fd1dcbf Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:33:02 +0200 Subject: [PATCH 14/34] fix collate fun --- examples/dreambooth/train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e5b410a70f6c..276a08780a75 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -428,7 +428,7 @@ def _collate(input_ids, pixel_values): ).input_ids return input_ids, pixel_values - instance_prompt_ids, instance_images = _collate(example["instance_prompt_ids"], example["instance_images"]) + instance_prompt_ids, instance_images = _collate(examples["instance_prompt_ids"], examples["instance_images"]) batch = { "instance_images": instance_images, @@ -436,7 +436,7 @@ def _collate(input_ids, pixel_values): } if args.with_prior_preservation: - class_prompt_ids, class_images = _collate(example["class_prompt_ids"], example["class_images"]) + class_prompt_ids, class_images = _collate(examples["class_prompt_ids"], examples["class_images"]) batch["class_images"] = class_images batch["class_prompt_ids"] = class_prompt_ids From d6c88f4ef67eb17308ef2efe6954fa1286d8b891 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:35:40 +0200 Subject: [PATCH 15/34] fix collate_fn --- examples/dreambooth/train_dreambooth.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 276a08780a75..197acea4410d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -428,7 +428,9 @@ def _collate(input_ids, pixel_values): ).input_ids return input_ids, pixel_values - instance_prompt_ids, instance_images = _collate(examples["instance_prompt_ids"], examples["instance_images"]) + instance_prompt_ids = [example["instance_prompt_ids"] for example in examples] + instance_images = [example["instance_images"] for example in examples] + instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images) batch = { "instance_images": instance_images, @@ -436,7 +438,8 @@ def _collate(input_ids, pixel_values): } if args.with_prior_preservation: - class_prompt_ids, class_images = _collate(examples["class_prompt_ids"], examples["class_images"]) + class_prompt_ids = [example["class_prompt_ids"] for example in examples] + class_images = [example["class_images"] for example in examples] batch["class_images"] = class_images batch["class_prompt_ids"] = class_prompt_ids From a3d604e99a566c9b69e63342a2206aeb7cd87b1d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:36:39 +0200 Subject: [PATCH 16/34] fix collate fn --- examples/dreambooth/train_dreambooth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 197acea4410d..7b95f21963b1 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -442,6 +442,7 @@ def _collate(input_ids, pixel_values): class_images = [example["class_images"] for example in examples] batch["class_images"] = class_images batch["class_prompt_ids"] = class_prompt_ids + return batch train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn From f4a91a640ce140f7d5b2f7e4a9cd015aeeaddb22 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:37:32 +0200 Subject: [PATCH 17/34] fix key name --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 7b95f21963b1..2959b572313c 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -434,7 +434,7 @@ def _collate(input_ids, pixel_values): batch = { "instance_images": instance_images, - "input_ids": instance_prompt_ids, + "instance_prompt_ids": instance_prompt_ids, } if args.with_prior_preservation: From 8e92d69ca2a4bdd8d1b08a3bfd078d815d5b94a9 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:44:23 +0200 Subject: [PATCH 18/34] fix dataset --- examples/dreambooth/train_dreambooth.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 2959b572313c..f51afc75bdf9 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -260,7 +260,7 @@ def __getitem__(self, index): padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, - ).input_ids[0] + ).input_ids if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) @@ -272,7 +272,7 @@ def __getitem__(self, index): padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, - ).input_ids[0] + ).input_ids return example @@ -423,7 +423,6 @@ def _collate(input_ids, pixel_values): input_ids = tokenizer.pad( {"input_ids": input_ids}, padding=True, - max_length=tokenizer.model_max_length, return_tensors="pt", ).input_ids return input_ids, pixel_values From ef013311464e727c515f9818f2ba5886614a20d8 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:45:37 +0200 Subject: [PATCH 19/34] fix collate fn --- examples/dreambooth/train_dreambooth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index f51afc75bdf9..19c6162448e1 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -439,6 +439,7 @@ def _collate(input_ids, pixel_values): if args.with_prior_preservation: class_prompt_ids = [example["class_prompt_ids"] for example in examples] class_images = [example["class_images"] for example in examples] + class_prompt_ids, class_images = _collate(class_prompt_ids, class_images) batch["class_images"] = class_images batch["class_prompt_ids"] = class_prompt_ids return batch From c66cf4dc1af207fc067914b77d21c8f9bb59bb3e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 15:02:13 +0200 Subject: [PATCH 20/34] concat batch in collate fn --- examples/dreambooth/train_dreambooth.py | 47 ++++++++----------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 19c6162448e1..d3f8413415fe 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -415,33 +415,23 @@ def main(): ) def collate_fn(examples): - def _collate(input_ids, pixel_values): - pixel_values = torch.stack([pixel_value for pixel_value in pixel_values]) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - - input_ids = [input_id for input_id in input_ids] - input_ids = tokenizer.pad( - {"input_ids": input_ids}, - padding=True, - return_tensors="pt", - ).input_ids - return input_ids, pixel_values + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # concat class and instance examples for prior preservation + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - instance_prompt_ids = [example["instance_prompt_ids"] for example in examples] - instance_images = [example["instance_images"] for example in examples] - instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images) + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids batch = { - "instance_images": instance_images, - "instance_prompt_ids": instance_prompt_ids, + "input_ids": input_ids, + "pixel_values": pixel_values, } - - if args.with_prior_preservation: - class_prompt_ids = [example["class_prompt_ids"] for example in examples] - class_images = [example["class_images"] for example in examples] - class_prompt_ids, class_images = _collate(class_prompt_ids, class_images) - batch["class_images"] = class_images - batch["class_prompt_ids"] = class_prompt_ids return batch train_dataloader = torch.utils.data.DataLoader( @@ -503,15 +493,8 @@ def _collate(input_ids, pixel_values): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - if args.with_prior_preservation: - images = torch.cat([batch["instance_images"], batch["class_images"]], dim=0) - input_ids = torch.cat([batch["instance_prompt_ids"], batch["class_prompt_ids"]], dim=0) - else: - images = batch["instance_images"] - input_ids = batch["instance_prompt_ids"] - with torch.no_grad(): - latents = vae.encode(images).latent_dist.sample() + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -528,7 +511,7 @@ def _collate(input_ids, pixel_values): # Get the text embedding for conditioning with torch.no_grad(): - encoder_hidden_states = text_encoder(input_ids)[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample From 16ecc089e3ce400a27ec5ce91e72da0c706dd8a6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 15:10:54 +0200 Subject: [PATCH 21/34] add grad ckpt --- examples/dreambooth/train_dreambooth.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3f8413415fe..c44994bbf5ab 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -120,6 +120,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -388,10 +393,14 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token ) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + optimizer = torch.optim.AdamW( unet.parameters(), # only optimize unet lr=args.learning_rate, From 87bc75231a62d124333971a7ea9d4b0e5328f466 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 17:02:02 +0200 Subject: [PATCH 22/34] add option for 8bit adam --- examples/dreambooth/train_dreambooth.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c44994bbf5ab..1829929bbf6f 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -149,6 +149,9 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -401,7 +404,19 @@ def main(): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - optimizer = torch.optim.AdamW( + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( unet.parameters(), # only optimize unet lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), From 661ca4677e6dccc4ad596c2ee6ca4baad4159e95 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 22:41:47 +0200 Subject: [PATCH 23/34] do two forward passes for prior preservation --- examples/dreambooth/train_dreambooth.py | 81 ++++++++++++++----------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 1829929bbf6f..2f4bab89f588 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -442,20 +442,25 @@ def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - # concat class and instance examples for prior preservation - if args.with_prior_preservation: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - + pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids batch = { "input_ids": input_ids, "pixel_values": pixel_values, } + + if args.with_prior_preservation: + class_input_ids = [example["class_prompt_ids"] for example in examples] + class_pixel_values = [example["class_images"] for example in examples] + + class_pixel_values = torch.stack(class_pixel_values).to(memory_format=torch.contiguous_format).float() + class_input_ids = tokenizer.pad( + {"input_ids": class_input_ids}, padding=True, return_tensors="pt" + ).input_ids + batch["class_input_ids"] = class_input_ids + batch["class_pixel_values"] = class_pixel_values + return batch train_dataloader = torch.utils.data.DataLoader( @@ -516,33 +521,41 @@ def collate_fn(examples): unet.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - accelerator.backward(loss) + def _forward(input_ids, pixel_values): + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + with torch.no_grad(): + encoder_hidden_states = text_encoder(input_ids)[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + return loss + + loss = _forward(batch["input_ids"], batch["pixel_values"]) + + if args.with_prior_preservation: + prior_loss = _forward(batch["class_input_ids"], batch["class_pixel_values"]) + loss = loss + prior_loss + + accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() From ce2a3beae263f4c5c32d8f83f10f572b5a87a6d0 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 10:49:44 +0200 Subject: [PATCH 24/34] Revert "do two forward passes for prior preservation" This reverts commit 661ca4677e6dccc4ad596c2ee6ca4baad4159e95. --- examples/dreambooth/train_dreambooth.py | 81 +++++++++++-------------- 1 file changed, 34 insertions(+), 47 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 2f4bab89f588..1829929bbf6f 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -442,25 +442,20 @@ def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() + # concat class and instance examples for prior preservation + if args.with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids batch = { "input_ids": input_ids, "pixel_values": pixel_values, } - - if args.with_prior_preservation: - class_input_ids = [example["class_prompt_ids"] for example in examples] - class_pixel_values = [example["class_images"] for example in examples] - - class_pixel_values = torch.stack(class_pixel_values).to(memory_format=torch.contiguous_format).float() - class_input_ids = tokenizer.pad( - {"input_ids": class_input_ids}, padding=True, return_tensors="pt" - ).input_ids - batch["class_input_ids"] = class_input_ids - batch["class_pixel_values"] = class_pixel_values - return batch train_dataloader = torch.utils.data.DataLoader( @@ -521,41 +516,33 @@ def collate_fn(examples): unet.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - - def _forward(input_ids, pixel_values): - # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(pixel_values).latent_dist.sample() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(input_ids)[0] - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - return loss - - loss = _forward(batch["input_ids"], batch["pixel_values"]) - - if args.with_prior_preservation: - prior_loss = _forward(batch["class_input_ids"], batch["class_pixel_values"]) - loss = loss + prior_loss - + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + with torch.no_grad(): + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) + optimizer.step() lr_scheduler.step() optimizer.zero_grad() From 248e77d11976f5b1fdfee3d89a70dbad91d61edc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 11:43:14 +0200 Subject: [PATCH 25/34] add option for prior_loss_weight --- examples/dreambooth/train_dreambooth.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 1829929bbf6f..0162a538e361 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -73,6 +73,7 @@ def parse_args(): action="store_true", help="Flag to add prior perservation loss.", ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, @@ -540,9 +541,19 @@ def collate_fn(examples): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - accelerator.backward(loss) + if args.with_prior_preservation: + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + # compute instance loss + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + # compute prior loss + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() From abbb614bc483aa9a7ac69dd91973df918163b6b0 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 11:46:19 +0200 Subject: [PATCH 26/34] add option for clip grad norm --- examples/dreambooth/train_dreambooth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 0162a538e361..3f08a231908a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -157,6 +157,7 @@ def parse_args(): parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument( "--use_auth_token", @@ -554,6 +555,7 @@ def collate_fn(examples): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() From c05b0438e589355639deeda8d557daedbd9f18d3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 11:50:57 +0200 Subject: [PATCH 27/34] add more comments --- examples/dreambooth/train_dreambooth.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 3f08a231908a..397b5d9b52b1 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -406,6 +406,7 @@ def main(): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -444,7 +445,8 @@ def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - # concat class and instance examples for prior preservation + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. if args.with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -543,13 +545,17 @@ def collate_fn(examples): noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.with_prior_preservation: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) noise, noise_prior = torch.chunk(noise, 2, dim=0) - # compute instance loss + + # Compute instance loss loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() - # compute prior loss + # Compute prior loss prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + + # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() From 90bac836d182caefc694c87deff63940652fd016 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 14:17:18 +0200 Subject: [PATCH 28/34] update readme --- examples/dreambooth/README.md | 48 +++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index ee936d419b0d..7013abb63794 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -1,8 +1,11 @@ -## DreamBooth training example +# DreamBooth training example [DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. +## Running on Colab +TODO + ## Running locally ### Installing the dependencies @@ -57,7 +60,11 @@ accelerate launch train_dreambooth.py \ --max_train_steps=400 ``` -Training with prior-preservation loss using +### Training with prior-preservation loss + +prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. +According to the paper, it's recommened to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. + ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export INSTANCE_DIR="path-to-instance-images" @@ -69,7 +76,7 @@ accelerate launch train_dreambooth.py \ --instance_data_dir=$INSTANCE_DIR \ --class_data_dir=$CLASS_DIR \ --output_dir=$OUTPUT_DIR \ - --with_prior_preservation \ + --with_prior_preservation --prior_loss_weight=1.0 \ --instance_prompt="a photo of sks dog" \ --class_prompt="a photo of dog" \ --resolution=512 \ @@ -79,11 +86,42 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=1000 + --max_train_steps=800 +``` + +### Training on a 16GB GPU: + +With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU. + +Install `bitsandbytes` with `pip install bitsandbytes` + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 --gradient_checkpointing \ + --use_8bit_adam \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 ``` -### Inference +## Inference Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. From 89991a1437fd29973037af7aaadd574051e33a00 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 14:18:04 +0200 Subject: [PATCH 29/34] update readme --- examples/dreambooth/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 7013abb63794..6b5f8211147b 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -111,7 +111,7 @@ accelerate launch train_dreambooth.py \ --class_prompt="a photo of dog" \ --resolution=512 \ --train_batch_size=1 \ - --gradient_accumulation_steps=1 --gradient_checkpointing \ + --gradient_accumulation_steps=2 --gradient_checkpointing \ --use_8bit_adam \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ From 265d2b1e35f4db148c7d9485fc128f93f4fee998 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 27 Sep 2022 14:39:58 +0200 Subject: [PATCH 30/34] Apply suggestions from code review Co-authored-by: Patrick von Platen --- examples/dreambooth/README.md | 5 +++-- examples/dreambooth/train_dreambooth.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 6b5f8211147b..3a68a478894c 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -62,7 +62,7 @@ accelerate launch train_dreambooth.py \ ### Training with prior-preservation loss -prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. +Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. According to the paper, it's recommened to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. ```bash @@ -129,9 +129,10 @@ Once you have trained a model using above command, the inference can be done sim from torch import autocast from diffusers import StableDiffusionPipeline +import torch model_id = "path-to-your-trained-model" -pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda") +pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") prompt = "A photo of sks dog in a bucket" diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 397b5d9b52b1..48f5c2f6b7fa 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -531,7 +531,8 @@ def collate_fn(examples): # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() + ) + timesteps = timsteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) From d63fa4d6b3282d750e6629b47542b4d6d490077b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 14:46:31 +0200 Subject: [PATCH 31/34] add docstr for dataset --- examples/dreambooth/train_dreambooth.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 48f5c2f6b7fa..10213df6cfe9 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -214,6 +214,11 @@ def parse_args(): class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the promots for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + def __init__( self, instance_data_root, @@ -288,6 +293,8 @@ def __getitem__(self, index): class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + def __init__(self, prompt, num_samples): self.prompt = prompt self.num_samples = num_samples @@ -529,10 +536,8 @@ def collate_fn(examples): noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ) - timesteps = timsteps.long() + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) From 102ad709a5e6d3c1479abea1cfeedf88f45d1790 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 14:51:36 +0200 Subject: [PATCH 32/34] update the saving logic --- examples/dreambooth/train_dreambooth.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 10213df6cfe9..681a4aeb42c2 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -588,16 +588,10 @@ def collate_fn(examples): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: - pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), - safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + use_auth_token=args.use_auth_token, ) pipeline.save_pretrained(args.output_dir) From 7ad4316316f07b2948b1e1efc33bfc6b6417093e Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 27 Sep 2022 14:52:27 +0200 Subject: [PATCH 33/34] Update examples/dreambooth/README.md --- examples/dreambooth/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 3a68a478894c..01bbb1c5e343 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -3,8 +3,6 @@ [DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. -## Running on Colab -TODO ## Running locally ### Installing the dependencies From d72c65921732e120dbad81b24eb3818d4f40e2a6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Sep 2022 14:54:50 +0200 Subject: [PATCH 34/34] remove unused imports --- examples/dreambooth/train_dreambooth.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 681a4aeb42c2..600653187977 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -13,14 +13,13 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer logger = get_logger(__name__)