Skip to content

Commit 2fada8d

Browse files
[bug fix] fixes #6444 - checkpointing save issue in advanced dreambooth lora sdxl script (#6464)
* unwrap text encoder when saving hook only for full text encoder tuning * unwrap text encoder when saving hook only for full text encoder tuning * save embeddings in each checkpoint as well * save embeddings in each checkpoint as well * save embeddings in each checkpoint as well * Update examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent f2d51a2 commit 2fada8d

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,13 +1316,15 @@ def save_model_hook(models, weights, output_dir):
13161316
if isinstance(model, type(accelerator.unwrap_model(unet))):
13171317
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
13181318
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1319-
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
1320-
get_peft_model_state_dict(model)
1321-
)
1319+
if args.train_text_encoder:
1320+
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
1321+
get_peft_model_state_dict(model)
1322+
)
13221323
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1323-
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
1324-
get_peft_model_state_dict(model)
1325-
)
1324+
if args.train_text_encoder:
1325+
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
1326+
get_peft_model_state_dict(model)
1327+
)
13261328
else:
13271329
raise ValueError(f"unexpected save model: {model.__class__}")
13281330

@@ -1335,6 +1337,8 @@ def save_model_hook(models, weights, output_dir):
13351337
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
13361338
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
13371339
)
1340+
if args.train_text_encoder_ti:
1341+
embedding_handler.save_embeddings(f"{output_dir}/{args.output_dir}_emb.safetensors")
13381342

13391343
def load_model_hook(models, input_dir):
13401344
unet_ = None

0 commit comments

Comments
 (0)