From 7393fae4c19b430accadacee213e5d958da17af0 Mon Sep 17 00:00:00 2001 From: Boseong Jeon Date: Tue, 1 Oct 2024 21:56:19 +0900 Subject: [PATCH] Handling mixed precision and add unwarp --- examples/dreambooth/train_dreambooth_lora_flux.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index fcc11386abcf..6e5373e95333 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -177,7 +177,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1686,7 +1686,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if transformer.config.guidance_embeds: + if accelerator.unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1799,6 +1799,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # create pipeline if not args.train_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae,