Skip to content

Commit c66cf4d

Browse files
committed
concat batch in collate fn
1 parent ef01331 commit c66cf4d

File tree

1 file changed

+15
-32
lines changed

1 file changed

+15
-32
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -415,33 +415,23 @@ def main():
415415
)
416416

417417
def collate_fn(examples):
418-
def _collate(input_ids, pixel_values):
419-
pixel_values = torch.stack([pixel_value for pixel_value in pixel_values])
420-
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
421-
422-
input_ids = [input_id for input_id in input_ids]
423-
input_ids = tokenizer.pad(
424-
{"input_ids": input_ids},
425-
padding=True,
426-
return_tensors="pt",
427-
).input_ids
428-
return input_ids, pixel_values
418+
input_ids = [example["instance_prompt_ids"] for example in examples]
419+
pixel_values = [example["instance_images"] for example in examples]
420+
421+
# concat class and instance examples for prior preservation
422+
if args.with_prior_preservation:
423+
input_ids += [example["class_prompt_ids"] for example in examples]
424+
pixel_values += [example["class_images"] for example in examples]
425+
426+
pixel_values = torch.stack(pixel_values)
427+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
429428

430-
instance_prompt_ids = [example["instance_prompt_ids"] for example in examples]
431-
instance_images = [example["instance_images"] for example in examples]
432-
instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images)
429+
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
433430

434431
batch = {
435-
"instance_images": instance_images,
436-
"instance_prompt_ids": instance_prompt_ids,
432+
"input_ids": input_ids,
433+
"pixel_values": pixel_values,
437434
}
438-
439-
if args.with_prior_preservation:
440-
class_prompt_ids = [example["class_prompt_ids"] for example in examples]
441-
class_images = [example["class_images"] for example in examples]
442-
class_prompt_ids, class_images = _collate(class_prompt_ids, class_images)
443-
batch["class_images"] = class_images
444-
batch["class_prompt_ids"] = class_prompt_ids
445435
return batch
446436

447437
train_dataloader = torch.utils.data.DataLoader(
@@ -503,15 +493,8 @@ def _collate(input_ids, pixel_values):
503493
for step, batch in enumerate(train_dataloader):
504494
with accelerator.accumulate(unet):
505495
# Convert images to latent space
506-
if args.with_prior_preservation:
507-
images = torch.cat([batch["instance_images"], batch["class_images"]], dim=0)
508-
input_ids = torch.cat([batch["instance_prompt_ids"], batch["class_prompt_ids"]], dim=0)
509-
else:
510-
images = batch["instance_images"]
511-
input_ids = batch["instance_prompt_ids"]
512-
513496
with torch.no_grad():
514-
latents = vae.encode(images).latent_dist.sample()
497+
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
515498
latents = latents * 0.18215
516499

517500
# Sample noise that we'll add to the latents
@@ -528,7 +511,7 @@ def _collate(input_ids, pixel_values):
528511

529512
# Get the text embedding for conditioning
530513
with torch.no_grad():
531-
encoder_hidden_states = text_encoder(input_ids)[0]
514+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
532515

533516
# Predict the noise residual
534517
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

0 commit comments

Comments
 (0)