From 7ef43c941803876f7bbf1c8913365f1b9a1a60cd Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Wed, 18 Jun 2025 10:42:34 -0500 Subject: [PATCH] Add --lora_alpha and metadata handling to train_dreambooth_lora_sana.py --- .../dreambooth/test_dreambooth_lora_sana.py | 42 +++++++++++++++++++ .../dreambooth/train_dreambooth_lora_sana.py | 18 ++++++-- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_lora_sana.py b/examples/dreambooth/test_dreambooth_lora_sana.py index 6e5727ae7176..7564ab08b028 100644 --- a/examples/dreambooth/test_dreambooth_lora_sana.py +++ b/examples/dreambooth/test_dreambooth_lora_sana.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import sys @@ -20,6 +21,8 @@ import safetensors +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + sys.path.append("..") from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 @@ -204,3 +207,42 @@ def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_mult run_command(self._launch_args + resume_run_args) self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + + def test_dreambooth_lora_sana_with_metadata(self): + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --resolution=32 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --lora_alpha={lora_alpha} + --rank={rank} + --checkpointing_steps=2 + --max_sequence_length 166 + """.split() + + test_args.extend(["--instance_prompt", ""]) + run_command(self._launch_args + test_args) + + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 0c4a16d1802f..c156523db3d5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -52,6 +52,7 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( + _collate_lora_metadata, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -323,9 +324,13 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) - + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") - parser.add_argument( "--with_prior_preservation", default=False, @@ -1023,7 +1028,7 @@ def main(args): # now we will add new LoRA weights the transformer layers transformer_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, + lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, init_lora_weights="gaussian", target_modules=target_modules, @@ -1039,10 +1044,11 @@ def unwrap_model(model): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None - + modules_to_save = {} for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1052,6 +1058,7 @@ def save_model_hook(models, weights, output_dir): SanaPipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), ) def load_model_hook(models, input_dir): @@ -1507,15 +1514,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) + modules_to_save = {} if args.upcast_before_saving: transformer.to(torch.float32) else: transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer SanaPipeline.save_lora_weights( save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), ) # Final inference