Skip to content

Commit 8685699

Browse files
author
Mishig Davaadorj
authored
Mv weights name consts to diffusers.utils (#605)
1 parent f810060 commit 8685699

File tree

6 files changed

+16
-17
lines changed

6 files changed

+16
-17
lines changed

src/diffusers/modeling_flax_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,17 @@
2828
from requests import HTTPError
2929

3030
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
31-
from .modeling_utils import WEIGHTS_NAME, load_state_dict
32-
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
31+
from .modeling_utils import load_state_dict
32+
from .utils import (
33+
CONFIG_NAME,
34+
DIFFUSERS_CACHE,
35+
FLAX_WEIGHTS_NAME,
36+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
37+
WEIGHTS_NAME,
38+
logging,
39+
)
3340

3441

35-
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
36-
3742
logger = logging.get_logger(__name__)
3843

3944

src/diffusers/modeling_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@
2424
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2525
from requests import HTTPError
2626

27-
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
28-
29-
30-
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
27+
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
3128

3229

3330
logger = logging.get_logger(__name__)

src/diffusers/onnx_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,13 @@
2424

2525
from huggingface_hub import hf_hub_download
2626

27-
from .utils import is_onnx_available, logging
27+
from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
2828

2929

3030
if is_onnx_available():
3131
import onnxruntime as ort
3232

3333

34-
ONNX_WEIGHTS_NAME = "model.onnx"
35-
36-
3734
logger = logging.get_logger(__name__)
3835

3936

src/diffusers/pipeline_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@
3030
from tqdm.auto import tqdm
3131

3232
from .configuration_utils import ConfigMixin
33-
from .modeling_utils import WEIGHTS_NAME
34-
from .onnx_utils import ONNX_WEIGHTS_NAME
3533
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
36-
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging
34+
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
3735

3836

3937
INDEX_FILE = "diffusion_pytorch_model.bin"

src/diffusers/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747

4848

4949
CONFIG_NAME = "config.json"
50+
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
51+
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
52+
ONNX_WEIGHTS_NAME = "model.onnx"
5053
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
5154
DIFFUSERS_CACHE = default_cache_path
5255
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"

tests/test_pipelines.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,10 @@
4646
UNet2DModel,
4747
VQModel,
4848
)
49-
from diffusers.modeling_utils import WEIGHTS_NAME
5049
from diffusers.pipeline_utils import DiffusionPipeline
5150
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
5251
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
53-
from diffusers.utils import CONFIG_NAME
52+
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME
5453
from PIL import Image
5554
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
5655

0 commit comments

Comments
 (0)