Skip to content

Commit a15b9b3

Browse files
committed
log a warning when there are missing keys in the LoRA loading.
1 parent 31058cd commit a15b9b3

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,14 @@ def load_lora_into_transformer(
13671367
f" {unexpected_keys}. "
13681368
)
13691369

1370+
# Filter missing keys specific to the current adapter.
1371+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
1372+
if missing_keys:
1373+
logger.warning(
1374+
f"Loading adapter weights from state_dict led to missing keys in the model: "
1375+
f" {missing_keys}. "
1376+
)
1377+
13701378
# Offload back.
13711379
if is_model_cpu_offload:
13721380
_pipeline.enable_model_cpu_offload()
@@ -1933,14 +1941,22 @@ def load_lora_into_transformer(
19331941
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
19341942

19351943
if incompatible_keys is not None:
1936-
# check only for unexpected keys
1944+
# Check only for unexpected keys.
19371945
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
19381946
if unexpected_keys:
19391947
logger.warning(
19401948
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
19411949
f" {unexpected_keys}. "
19421950
)
19431951

1952+
# Filter missing keys specific to the current adapter.
1953+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
1954+
if missing_keys:
1955+
logger.warning(
1956+
f"Loading adapter weights from state_dict led to missing keys in the model: "
1957+
f" {missing_keys}. "
1958+
)
1959+
19441960
# Offload back.
19451961
if is_model_cpu_offload:
19461962
_pipeline.enable_model_cpu_offload()
@@ -2288,6 +2304,14 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
22882304
f" {unexpected_keys}. "
22892305
)
22902306

2307+
# Filter missing keys specific to the current adapter.
2308+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
2309+
if missing_keys:
2310+
logger.warning(
2311+
f"Loading adapter weights from state_dict led to missing keys in the model: "
2312+
f" {missing_keys}. "
2313+
)
2314+
22912315
# Offload back.
22922316
if is_model_cpu_offload:
22932317
_pipeline.enable_model_cpu_offload()
@@ -2726,6 +2750,14 @@ def load_lora_into_transformer(
27262750
f" {unexpected_keys}. "
27272751
)
27282752

2753+
# Filter missing keys specific to the current adapter.
2754+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
2755+
if missing_keys:
2756+
logger.warning(
2757+
f"Loading adapter weights from state_dict led to missing keys in the model: "
2758+
f" {missing_keys}. "
2759+
)
2760+
27292761
# Offload back.
27302762
if is_model_cpu_offload:
27312763
_pipeline.enable_model_cpu_offload()

src/diffusers/loaders/unet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,14 @@ def _process_lora(
363363
f" {unexpected_keys}. "
364364
)
365365

366+
# Filter missing keys specific to the current adapter.
367+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
368+
if missing_keys:
369+
logger.warning(
370+
f"Loading adapter weights from state_dict led to missing keys in the model: "
371+
f" {missing_keys}. "
372+
)
373+
366374
return is_model_cpu_offload, is_sequential_cpu_offload
367375

368376
@classmethod

0 commit comments

Comments
 (0)