Skip to content

[ONNX] Collate the external weights, speed up loading from the hub #610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion scripts/convert_stable_diffusion_checkpoint_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.

import argparse
import os
import shutil
from pathlib import Path

import torch
from torch.onnx import export

import onnx
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version
Expand Down Expand Up @@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
)

# UNET
unet_path = output_path / "unet" / "model.onnx"
onnx_export(
pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
output_path=output_path / "unet" / "model.onnx",
output_path=unet_path,
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={
Expand All @@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = os.path.dirname(unet_model_path)
unet = onnx.load(unet_model_path)
# clean up existing tensor files
shutil.rmtree(unet_dir)
os.mkdir(unet_dir)
# collate external tensor files into one
onnx.save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False,
)

# VAE ENCODER
vae_encoder = pipeline.vae
Expand Down
23 changes: 18 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards==0.1.4",
"modelcards>=0.1.4",
"numpy",
"onnxruntime",
"onnxruntime-gpu",
"pytest",
"pytest-timeout",
"pytest-xdist",
Expand All @@ -100,6 +102,7 @@
"requests",
"tensorboard",
"torch>=1.4",
"torchvision",
"transformers>=4.21.0",
]

Expand Down Expand Up @@ -171,10 +174,20 @@ def run(self):


extras = {}
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
extras["docs"] = ["hf-doc-builder"]
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"]
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"datasets",
"onnxruntime",
"onnxruntime-gpu",
"pytest",
"pytest-timeout",
"pytest-xdist",
"scipy",
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch")

if os.name == "nt": # windows
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards==0.1.4",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
"onnxruntime-gpu": "onnxruntime-gpu",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to add onnxruntime-gpu for our CI, fixing that here

"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
Expand All @@ -25,5 +27,6 @@
"requests": "requests",
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.21.0",
}
9 changes: 3 additions & 6 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,12 +1373,9 @@ def test_stable_diffusion_inpaint_pipeline_k_lms(self):

@slow
def test_stable_diffusion_onnx(self):
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models

with tempfile.TemporaryDirectory() as tmpdirname:
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)

sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
)
Comment on lines +1376 to +1378
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 minutes -> 10 seconds!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome - great work!


prompt = "A painting of a squirrel eating a burger"
np.random.seed(0)
Expand Down