Skip to content

Commit 721e017

Browse files
[Flax] Make room for more frameworks (#494)
* start * finish
1 parent f4781a0 commit 721e017

8 files changed

+227
-57
lines changed

setup.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
"""
6969

7070
import re
71+
import os
7172
from distutils.core import Command
7273

7374
from setuptools import find_packages, setup
@@ -82,10 +83,13 @@
8283
"datasets",
8384
"filelock",
8485
"flake8>=3.8.3",
86+
"flax>=0.4.1",
8587
"hf-doc-builder>=0.3.0",
8688
"huggingface-hub>=0.8.1",
8789
"importlib_metadata",
8890
"isort>=5.5.4",
91+
"jax>=0.2.8,!=0.3.2,<=0.3.6",
92+
"jaxlib>=0.1.65,<=0.3.6",
8993
"modelcards==0.1.4",
9094
"numpy",
9195
"pytest",
@@ -171,7 +175,14 @@ def run(self):
171175
extras["docs"] = ["hf-doc-builder"]
172176
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
173177
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
174-
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
178+
extras["torch"] = deps_list("torch")
179+
180+
if os.name == "nt": # windows
181+
extras["flax"] = [] # jax is not supported on windows
182+
else:
183+
extras["flax"] = deps_list("jax", "jaxlib", "flax")
184+
185+
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
175186

176187
install_requires = [
177188
deps["importlib_metadata"],
@@ -180,13 +191,12 @@ def run(self):
180191
deps["numpy"],
181192
deps["regex"],
182193
deps["requests"],
183-
deps["torch"],
184194
deps["Pillow"],
185195
]
186196

187197
setup(
188198
name="diffusers",
189-
version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
199+
version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
190200
description="Diffusers",
191201
long_description=open("README.md", "r", encoding="utf-8").read(),
192202
long_description_content_type="text/markdown",
@@ -198,7 +208,7 @@ def run(self):
198208
package_dir={"": "src"},
199209
packages=find_packages("src"),
200210
include_package_data=True,
201-
python_requires=">=3.6.0",
211+
python_requires=">=3.7.0",
202212
install_requires=install_requires,
203213
extras_require=extras,
204214
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},

src/diffusers/__init__.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
is_inflect_available,
33
is_onnx_available,
44
is_scipy_available,
5+
is_torch_available,
56
is_transformers_available,
67
is_unidecode_available,
78
)
@@ -10,51 +11,52 @@
1011
__version__ = "0.4.0.dev0"
1112

1213
from .configuration_utils import ConfigMixin
13-
from .modeling_utils import ModelMixin
14-
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
1514
from .onnx_utils import OnnxRuntimeModel
16-
from .optimization import (
17-
get_constant_schedule,
18-
get_constant_schedule_with_warmup,
19-
get_cosine_schedule_with_warmup,
20-
get_cosine_with_hard_restarts_schedule_with_warmup,
21-
get_linear_schedule_with_warmup,
22-
get_polynomial_decay_schedule_with_warmup,
23-
get_scheduler,
24-
)
25-
from .pipeline_utils import DiffusionPipeline
26-
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
27-
from .schedulers import (
28-
DDIMScheduler,
29-
DDPMScheduler,
30-
KarrasVeScheduler,
31-
PNDMScheduler,
32-
SchedulerMixin,
33-
ScoreSdeVeScheduler,
34-
)
3515
from .utils import logging
3616

3717

38-
if is_scipy_available():
39-
from .schedulers import LMSDiscreteScheduler
18+
if is_torch_available():
19+
from .modeling_utils import ModelMixin
20+
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
21+
from .optimization import (
22+
get_constant_schedule,
23+
get_constant_schedule_with_warmup,
24+
get_cosine_schedule_with_warmup,
25+
get_cosine_with_hard_restarts_schedule_with_warmup,
26+
get_linear_schedule_with_warmup,
27+
get_polynomial_decay_schedule_with_warmup,
28+
get_scheduler,
29+
)
30+
from .pipeline_utils import DiffusionPipeline
31+
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
32+
from .schedulers import (
33+
DDIMScheduler,
34+
DDPMScheduler,
35+
KarrasVeScheduler,
36+
PNDMScheduler,
37+
SchedulerMixin,
38+
ScoreSdeVeScheduler,
39+
)
40+
from .training_utils import EMAModel
4041
else:
41-
from .utils.dummy_scipy_objects import * # noqa F403
42-
43-
from .training_utils import EMAModel
42+
from .utils.dummy_pt_objects import * # noqa F403
4443

44+
if is_torch_available() and is_scipy_available():
45+
from .schedulers import LMSDiscreteScheduler
46+
else:
47+
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
4548

46-
if is_transformers_available():
49+
if is_torch_available() and is_transformers_available():
4750
from .pipelines import (
4851
LDMTextToImagePipeline,
4952
StableDiffusionImg2ImgPipeline,
5053
StableDiffusionInpaintPipeline,
5154
StableDiffusionPipeline,
5255
)
5356
else:
54-
from .utils.dummy_transformers_objects import * # noqa F403
55-
57+
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
5658

57-
if is_transformers_available() and is_onnx_available():
59+
if is_torch_available() and is_transformers_available() and is_onnx_available():
5860
from .pipelines import StableDiffusionOnnxPipeline
5961
else:
60-
from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
62+
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403

src/diffusers/dependency_versions_table.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
"datasets": "datasets",
99
"filelock": "filelock",
1010
"flake8": "flake8>=3.8.3",
11+
"flax": "flax>=0.4.1",
1112
"hf-doc-builder": "hf-doc-builder>=0.3.0",
1213
"huggingface-hub": "huggingface-hub>=0.8.1",
1314
"importlib_metadata": "importlib_metadata",
1415
"isort": "isort>=5.5.4",
16+
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
17+
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
1518
"modelcards": "modelcards==0.1.4",
1619
"numpy": "numpy",
1720
"pytest": "pytest",
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# This file is autogenerated by the command `make fix-copies`, do not edit.
2+
# flake8: noqa
3+
4+
from ..utils import DummyObject, requires_backends
5+
6+
7+
class ModelMixin(metaclass=DummyObject):
8+
_backends = ["torch"]
9+
10+
def __init__(self, *args, **kwargs):
11+
requires_backends(self, ["torch"])
12+
13+
14+
class AutoencoderKL(metaclass=DummyObject):
15+
_backends = ["torch"]
16+
17+
def __init__(self, *args, **kwargs):
18+
requires_backends(self, ["torch"])
19+
20+
21+
class UNet2DConditionModel(metaclass=DummyObject):
22+
_backends = ["torch"]
23+
24+
def __init__(self, *args, **kwargs):
25+
requires_backends(self, ["torch"])
26+
27+
28+
class UNet2DModel(metaclass=DummyObject):
29+
_backends = ["torch"]
30+
31+
def __init__(self, *args, **kwargs):
32+
requires_backends(self, ["torch"])
33+
34+
35+
class VQModel(metaclass=DummyObject):
36+
_backends = ["torch"]
37+
38+
def __init__(self, *args, **kwargs):
39+
requires_backends(self, ["torch"])
40+
41+
42+
def get_constant_schedule(*args, **kwargs):
43+
requires_backends(get_constant_schedule, ["torch"])
44+
45+
46+
def get_constant_schedule_with_warmup(*args, **kwargs):
47+
requires_backends(get_constant_schedule_with_warmup, ["torch"])
48+
49+
50+
def get_cosine_schedule_with_warmup(*args, **kwargs):
51+
requires_backends(get_cosine_schedule_with_warmup, ["torch"])
52+
53+
54+
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
55+
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])
56+
57+
58+
def get_linear_schedule_with_warmup(*args, **kwargs):
59+
requires_backends(get_linear_schedule_with_warmup, ["torch"])
60+
61+
62+
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
63+
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"])
64+
65+
66+
def get_scheduler(*args, **kwargs):
67+
requires_backends(get_scheduler, ["torch"])
68+
69+
70+
class DiffusionPipeline(metaclass=DummyObject):
71+
_backends = ["torch"]
72+
73+
def __init__(self, *args, **kwargs):
74+
requires_backends(self, ["torch"])
75+
76+
77+
class DDIMPipeline(metaclass=DummyObject):
78+
_backends = ["torch"]
79+
80+
def __init__(self, *args, **kwargs):
81+
requires_backends(self, ["torch"])
82+
83+
84+
class DDPMPipeline(metaclass=DummyObject):
85+
_backends = ["torch"]
86+
87+
def __init__(self, *args, **kwargs):
88+
requires_backends(self, ["torch"])
89+
90+
91+
class KarrasVePipeline(metaclass=DummyObject):
92+
_backends = ["torch"]
93+
94+
def __init__(self, *args, **kwargs):
95+
requires_backends(self, ["torch"])
96+
97+
98+
class LDMPipeline(metaclass=DummyObject):
99+
_backends = ["torch"]
100+
101+
def __init__(self, *args, **kwargs):
102+
requires_backends(self, ["torch"])
103+
104+
105+
class PNDMPipeline(metaclass=DummyObject):
106+
_backends = ["torch"]
107+
108+
def __init__(self, *args, **kwargs):
109+
requires_backends(self, ["torch"])
110+
111+
112+
class ScoreSdeVePipeline(metaclass=DummyObject):
113+
_backends = ["torch"]
114+
115+
def __init__(self, *args, **kwargs):
116+
requires_backends(self, ["torch"])
117+
118+
119+
class DDIMScheduler(metaclass=DummyObject):
120+
_backends = ["torch"]
121+
122+
def __init__(self, *args, **kwargs):
123+
requires_backends(self, ["torch"])
124+
125+
126+
class DDPMScheduler(metaclass=DummyObject):
127+
_backends = ["torch"]
128+
129+
def __init__(self, *args, **kwargs):
130+
requires_backends(self, ["torch"])
131+
132+
133+
class KarrasVeScheduler(metaclass=DummyObject):
134+
_backends = ["torch"]
135+
136+
def __init__(self, *args, **kwargs):
137+
requires_backends(self, ["torch"])
138+
139+
140+
class PNDMScheduler(metaclass=DummyObject):
141+
_backends = ["torch"]
142+
143+
def __init__(self, *args, **kwargs):
144+
requires_backends(self, ["torch"])
145+
146+
147+
class SchedulerMixin(metaclass=DummyObject):
148+
_backends = ["torch"]
149+
150+
def __init__(self, *args, **kwargs):
151+
requires_backends(self, ["torch"])
152+
153+
154+
class ScoreSdeVeScheduler(metaclass=DummyObject):
155+
_backends = ["torch"]
156+
157+
def __init__(self, *args, **kwargs):
158+
requires_backends(self, ["torch"])
159+
160+
161+
class EMAModel(metaclass=DummyObject):
162+
_backends = ["torch"]
163+
164+
def __init__(self, *args, **kwargs):
165+
requires_backends(self, ["torch"])

src/diffusers/utils/dummy_scipy_objects.py renamed to src/diffusers/utils/dummy_torch_and_scipy_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class LMSDiscreteScheduler(metaclass=DummyObject):
8-
_backends = ["scipy"]
8+
_backends = ["torch", "scipy"]
99

1010
def __init__(self, *args, **kwargs):
11-
requires_backends(self, ["scipy"])
11+
requires_backends(self, ["torch", "scipy"])

src/diffusers/utils/dummy_transformers_and_onnx_objects.py renamed to src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class StableDiffusionOnnxPipeline(metaclass=DummyObject):
8-
_backends = ["transformers", "onnx"]
8+
_backends = ["torch", "transformers", "onnx"]
99

1010
def __init__(self, *args, **kwargs):
11-
requires_backends(self, ["transformers", "onnx"])
11+
requires_backends(self, ["torch", "transformers", "onnx"])

src/diffusers/utils/dummy_transformers_objects.py renamed to src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,28 @@
55

66

77
class LDMTextToImagePipeline(metaclass=DummyObject):
8-
_backends = ["transformers"]
8+
_backends = ["torch", "transformers"]
99

1010
def __init__(self, *args, **kwargs):
11-
requires_backends(self, ["transformers"])
11+
requires_backends(self, ["torch", "transformers"])
1212

1313

1414
class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
15-
_backends = ["transformers"]
15+
_backends = ["torch", "transformers"]
1616

1717
def __init__(self, *args, **kwargs):
18-
requires_backends(self, ["transformers"])
18+
requires_backends(self, ["torch", "transformers"])
1919

2020

2121
class StableDiffusionInpaintPipeline(metaclass=DummyObject):
22-
_backends = ["transformers"]
22+
_backends = ["torch", "transformers"]
2323

2424
def __init__(self, *args, **kwargs):
25-
requires_backends(self, ["transformers"])
25+
requires_backends(self, ["torch", "transformers"])
2626

2727

2828
class StableDiffusionPipeline(metaclass=DummyObject):
29-
_backends = ["transformers"]
29+
_backends = ["torch", "transformers"]
3030

3131
def __init__(self, *args, **kwargs):
32-
requires_backends(self, ["transformers"])
32+
requires_backends(self, ["torch", "transformers"])

0 commit comments

Comments
 (0)