From a85fc37c9b072f06a60cda8f07edc44d0e378bd1 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 5 Jan 2024 09:02:50 +0200 Subject: [PATCH 1/6] unwrap text encoder when saving hook only for full text encoder tuning --- .../train_dreambooth_lora_sdxl_advanced.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 58df031c3f83..b9bd25d4de16 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1316,9 +1316,10 @@ def save_model_hook(models, weights, output_dir): if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + if args.train_text_encoder: + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) From 6189a090428cc2972dfc32cc4d11b6edd47156f6 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 5 Jan 2024 09:20:49 +0200 Subject: [PATCH 2/6] unwrap text encoder when saving hook only for full text encoder tuning --- .../train_dreambooth_lora_sdxl_advanced.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index b9bd25d4de16..a89c29c5c3b6 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1321,9 +1321,10 @@ def save_model_hook(models, weights, output_dir): get_peft_model_state_dict(model) ) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + if args.train_text_encoder: + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) else: raise ValueError(f"unexpected save model: {model.__class__}") From 1d973fa6e5663bf6e356031313fec0360844a0b2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 5 Jan 2024 09:41:53 +0200 Subject: [PATCH 3/6] save embeddings in each checkpoint as well --- .../train_dreambooth_lora_sdxl_advanced.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index a89c29c5c3b6..624b137fe801 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1337,6 +1337,10 @@ def save_model_hook(models, weights, output_dir): text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, ) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{output_dir}/{output_dir}_emb.safetensors", + ) def load_model_hook(models, input_dir): unet_ = None From aabe1fe34aaa7c0844ac6082dd18b5fc5ba39d8d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 5 Jan 2024 10:07:06 +0200 Subject: [PATCH 4/6] save embeddings in each checkpoint as well --- .../train_dreambooth_lora_sdxl_advanced.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 624b137fe801..a4c075f24888 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1337,10 +1337,10 @@ def save_model_hook(models, weights, output_dir): text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, ) - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{output_dir}/{output_dir}_emb.safetensors", - ) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{output_dir}/{output_dir}_emb.safetensors", + ) def load_model_hook(models, input_dir): unet_ = None From 21b7a1a36f32d5efd22f44d9e8a587432563abd5 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 5 Jan 2024 10:07:23 +0200 Subject: [PATCH 5/6] save embeddings in each checkpoint as well --- .../train_dreambooth_lora_sdxl_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index a4c075f24888..8c29c27f1b7c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1339,7 +1339,7 @@ def save_model_hook(models, weights, output_dir): ) if args.train_text_encoder_ti: embedding_handler.save_embeddings( - f"{output_dir}/{output_dir}_emb.safetensors", + f"{output_dir}/{args.output_dir}_emb.safetensors", ) def load_model_hook(models, input_dir): From b7fc8af70690b52ae47b725fb38a8a42c609e0a3 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Fri, 5 Jan 2024 10:15:21 +0200 Subject: [PATCH 6/6] Update examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py Co-authored-by: Sayak Paul --- .../train_dreambooth_lora_sdxl_advanced.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 8c29c27f1b7c..979a17927182 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1338,9 +1338,7 @@ def save_model_hook(models, weights, output_dir): text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, ) if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{output_dir}/{args.output_dir}_emb.safetensors", - ) + embedding_handler.save_embeddings(f"{output_dir}/{args.output_dir}_emb.safetensors") def load_model_hook(models, input_dir): unet_ = None