Skip to content

Commit 6bd005e

Browse files
authored
[ONNX] Collate the external weights, speed up loading from the hub (#610)
1 parent a9fdb3d commit 6bd005e

File tree

4 files changed

+45
-13
lines changed

4 files changed

+45
-13
lines changed

scripts/convert_stable_diffusion_checkpoint_to_onnx.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# limitations under the License.
1414

1515
import argparse
16+
import os
17+
import shutil
1618
from pathlib import Path
1719

1820
import torch
1921
from torch.onnx import export
2022

23+
import onnx
2124
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
2225
from diffusers.onnx_utils import OnnxRuntimeModel
2326
from packaging import version
@@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
9295
)
9396

9497
# UNET
98+
unet_path = output_path / "unet" / "model.onnx"
9599
onnx_export(
96100
pipeline.unet,
97101
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
98-
output_path=output_path / "unet" / "model.onnx",
102+
output_path=unet_path,
99103
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
100104
output_names=["out_sample"], # has to be different from "sample" for correct tracing
101105
dynamic_axes={
@@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
106110
opset=opset,
107111
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
108112
)
113+
unet_model_path = str(unet_path.absolute().as_posix())
114+
unet_dir = os.path.dirname(unet_model_path)
115+
unet = onnx.load(unet_model_path)
116+
# clean up existing tensor files
117+
shutil.rmtree(unet_dir)
118+
os.mkdir(unet_dir)
119+
# collate external tensor files into one
120+
onnx.save_model(
121+
unet,
122+
unet_model_path,
123+
save_as_external_data=True,
124+
all_tensors_to_one_file=True,
125+
location="weights.pb",
126+
convert_attribute=False,
127+
)
109128

110129
# VAE ENCODER
111130
vae_encoder = pipeline.vae

setup.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@
9090
"isort>=5.5.4",
9191
"jax>=0.2.8,!=0.3.2,<=0.3.6",
9292
"jaxlib>=0.1.65,<=0.3.6",
93-
"modelcards==0.1.4",
93+
"modelcards>=0.1.4",
9494
"numpy",
95+
"onnxruntime",
96+
"onnxruntime-gpu",
9597
"pytest",
9698
"pytest-timeout",
9799
"pytest-xdist",
@@ -100,6 +102,7 @@
100102
"requests",
101103
"tensorboard",
102104
"torch>=1.4",
105+
"torchvision",
103106
"transformers>=4.21.0",
104107
]
105108

@@ -171,10 +174,20 @@ def run(self):
171174

172175

173176
extras = {}
174-
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
175-
extras["docs"] = ["hf-doc-builder"]
176-
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
177-
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"]
177+
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
178+
extras["docs"] = deps_list("hf-doc-builder")
179+
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
180+
extras["test"] = deps_list(
181+
"datasets",
182+
"onnxruntime",
183+
"onnxruntime-gpu",
184+
"pytest",
185+
"pytest-timeout",
186+
"pytest-xdist",
187+
"scipy",
188+
"torchvision",
189+
"transformers"
190+
)
178191
extras["torch"] = deps_list("torch")
179192

180193
if os.name == "nt": # windows

src/diffusers/dependency_versions_table.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
"isort": "isort>=5.5.4",
1616
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
1717
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
18-
"modelcards": "modelcards==0.1.4",
18+
"modelcards": "modelcards>=0.1.4",
1919
"numpy": "numpy",
20+
"onnxruntime": "onnxruntime",
21+
"onnxruntime-gpu": "onnxruntime-gpu",
2022
"pytest": "pytest",
2123
"pytest-timeout": "pytest-timeout",
2224
"pytest-xdist": "pytest-xdist",
@@ -25,5 +27,6 @@
2527
"requests": "requests",
2628
"tensorboard": "tensorboard",
2729
"torch": "torch>=1.4",
30+
"torchvision": "torchvision",
2831
"transformers": "transformers>=4.21.0",
2932
}

tests/test_pipelines.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,12 +1373,9 @@ def test_stable_diffusion_inpaint_pipeline_k_lms(self):
13731373

13741374
@slow
13751375
def test_stable_diffusion_onnx(self):
1376-
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models
1377-
1378-
with tempfile.TemporaryDirectory() as tmpdirname:
1379-
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
1380-
1381-
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
1376+
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
1377+
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
1378+
)
13821379

13831380
prompt = "A painting of a squirrel eating a burger"
13841381
np.random.seed(0)

0 commit comments

Comments
 (0)