Skip to content

fix a regression in is_safetensors_compatible #9234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])


def is_safetensors_compatible(filenames, passed_components=None) -> bool:
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
Expand All @@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool:
extension is replaced with ".safetensors"
"""
passed_components = passed_components or []
if folder_names is not None:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}

# extract all components of the pipeline and their associated files
components = {}
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,14 +1416,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
Expand Down
24 changes: 24 additions & 0 deletions tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,30 @@ def test_transformer_model_is_not_compatible_variant(self):
]
self.assertFalse(is_safetensors_compatible(filenames))

def test_transformer_model_is_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))

def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"text_encoder"}))

def test_transformers_is_compatible_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",
Expand Down
Loading