Skip to content

Commit 661ca46

Browse files
committed
do two forward passes for prior preservation
1 parent 87bc752 commit 661ca46

File tree

1 file changed

+47
-34
lines changed

1 file changed

+47
-34
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -442,20 +442,25 @@ def collate_fn(examples):
442442
input_ids = [example["instance_prompt_ids"] for example in examples]
443443
pixel_values = [example["instance_images"] for example in examples]
444444

445-
# concat class and instance examples for prior preservation
446-
if args.with_prior_preservation:
447-
input_ids += [example["class_prompt_ids"] for example in examples]
448-
pixel_values += [example["class_images"] for example in examples]
449-
450-
pixel_values = torch.stack(pixel_values)
451-
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
452-
445+
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
453446
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
454447

455448
batch = {
456449
"input_ids": input_ids,
457450
"pixel_values": pixel_values,
458451
}
452+
453+
if args.with_prior_preservation:
454+
class_input_ids = [example["class_prompt_ids"] for example in examples]
455+
class_pixel_values = [example["class_images"] for example in examples]
456+
457+
class_pixel_values = torch.stack(class_pixel_values).to(memory_format=torch.contiguous_format).float()
458+
class_input_ids = tokenizer.pad(
459+
{"input_ids": class_input_ids}, padding=True, return_tensors="pt"
460+
).input_ids
461+
batch["class_input_ids"] = class_input_ids
462+
batch["class_pixel_values"] = class_pixel_values
463+
459464
return batch
460465

461466
train_dataloader = torch.utils.data.DataLoader(
@@ -516,33 +521,41 @@ def collate_fn(examples):
516521
unet.train()
517522
for step, batch in enumerate(train_dataloader):
518523
with accelerator.accumulate(unet):
519-
# Convert images to latent space
520-
with torch.no_grad():
521-
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
522-
latents = latents * 0.18215
523-
524-
# Sample noise that we'll add to the latents
525-
noise = torch.randn(latents.shape).to(latents.device)
526-
bsz = latents.shape[0]
527-
# Sample a random timestep for each image
528-
timesteps = torch.randint(
529-
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
530-
).long()
531-
532-
# Add noise to the latents according to the noise magnitude at each timestep
533-
# (this is the forward diffusion process)
534-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
535-
536-
# Get the text embedding for conditioning
537-
with torch.no_grad():
538-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
539-
540-
# Predict the noise residual
541-
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
542-
543-
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
544-
accelerator.backward(loss)
545524

525+
def _forward(input_ids, pixel_values):
526+
# Convert images to latent space
527+
with torch.no_grad():
528+
latents = vae.encode(pixel_values).latent_dist.sample()
529+
latents = latents * 0.18215
530+
531+
# Sample noise that we'll add to the latents
532+
noise = torch.randn(latents.shape).to(latents.device)
533+
bsz = latents.shape[0]
534+
# Sample a random timestep for each image
535+
timesteps = torch.randint(
536+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
537+
).long()
538+
539+
# Add noise to the latents according to the noise magnitude at each timestep
540+
# (this is the forward diffusion process)
541+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
542+
543+
# Get the text embedding for conditioning
544+
with torch.no_grad():
545+
encoder_hidden_states = text_encoder(input_ids)[0]
546+
547+
# Predict the noise residual
548+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
549+
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
550+
return loss
551+
552+
loss = _forward(batch["input_ids"], batch["pixel_values"])
553+
554+
if args.with_prior_preservation:
555+
prior_loss = _forward(batch["class_input_ids"], batch["class_pixel_values"])
556+
loss = loss + prior_loss
557+
558+
accelerator.backward(loss)
546559
optimizer.step()
547560
lr_scheduler.step()
548561
optimizer.zero_grad()

0 commit comments

Comments
 (0)