diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py index fd2b5568d6d8..f3b4602c7fcf 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py @@ -349,7 +349,7 @@ def parse_args(input_args=None): "--optimizer", type=str, default="AdamW", - help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + choices=["AdamW", "Prodigy", "AdEMAMix"], ) parser.add_argument( @@ -357,6 +357,11 @@ def parse_args(input_args=None): action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", ) + parser.add_argument( + "--use_8bit_ademamix", + action="store_true", + help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.", + ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." @@ -820,16 +825,15 @@ def load_model_hook(models, input_dir): params_to_optimize = [transformer_parameters_with_lr] # Optimizer creation - if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warning( - f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." - "Defaulting to adamW" + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" ) - args.optimizer = "adamw" - if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix": logger.warning( - f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was " f"set to {args.optimizer.lower()}" ) @@ -853,6 +857,20 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + elif args.optimizer.lower() == "ademamix": + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`." + ) + if args.use_8bit_ademamix: + optimizer_class = bnb.optim.AdEMAMix8bit + else: + optimizer_class = bnb.optim.AdEMAMix + + optimizer = optimizer_class(params_to_optimize) + if args.optimizer.lower() == "prodigy": try: import prodigyopt @@ -868,7 +886,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, @@ -1020,12 +1037,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], - model_input.shape[2], - model_input.shape[3], + model_input.shape[2] // 2, + model_input.shape[3] // 2, accelerator.device, weight_dtype, ) @@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if transformer.config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1082,8 +1099,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): )[0] model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2] * vae_scale_factor / 2), - width=int(model_input.shape[3] * vae_scale_factor / 2), + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, vae_scale_factor=vae_scale_factor, )