@@ -442,20 +442,25 @@ def collate_fn(examples):
442
442
input_ids = [example ["instance_prompt_ids" ] for example in examples ]
443
443
pixel_values = [example ["instance_images" ] for example in examples ]
444
444
445
- # concat class and instance examples for prior preservation
446
- if args .with_prior_preservation :
447
- input_ids += [example ["class_prompt_ids" ] for example in examples ]
448
- pixel_values += [example ["class_images" ] for example in examples ]
449
-
450
- pixel_values = torch .stack (pixel_values )
451
- pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
452
-
445
+ pixel_values = torch .stack (pixel_values ).to (memory_format = torch .contiguous_format ).float ()
453
446
input_ids = tokenizer .pad ({"input_ids" : input_ids }, padding = True , return_tensors = "pt" ).input_ids
454
447
455
448
batch = {
456
449
"input_ids" : input_ids ,
457
450
"pixel_values" : pixel_values ,
458
451
}
452
+
453
+ if args .with_prior_preservation :
454
+ class_input_ids = [example ["class_prompt_ids" ] for example in examples ]
455
+ class_pixel_values = [example ["class_images" ] for example in examples ]
456
+
457
+ class_pixel_values = torch .stack (class_pixel_values ).to (memory_format = torch .contiguous_format ).float ()
458
+ class_input_ids = tokenizer .pad (
459
+ {"input_ids" : class_input_ids }, padding = True , return_tensors = "pt"
460
+ ).input_ids
461
+ batch ["class_input_ids" ] = class_input_ids
462
+ batch ["class_pixel_values" ] = class_pixel_values
463
+
459
464
return batch
460
465
461
466
train_dataloader = torch .utils .data .DataLoader (
@@ -516,33 +521,41 @@ def collate_fn(examples):
516
521
unet .train ()
517
522
for step , batch in enumerate (train_dataloader ):
518
523
with accelerator .accumulate (unet ):
519
- # Convert images to latent space
520
- with torch .no_grad ():
521
- latents = vae .encode (batch ["pixel_values" ]).latent_dist .sample ()
522
- latents = latents * 0.18215
523
-
524
- # Sample noise that we'll add to the latents
525
- noise = torch .randn (latents .shape ).to (latents .device )
526
- bsz = latents .shape [0 ]
527
- # Sample a random timestep for each image
528
- timesteps = torch .randint (
529
- 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device
530
- ).long ()
531
-
532
- # Add noise to the latents according to the noise magnitude at each timestep
533
- # (this is the forward diffusion process)
534
- noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
535
-
536
- # Get the text embedding for conditioning
537
- with torch .no_grad ():
538
- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
539
-
540
- # Predict the noise residual
541
- noise_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
542
-
543
- loss = F .mse_loss (noise_pred , noise , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
544
- accelerator .backward (loss )
545
524
525
+ def _forward (input_ids , pixel_values ):
526
+ # Convert images to latent space
527
+ with torch .no_grad ():
528
+ latents = vae .encode (pixel_values ).latent_dist .sample ()
529
+ latents = latents * 0.18215
530
+
531
+ # Sample noise that we'll add to the latents
532
+ noise = torch .randn (latents .shape ).to (latents .device )
533
+ bsz = latents .shape [0 ]
534
+ # Sample a random timestep for each image
535
+ timesteps = torch .randint (
536
+ 0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device
537
+ ).long ()
538
+
539
+ # Add noise to the latents according to the noise magnitude at each timestep
540
+ # (this is the forward diffusion process)
541
+ noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
542
+
543
+ # Get the text embedding for conditioning
544
+ with torch .no_grad ():
545
+ encoder_hidden_states = text_encoder (input_ids )[0 ]
546
+
547
+ # Predict the noise residual
548
+ noise_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
549
+ loss = F .mse_loss (noise_pred , noise , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
550
+ return loss
551
+
552
+ loss = _forward (batch ["input_ids" ], batch ["pixel_values" ])
553
+
554
+ if args .with_prior_preservation :
555
+ prior_loss = _forward (batch ["class_input_ids" ], batch ["class_pixel_values" ])
556
+ loss = loss + prior_loss
557
+
558
+ accelerator .backward (loss )
546
559
optimizer .step ()
547
560
lr_scheduler .step ()
548
561
optimizer .zero_grad ()
0 commit comments