Skip to content

[Pytorch] add dep. warning for pytorch schedulers #651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,15 @@ def __init__(
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -112,7 +113,15 @@ def __init__(
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -86,7 +87,15 @@ def __init__(
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)

# setable values
self.num_inference_steps: int = None
self.timesteps: np.ndarray = None
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -74,7 +75,15 @@ def __init__(
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,15 @@ def __init__(
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/schedulers/scheduling_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,15 @@ def __init__(
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)

# setable values
self.timesteps = None

Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/schedulers/scheduling_sde_vp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit

import math
import warnings

import torch

Expand All @@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""

@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/schedulers/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -41,3 +42,12 @@ class SchedulerMixin:
"""

config_name = SCHEDULER_CONFIG_NAME

def set_format(self, tensor_format="pt"):
warnings.warn(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch",
DeprecationWarning,
)
return self