Skip to content

Commit df90f0c

Browse files
authored
Add is_torch_available, is_flax_available (#204)
* Add is_<framework>_available, refactor import utils * deps * quality
1 parent ed22b4f commit df90f0c

File tree

7 files changed

+312
-185
lines changed

7 files changed

+312
-185
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
"filelock",
8484
"flake8>=3.8.3",
8585
"hf-doc-builder>=0.3.0",
86-
"huggingface-hub",
86+
"huggingface-hub>=0.8.1,<1.0",
8787
"importlib_metadata",
8888
"isort>=5.5.4",
8989
"modelcards==0.1.4",

src/diffusers/configuration_utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,11 @@
2323
from typing import Any, Dict, Tuple, Union
2424

2525
from huggingface_hub import hf_hub_download
26+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2627
from requests import HTTPError
2728

2829
from . import __version__
29-
from .utils import (
30-
DIFFUSERS_CACHE,
31-
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
32-
EntryNotFoundError,
33-
RepositoryNotFoundError,
34-
RevisionNotFoundError,
35-
logging,
36-
)
30+
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
3731

3832

3933
logger = logging.get_logger(__name__)

src/diffusers/dependency_versions_table.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
"Pillow": "Pillow",
66
"accelerate": "accelerate>=0.11.0",
77
"black": "black~=22.0,>=22.3",
8+
"datasets": "datasets",
89
"filelock": "filelock",
910
"flake8": "flake8>=3.8.3",
1011
"hf-doc-builder": "hf-doc-builder>=0.3.0",
11-
"huggingface-hub": "huggingface-hub",
12+
"huggingface-hub": "huggingface-hub>=0.8.1,<1.0",
1213
"importlib_metadata": "importlib_metadata",
1314
"isort": "isort>=5.5.4",
1415
"modelcards": "modelcards==0.1.4",
1516
"numpy": "numpy",
1617
"pytest": "pytest",
1718
"regex": "regex!=2019.12.17",
1819
"requests": "requests",
19-
"torch": "torch>=1.4",
2020
"tensorboard": "tensorboard",
21+
"torch": "torch>=1.4",
2122
}

src/diffusers/modeling_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,10 @@
2121
from torch import Tensor, device
2222

2323
from huggingface_hub import hf_hub_download
24+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2425
from requests import HTTPError
2526

26-
from .utils import (
27-
CONFIG_NAME,
28-
DIFFUSERS_CACHE,
29-
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
30-
EntryNotFoundError,
31-
RepositoryNotFoundError,
32-
RevisionNotFoundError,
33-
logging,
34-
)
27+
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
3528

3629

3730
WEIGHTS_NAME = "diffusion_pytorch_model.bin"

src/diffusers/utils/__init__.py

Lines changed: 23 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
1+
# flake8: noqa
2+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
3+
# module, but to preserve other warnings. So, don't check this module at all.
4+
5+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
26
#
37
# Licensed under the Apache License, Version 2.0 (the "License");
48
# you may not use this file except in compliance with the License.
@@ -11,13 +15,26 @@
1115
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1216
# See the License for the specific language governing permissions and
1317
# limitations under the License.
14-
import importlib
15-
import os
16-
from collections import OrderedDict
1718

18-
import importlib_metadata
19-
from requests.exceptions import HTTPError
2019

20+
import os
21+
22+
from .import_utils import (
23+
ENV_VARS_TRUE_AND_AUTO_VALUES,
24+
ENV_VARS_TRUE_VALUES,
25+
USE_JAX,
26+
USE_TF,
27+
USE_TORCH,
28+
DummyObject,
29+
is_flax_available,
30+
is_inflect_available,
31+
is_scipy_available,
32+
is_tf_available,
33+
is_torch_available,
34+
is_transformers_available,
35+
is_unidecode_available,
36+
requires_backends,
37+
)
2138
from .logging import get_logger
2239

2340

@@ -35,135 +52,3 @@
3552
DIFFUSERS_CACHE = default_cache_path
3653
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
3754
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
38-
39-
40-
_transformers_available = importlib.util.find_spec("transformers") is not None
41-
try:
42-
_transformers_version = importlib_metadata.version("transformers")
43-
logger.debug(f"Successfully imported transformers version {_transformers_version}")
44-
except importlib_metadata.PackageNotFoundError:
45-
_transformers_available = False
46-
47-
48-
_inflect_available = importlib.util.find_spec("inflect") is not None
49-
try:
50-
_inflect_version = importlib_metadata.version("inflect")
51-
logger.debug(f"Successfully imported inflect version {_inflect_version}")
52-
except importlib_metadata.PackageNotFoundError:
53-
_inflect_available = False
54-
55-
56-
_unidecode_available = importlib.util.find_spec("unidecode") is not None
57-
try:
58-
_unidecode_version = importlib_metadata.version("unidecode")
59-
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
60-
except importlib_metadata.PackageNotFoundError:
61-
_unidecode_available = False
62-
63-
64-
_modelcards_available = importlib.util.find_spec("modelcards") is not None
65-
try:
66-
_modelcards_version = importlib_metadata.version("modelcards")
67-
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
68-
except importlib_metadata.PackageNotFoundError:
69-
_modelcards_available = False
70-
71-
72-
_scipy_available = importlib.util.find_spec("scipy") is not None
73-
try:
74-
_scipy_version = importlib_metadata.version("scipy")
75-
logger.debug(f"Successfully imported transformers version {_scipy_version}")
76-
except importlib_metadata.PackageNotFoundError:
77-
_scipy_available = False
78-
79-
80-
def is_transformers_available():
81-
return _transformers_available
82-
83-
84-
def is_inflect_available():
85-
return _inflect_available
86-
87-
88-
def is_unidecode_available():
89-
return _unidecode_available
90-
91-
92-
def is_modelcards_available():
93-
return _modelcards_available
94-
95-
96-
def is_scipy_available():
97-
return _scipy_available
98-
99-
100-
class RepositoryNotFoundError(HTTPError):
101-
"""
102-
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
103-
not have access to.
104-
"""
105-
106-
107-
class EntryNotFoundError(HTTPError):
108-
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
109-
110-
111-
class RevisionNotFoundError(HTTPError):
112-
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
113-
114-
115-
TRANSFORMERS_IMPORT_ERROR = """
116-
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
117-
install transformers`
118-
"""
119-
120-
121-
UNIDECODE_IMPORT_ERROR = """
122-
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
123-
Unidecode`
124-
"""
125-
126-
127-
INFLECT_IMPORT_ERROR = """
128-
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
129-
inflect`
130-
"""
131-
132-
133-
SCIPY_IMPORT_ERROR = """
134-
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
135-
scipy`
136-
"""
137-
138-
139-
BACKENDS_MAPPING = OrderedDict(
140-
[
141-
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
142-
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
143-
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
144-
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
145-
]
146-
)
147-
148-
149-
def requires_backends(obj, backends):
150-
if not isinstance(backends, (list, tuple)):
151-
backends = [backends]
152-
153-
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
154-
checks = (BACKENDS_MAPPING[backend] for backend in backends)
155-
failed = [msg.format(name) for available, msg in checks if not available()]
156-
if failed:
157-
raise ImportError("".join(failed))
158-
159-
160-
class DummyObject(type):
161-
"""
162-
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
163-
`requires_backend` each time a user tries to access any method of that class.
164-
"""
165-
166-
def __getattr__(cls, key):
167-
if key.startswith("_"):
168-
return super().__getattr__(cls, key)
169-
requires_backends(cls, cls._backends)

0 commit comments

Comments
 (0)