Skip to content

Commit 57a861a

Browse files
kashifPrathik Rao
authored andcommitted
[Pytorch] add dep. warning for pytorch schedulers (huggingface#651)
* add dep. warning for schedulers * fix format
1 parent 24902f8 commit 57a861a

File tree

8 files changed

+69
-1
lines changed

8 files changed

+69
-1
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,15 @@ def __init__(
120120
clip_sample: bool = True,
121121
set_alpha_to_one: bool = True,
122122
steps_offset: int = 0,
123+
**kwargs,
123124
):
125+
if "tensor_format" in kwargs:
126+
warnings.warn(
127+
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
128+
"If you're running your code in PyTorch, you can safely remove this argument.",
129+
DeprecationWarning,
130+
)
131+
124132
if trained_betas is not None:
125133
self.betas = torch.from_numpy(trained_betas)
126134
if beta_schedule == "linear":

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
1616

1717
import math
18+
import warnings
1819
from dataclasses import dataclass
1920
from typing import Optional, Tuple, Union
2021

@@ -112,7 +113,15 @@ def __init__(
112113
trained_betas: Optional[np.ndarray] = None,
113114
variance_type: str = "fixed_small",
114115
clip_sample: bool = True,
116+
**kwargs,
115117
):
118+
if "tensor_format" in kwargs:
119+
warnings.warn(
120+
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
121+
"If you're running your code in PyTorch, you can safely remove this argument.",
122+
DeprecationWarning,
123+
)
124+
116125
if trained_betas is not None:
117126
self.betas = torch.from_numpy(trained_betas)
118127
elif beta_schedule == "linear":

src/diffusers/schedulers/scheduling_karras_ve.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515

16+
import warnings
1617
from dataclasses import dataclass
1718
from typing import Optional, Tuple, Union
1819

@@ -86,7 +87,15 @@ def __init__(
8687
s_churn: float = 80,
8788
s_min: float = 0.05,
8889
s_max: float = 50,
90+
**kwargs,
8991
):
92+
if "tensor_format" in kwargs:
93+
warnings.warn(
94+
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
95+
"If you're running your code in PyTorch, you can safely remove this argument.",
96+
DeprecationWarning,
97+
)
98+
9099
# setable values
91100
self.num_inference_steps: int = None
92101
self.timesteps: np.ndarray = None

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from dataclasses import dataclass
1617
from typing import Optional, Tuple, Union
1718

@@ -74,7 +75,15 @@ def __init__(
7475
beta_end: float = 0.02,
7576
beta_schedule: str = "linear",
7677
trained_betas: Optional[np.ndarray] = None,
78+
**kwargs,
7779
):
80+
if "tensor_format" in kwargs:
81+
warnings.warn(
82+
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
83+
"If you're running your code in PyTorch, you can safely remove this argument.",
84+
DeprecationWarning,
85+
)
86+
7887
if trained_betas is not None:
7988
self.betas = torch.from_numpy(trained_betas)
8089
if beta_schedule == "linear":

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,15 @@ def __init__(
100100
skip_prk_steps: bool = False,
101101
set_alpha_to_one: bool = False,
102102
steps_offset: int = 0,
103+
**kwargs,
103104
):
105+
if "tensor_format" in kwargs:
106+
warnings.warn(
107+
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
108+
"If you're running your code in PyTorch, you can safely remove this argument.",
109+
DeprecationWarning,
110+
)
111+
104112
if trained_betas is not None:
105113
self.betas = torch.from_numpy(trained_betas)
106114
if beta_schedule == "linear":

src/diffusers/schedulers/scheduling_sde_ve.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,15 @@ def __init__(
7676
sigma_max: float = 1348.0,
7777
sampling_eps: float = 1e-5,
7878
correct_steps: int = 1,
79+
**kwargs,
7980
):
81+
if "tensor_format" in kwargs:
82+
warnings.warn(
83+
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
84+
"If you're running your code in PyTorch, you can safely remove this argument.",
85+
DeprecationWarning,
86+
)
87+
8088
# setable values
8189
self.timesteps = None
8290

src/diffusers/schedulers/scheduling_sde_vp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
1818

1919
import math
20+
import warnings
2021

2122
import torch
2223

@@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
4041
"""
4142

4243
@register_to_config
43-
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
44+
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
45+
if "tensor_format" in kwargs:
46+
warnings.warn(
47+
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
48+
"If you're running your code in PyTorch, you can safely remove this argument.",
49+
DeprecationWarning,
50+
)
4451
self.sigmas = None
4552
self.discrete_sigmas = None
4653
self.timesteps = None

src/diffusers/schedulers/scheduling_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415
from dataclasses import dataclass
1516

1617
import torch
@@ -41,3 +42,12 @@ class SchedulerMixin:
4142
"""
4243

4344
config_name = SCHEDULER_CONFIG_NAME
45+
46+
def set_format(self, tensor_format="pt"):
47+
warnings.warn(
48+
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
49+
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
50+
"are always in Pytorch",
51+
DeprecationWarning,
52+
)
53+
return self

0 commit comments

Comments
 (0)