Skip to content

Add --lora_alpha and metadata handling to train_dreambooth_lora_sana.py #11744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import sys
import tempfile

import safetensors

from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
Expand Down Expand Up @@ -204,3 +207,42 @@ def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_mult
run_command(self._launch_args + resume_run_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

def test_dreambooth_lora_sana_with_metadata(self):
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--lora_alpha={lora_alpha}
--rank={rank}
--checkpointing_steps=2
--max_sequence_length 166
""".split()

test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)

state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))

# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}

metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)

loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
18 changes: 14 additions & 4 deletions examples/dreambooth/train_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
Expand Down Expand Up @@ -323,9 +324,13 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)

parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")

parser.add_argument(
"--with_prior_preservation",
default=False,
Expand Down Expand Up @@ -1023,7 +1028,7 @@ def main(args):
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
Expand All @@ -1039,10 +1044,11 @@ def unwrap_model(model):
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None

modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand All @@ -1052,6 +1058,7 @@ def save_model_hook(models, weights, output_dir):
SanaPipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)

def load_model_hook(models, input_dir):
Expand Down Expand Up @@ -1507,15 +1514,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
modules_to_save = {}
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer

SanaPipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
**_collate_lora_metadata(modules_to_save),
)

# Final inference
Expand Down