Skip to content

Commit 2d738e2

Browse files
akbaiglinoytsaban
authored andcommitted
fix: checkpoint save issue in advanced dreambooth lora sdxl script (#8926)
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent 1c550bf commit 2d738e2

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,13 +1605,15 @@ def save_model_hook(models, weights, output_dir):
16051605
if isinstance(model, type(unwrap_model(unet))):
16061606
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
16071607
elif isinstance(model, type(unwrap_model(text_encoder_one))):
1608-
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
1609-
get_peft_model_state_dict(model)
1610-
)
1608+
if args.train_text_encoder:
1609+
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
1610+
get_peft_model_state_dict(model)
1611+
)
16111612
elif isinstance(model, type(unwrap_model(text_encoder_two))):
1612-
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
1613-
get_peft_model_state_dict(model)
1614-
)
1613+
if args.train_text_encoder:
1614+
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
1615+
get_peft_model_state_dict(model)
1616+
)
16151617
else:
16161618
raise ValueError(f"unexpected save model: {model.__class__}")
16171619

0 commit comments

Comments
 (0)