@@ -415,33 +415,23 @@ def main():
415
415
)
416
416
417
417
def collate_fn (examples ):
418
- def _collate (input_ids , pixel_values ):
419
- pixel_values = torch .stack ([pixel_value for pixel_value in pixel_values ])
420
- pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
421
-
422
- input_ids = [input_id for input_id in input_ids ]
423
- input_ids = tokenizer .pad (
424
- {"input_ids" : input_ids },
425
- padding = True ,
426
- return_tensors = "pt" ,
427
- ).input_ids
428
- return input_ids , pixel_values
418
+ input_ids = [example ["instance_prompt_ids" ] for example in examples ]
419
+ pixel_values = [example ["instance_images" ] for example in examples ]
420
+
421
+ # concat class and instance examples for prior preservation
422
+ if args .with_prior_preservation :
423
+ input_ids += [example ["class_prompt_ids" ] for example in examples ]
424
+ pixel_values += [example ["class_images" ] for example in examples ]
425
+
426
+ pixel_values = torch .stack (pixel_values )
427
+ pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
429
428
430
- instance_prompt_ids = [example ["instance_prompt_ids" ] for example in examples ]
431
- instance_images = [example ["instance_images" ] for example in examples ]
432
- instance_prompt_ids , instance_images = _collate (instance_prompt_ids , instance_images )
429
+ input_ids = tokenizer .pad ({"input_ids" : input_ids }, padding = True , return_tensors = "pt" ).input_ids
433
430
434
431
batch = {
435
- "instance_images " : instance_images ,
436
- "instance_prompt_ids " : instance_prompt_ids ,
432
+ "input_ids " : input_ids ,
433
+ "pixel_values " : pixel_values ,
437
434
}
438
-
439
- if args .with_prior_preservation :
440
- class_prompt_ids = [example ["class_prompt_ids" ] for example in examples ]
441
- class_images = [example ["class_images" ] for example in examples ]
442
- class_prompt_ids , class_images = _collate (class_prompt_ids , class_images )
443
- batch ["class_images" ] = class_images
444
- batch ["class_prompt_ids" ] = class_prompt_ids
445
435
return batch
446
436
447
437
train_dataloader = torch .utils .data .DataLoader (
@@ -503,15 +493,8 @@ def _collate(input_ids, pixel_values):
503
493
for step , batch in enumerate (train_dataloader ):
504
494
with accelerator .accumulate (unet ):
505
495
# Convert images to latent space
506
- if args .with_prior_preservation :
507
- images = torch .cat ([batch ["instance_images" ], batch ["class_images" ]], dim = 0 )
508
- input_ids = torch .cat ([batch ["instance_prompt_ids" ], batch ["class_prompt_ids" ]], dim = 0 )
509
- else :
510
- images = batch ["instance_images" ]
511
- input_ids = batch ["instance_prompt_ids" ]
512
-
513
496
with torch .no_grad ():
514
- latents = vae .encode (images ).latent_dist .sample ()
497
+ latents = vae .encode (batch [ "pixel_values" ] ).latent_dist .sample ()
515
498
latents = latents * 0.18215
516
499
517
500
# Sample noise that we'll add to the latents
@@ -528,7 +511,7 @@ def _collate(input_ids, pixel_values):
528
511
529
512
# Get the text embedding for conditioning
530
513
with torch .no_grad ():
531
- encoder_hidden_states = text_encoder (input_ids )[0 ]
514
+ encoder_hidden_states = text_encoder (batch [ " input_ids" ] )[0 ]
532
515
533
516
# Predict the noise residual
534
517
noise_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
0 commit comments