|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 | import gc
|
| 17 | +import os |
17 | 18 | import random
|
18 | 19 | import tempfile
|
19 | 20 | import unittest
|
|
45 | 46 | UNet2DModel,
|
46 | 47 | VQModel,
|
47 | 48 | )
|
| 49 | +from diffusers.modeling_utils import WEIGHTS_NAME |
48 | 50 | from diffusers.pipeline_utils import DiffusionPipeline
|
| 51 | +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME |
49 | 52 | from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
|
| 53 | +from diffusers.utils import CONFIG_NAME |
50 | 54 | from PIL import Image
|
51 | 55 | from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
52 | 56 |
|
@@ -707,6 +711,27 @@ def tearDown(self):
|
707 | 711 | gc.collect()
|
708 | 712 | torch.cuda.empty_cache()
|
709 | 713 |
|
| 714 | + def test_smart_download(self): |
| 715 | + model_id = "hf-internal-testing/unet-pipeline-dummy" |
| 716 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 717 | + _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True) |
| 718 | + local_repo_name = "--".join(["models"] + model_id.split("/")) |
| 719 | + snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") |
| 720 | + snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) |
| 721 | + |
| 722 | + # inspect all downloaded files to make sure that everything is included |
| 723 | + assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name)) |
| 724 | + assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME)) |
| 725 | + assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME)) |
| 726 | + assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME)) |
| 727 | + assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME)) |
| 728 | + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) |
| 729 | + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) |
| 730 | + # let's make sure the super large numpy file: |
| 731 | + # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy |
| 732 | + # is not downloaded, but all the expected ones |
| 733 | + assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) |
| 734 | + |
710 | 735 | @property
|
711 | 736 | def dummy_safety_checker(self):
|
712 | 737 | def check(images, *args, **kwargs):
|
|
0 commit comments