diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index af323164f562..a8c23adead49 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -89,7 +89,7 @@ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) -def is_safetensors_compatible(filenames, passed_components=None) -> bool: +def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool: """ Checking for safetensors compatibility: - The model is safetensors compatible only if there is a safetensors file for each model component present in @@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool: extension is replaced with ".safetensors" """ passed_components = passed_components or [] + if folder_names is not None: + filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} # extract all components of the pipeline and their associated files components = {} diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f2882c5b1d02..631776f25043 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1416,14 +1416,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: if ( use_safetensors and not allow_pickle - and not is_safetensors_compatible(model_filenames, passed_components=passed_components) + and not is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ) ): raise EnvironmentError( f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" ) if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components): + elif use_safetensors and is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ): ignore_patterns = ["*.bin", "*.msgpack"] use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 0e3f2e8c2e27..57194acdcf2a 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -116,6 +116,30 @@ def test_transformer_model_is_not_compatible_variant(self): ] self.assertFalse(is_safetensors_compatible(filenames)) + def test_transformer_model_is_compatible_variant_extra_folder(self): + filenames = [ + "safety_checker/pytorch_model.fp16.bin", + "safety_checker/model.fp16.safetensors", + "vae/diffusion_pytorch_model.fp16.bin", + "vae/diffusion_pytorch_model.fp16.safetensors", + "text_encoder/pytorch_model.fp16.bin", + "unet/diffusion_pytorch_model.fp16.bin", + "unet/diffusion_pytorch_model.fp16.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"})) + + def test_transformer_model_is_not_compatible_variant_extra_folder(self): + filenames = [ + "safety_checker/pytorch_model.fp16.bin", + "safety_checker/model.fp16.safetensors", + "vae/diffusion_pytorch_model.fp16.bin", + "vae/diffusion_pytorch_model.fp16.safetensors", + "text_encoder/pytorch_model.fp16.bin", + "unet/diffusion_pytorch_model.fp16.bin", + "unet/diffusion_pytorch_model.fp16.safetensors", + ] + self.assertFalse(is_safetensors_compatible(filenames, folder_names={"text_encoder"})) + def test_transformers_is_compatible_sharded(self): filenames = [ "text_encoder/pytorch_model.bin",