Skip to content

Commit 214372a

Browse files
authored
fix a regression in is_safetensors_compatible (#9234)
fix
1 parent 867e0c9 commit 214372a

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
9090

9191

92-
def is_safetensors_compatible(filenames, passed_components=None) -> bool:
92+
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
9393
"""
9494
Checking for safetensors compatibility:
9595
- The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool:
101101
extension is replaced with ".safetensors"
102102
"""
103103
passed_components = passed_components or []
104+
if folder_names is not None:
105+
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
104106

105107
# extract all components of the pipeline and their associated files
106108
components = {}

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,14 +1416,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14161416
if (
14171417
use_safetensors
14181418
and not allow_pickle
1419-
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
1419+
and not is_safetensors_compatible(
1420+
model_filenames, passed_components=passed_components, folder_names=model_folder_names
1421+
)
14201422
):
14211423
raise EnvironmentError(
14221424
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
14231425
)
14241426
if from_flax:
14251427
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
1426-
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
1428+
elif use_safetensors and is_safetensors_compatible(
1429+
model_filenames, passed_components=passed_components, folder_names=model_folder_names
1430+
):
14271431
ignore_patterns = ["*.bin", "*.msgpack"]
14281432

14291433
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx

tests/pipelines/test_pipeline_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,30 @@ def test_transformer_model_is_not_compatible_variant(self):
116116
]
117117
self.assertFalse(is_safetensors_compatible(filenames))
118118

119+
def test_transformer_model_is_compatible_variant_extra_folder(self):
120+
filenames = [
121+
"safety_checker/pytorch_model.fp16.bin",
122+
"safety_checker/model.fp16.safetensors",
123+
"vae/diffusion_pytorch_model.fp16.bin",
124+
"vae/diffusion_pytorch_model.fp16.safetensors",
125+
"text_encoder/pytorch_model.fp16.bin",
126+
"unet/diffusion_pytorch_model.fp16.bin",
127+
"unet/diffusion_pytorch_model.fp16.safetensors",
128+
]
129+
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
130+
131+
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
132+
filenames = [
133+
"safety_checker/pytorch_model.fp16.bin",
134+
"safety_checker/model.fp16.safetensors",
135+
"vae/diffusion_pytorch_model.fp16.bin",
136+
"vae/diffusion_pytorch_model.fp16.safetensors",
137+
"text_encoder/pytorch_model.fp16.bin",
138+
"unet/diffusion_pytorch_model.fp16.bin",
139+
"unet/diffusion_pytorch_model.fp16.safetensors",
140+
]
141+
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"text_encoder"}))
142+
119143
def test_transformers_is_compatible_sharded(self):
120144
filenames = [
121145
"text_encoder/pytorch_model.bin",

0 commit comments

Comments
 (0)