Skip to content

Commit e5902ed

Browse files
[Download] Smart downloading (#512)
* [Download] Smart downloading * add test * finish test * update * make style
1 parent a54cfe6 commit e5902ed

File tree

3 files changed

+45
-3
lines changed

3 files changed

+45
-3
lines changed

src/diffusers/onnx_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838

3939

4040
class OnnxRuntimeModel:
41-
base_model_prefix = "onnx_model"
42-
4341
def __init__(self, model=None, **kwargs):
4442
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
4543
self.model = model

src/diffusers/pipeline_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030
from tqdm.auto import tqdm
3131

3232
from .configuration_utils import ConfigMixin
33-
from .utils import DIFFUSERS_CACHE, BaseOutput, logging
33+
from .modeling_utils import WEIGHTS_NAME
34+
from .onnx_utils import ONNX_WEIGHTS_NAME
35+
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
36+
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging
3437

3538

3639
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -285,6 +288,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
285288
# 1. Download the checkpoints and configs
286289
# use snapshot download here to get it working from from_pretrained
287290
if not os.path.isdir(pretrained_model_name_or_path):
291+
config_dict = cls.get_config_dict(
292+
pretrained_model_name_or_path,
293+
cache_dir=cache_dir,
294+
resume_download=resume_download,
295+
proxies=proxies,
296+
local_files_only=local_files_only,
297+
use_auth_token=use_auth_token,
298+
revision=revision,
299+
)
300+
# make sure we only download sub-folders and `diffusers` filenames
301+
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
302+
allow_patterns = [os.path.join(k, "*") for k in folder_names]
303+
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
304+
305+
# download all allow_patterns
288306
cached_folder = snapshot_download(
289307
pretrained_model_name_or_path,
290308
cache_dir=cache_dir,
@@ -293,6 +311,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
293311
local_files_only=local_files_only,
294312
use_auth_token=use_auth_token,
295313
revision=revision,
314+
allow_patterns=allow_patterns,
296315
)
297316
else:
298317
cached_folder = pretrained_model_name_or_path

tests/test_pipelines.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import gc
17+
import os
1718
import random
1819
import tempfile
1920
import unittest
@@ -45,8 +46,11 @@
4546
UNet2DModel,
4647
VQModel,
4748
)
49+
from diffusers.modeling_utils import WEIGHTS_NAME
4850
from diffusers.pipeline_utils import DiffusionPipeline
51+
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
4952
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
53+
from diffusers.utils import CONFIG_NAME
5054
from PIL import Image
5155
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
5256

@@ -707,6 +711,27 @@ def tearDown(self):
707711
gc.collect()
708712
torch.cuda.empty_cache()
709713

714+
def test_smart_download(self):
715+
model_id = "hf-internal-testing/unet-pipeline-dummy"
716+
with tempfile.TemporaryDirectory() as tmpdirname:
717+
_ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True)
718+
local_repo_name = "--".join(["models"] + model_id.split("/"))
719+
snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots")
720+
snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0])
721+
722+
# inspect all downloaded files to make sure that everything is included
723+
assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name))
724+
assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME))
725+
assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME))
726+
assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME))
727+
assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME))
728+
assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
729+
assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME))
730+
# let's make sure the super large numpy file:
731+
# https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy
732+
# is not downloaded, but all the expected ones
733+
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
734+
710735
@property
711736
def dummy_safety_checker(self):
712737
def check(images, *args, **kwargs):

0 commit comments

Comments
 (0)