diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2037bd787433..5e01ec567f9a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1358,14 +1358,30 @@ def load_lora_into_transformer( inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) + warn_msg = "" if incompatible_keys is not None: - # check only for unexpected keys + # Check only for unexpected keys. unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) # Offload back. if is_model_cpu_offload: @@ -1932,14 +1948,30 @@ def load_lora_into_transformer( inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) + warn_msg = "" if incompatible_keys is not None: - # check only for unexpected keys + # Check only for unexpected keys. unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) # Offload back. if is_model_cpu_offload: @@ -2279,14 +2311,30 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + warn_msg = "" if incompatible_keys is not None: - # check only for unexpected keys + # Check only for unexpected keys. unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) # Offload back. if is_model_cpu_offload: @@ -2717,14 +2765,30 @@ def load_lora_into_transformer( inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) + warn_msg = "" if incompatible_keys is not None: - # check only for unexpected keys + # Check only for unexpected keys. unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) # Offload back. if is_model_cpu_offload: diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index eaac52df6202..2fa7732a6a3b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -354,14 +354,30 @@ def _process_lora( inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + warn_msg = "" if incompatible_keys is not None: - # check only for unexpected keys + # Check only for unexpected keys. unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) return is_model_cpu_offload, is_sequential_cpu_offload diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 4629c24c8cd8..3bc46d1e9b13 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -27,6 +27,7 @@ from diffusers.utils.testing_utils import ( floats_tensor, is_peft_available, + numpy_cosine_similarity_distance, require_peft_backend, require_torch_gpu, slow, @@ -166,7 +167,7 @@ def test_modify_padding_mode(self): @slow @require_torch_gpu @require_peft_backend -@unittest.skip("We cannot run inference on this model with the current CI hardware") +# @unittest.skip("We cannot run inference on this model with the current CI hardware") # TODO (DN6, sayakpaul): move these tests to a beefier GPU class FluxLoRAIntegrationTests(unittest.TestCase): """internal note: The integration slices were obtained on audace. @@ -208,9 +209,11 @@ def test_flux_the_last_ben(self): generator=torch.manual_seed(self.seed), ).images out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090]) + expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246]) - assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 def test_flux_kohya(self): self.pipeline.load_lora_weights("Norod78/brain-slug-flux") @@ -230,7 +233,9 @@ def test_flux_kohya(self): out_slice = out[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484]) - assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 def test_flux_kohya_with_text_encoder(self): self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") @@ -248,9 +253,11 @@ def test_flux_kohya_with_text_encoder(self): ).images out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.4023, 0.4043, 0.4023, 0.3965, 0.3984, 0.3984, 0.3906, 0.3906, 0.4219]) + expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219]) - assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) + + assert max_diff < 1e-3 def test_flux_xlabs(self): self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") @@ -268,6 +275,8 @@ def test_flux_xlabs(self): generator=torch.manual_seed(self.seed), ).images out_slice = out[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980]) + expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980]) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) - assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4) + assert max_diff < 1e-3 diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 48c7d5fcec89..e7fc840fcaa5 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -27,8 +27,10 @@ LCMScheduler, UNet2DConditionModel, ) +from diffusers.utils import logging from diffusers.utils.import_utils import is_peft_available from diffusers.utils.testing_utils import ( + CaptureLogger, floats_tensor, require_peft_backend, require_peft_version_greater, @@ -219,10 +221,18 @@ def _get_modules_to_save(self, pipe, has_denoiser=False): modules_to_save = {} lora_loadable_modules = self.pipeline_class._lora_loadable_modules - if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"): + if ( + "text_encoder" in lora_loadable_modules + and hasattr(pipe, "text_encoder") + and getattr(pipe.text_encoder, "peft_config", None) is not None + ): modules_to_save["text_encoder"] = pipe.text_encoder - if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"): + if ( + "text_encoder_2" in lora_loadable_modules + and hasattr(pipe, "text_encoder_2") + and getattr(pipe.text_encoder_2, "peft_config", None) is not None + ): modules_to_save["text_encoder_2"] = pipe.text_encoder_2 if has_denoiser: @@ -1747,6 +1757,83 @@ def test_simple_inference_with_dora(self): "DoRA lora should change the output", ) + def test_missing_keys_warning(self): + scheduler_cls = self.scheduler_classes[0] + # Skip text encoder check for now as that is handled with `transformers`. + components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + pipe.unload_lora_weights() + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + + # To make things dynamic since we cannot settle with a single key for all the models where we + # offer PEFT support. + missing_key = [k for k in state_dict if "lora_A" in k][0] + del state_dict[missing_key] + + logger = ( + logging.get_logger("diffusers.loaders.unet") + if self.unet_kwargs is not None + else logging.get_logger("diffusers.loaders.lora_pipeline") + ) + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(state_dict) + + # Since the missing key won't contain the adapter name ("default_0"). + # Also strip out the component prefix (such as "unet." from `missing_key`). + component = list({k.split(".")[0] for k in state_dict})[0] + self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")) + + def test_unexpected_keys_warning(self): + scheduler_cls = self.scheduler_classes[0] + # Skip text encoder check for now as that is handled with `transformers`. + components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + pipe.unload_lora_weights() + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True) + + unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat" + state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device) + + logger = ( + logging.get_logger("diffusers.loaders.unet") + if self.unet_kwargs is not None + else logging.get_logger("diffusers.loaders.lora_pipeline") + ) + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + pipe.load_lora_weights(state_dict) + + self.assertTrue(".diffusers_cat" in cap_logger.out) + @unittest.skip("This is failing for now - need to investigate") def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): """