From d1747d6b024b974ef0f0ffeb16c03deba3c03af0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Sep 2022 17:43:16 +0200 Subject: [PATCH 1/2] add dep. warning for schedulers --- src/diffusers/schedulers/scheduling_ddim.py | 8 ++++++++ src/diffusers/schedulers/scheduling_ddpm.py | 9 +++++++++ src/diffusers/schedulers/scheduling_karras_ve.py | 9 +++++++++ src/diffusers/schedulers/scheduling_lms_discrete.py | 9 +++++++++ src/diffusers/schedulers/scheduling_pndm.py | 8 ++++++++ src/diffusers/schedulers/scheduling_sde_ve.py | 8 ++++++++ src/diffusers/schedulers/scheduling_sde_vp.py | 9 ++++++++- src/diffusers/schedulers/scheduling_utils.py | 10 ++++++++++ 8 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 6880700ecef0..31c48bb7569f 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -120,7 +120,15 @@ def __init__( clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 0383dea224c7..af71653c1be1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -112,7 +113,15 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 5826858faee4..e6e5300e73e7 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -13,6 +13,7 @@ # limitations under the License. +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -86,7 +87,15 @@ def __init__( s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + # setable values self.num_inference_steps: int = None self.timesteps: np.ndarray = None diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 6167af5ad42b..6d8db7682db5 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -74,7 +75,15 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 1935a6ef93f2..37f5d349cee2 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -100,7 +100,15 @@ def __init__( skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 7b06ae16c5e9..a549654c3b6f 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -76,7 +76,15 @@ def __init__( sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + # setable values self.timesteps = None diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 2f9821579c52..daea743873f1 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -17,6 +17,7 @@ # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit import math +import warnings import torch @@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ @register_to_config - def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) self.sigmas = None self.discrete_sigmas = None self.timesteps = None diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 29bf982f6adf..ffd489d5c118 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass import torch @@ -41,3 +42,12 @@ class SchedulerMixin: """ config_name = SCHEDULER_CONFIG_NAME + + def set_format(self, tensor_format="pt"): + warnings.warn( + "The method `set_format` is deprecated and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this function as the schedulers", + "are always in Pytorch", + DeprecationWarning, + ) + return self From b830eceec5586a6d5fa5a462b07596ca44716f14 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 27 Sep 2022 17:45:43 +0200 Subject: [PATCH 2/2] fix format --- src/diffusers/schedulers/scheduling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index ffd489d5c118..1cc1d94414a6 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -46,7 +46,7 @@ class SchedulerMixin: def set_format(self, tensor_format="pt"): warnings.warn( "The method `set_format` is deprecated and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this function as the schedulers", + "If you're running your code in PyTorch, you can safely remove this function as the schedulers" "are always in Pytorch", DeprecationWarning, )