From a733d72b50ed46bc69f5dbed4b58769130c83727 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 11:21:04 +0200 Subject: [PATCH 01/11] [Refactor] Remove set_seed and class attributes --- .../score_sde_ve/pipeline_score_sde_ve.py | 6 ++--- .../pipeline_stochastic_karras_ve.py | 5 ---- src/diffusers/schedulers/scheduling_sde_ve.py | 26 ++++++++++++------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 7d72ddf74625..0ab92effbee8 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -30,7 +30,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu model = self.unet - sample = torch.randn(*shape) * self.scheduler.config.sigma_max + sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps) @@ -42,11 +42,11 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu # correction step for _ in range(self.scheduler.correct_steps): model_output = self.unet(sample, sigma_t)["sample"] - sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] + sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"] # prediction step model_output = model(sample, sigma_t)["sample"] - output = self.scheduler.step_pred(model_output, t, sample) + output = self.scheduler.step_pred(model_output, t, sample, generator=generator) sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index 970272999c67..0bb65045e5b1 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -3,9 +3,7 @@ import torch -from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import KarrasVeScheduler class KarrasVePipeline(DiffusionPipeline): @@ -18,9 +16,6 @@ class KarrasVePipeline(DiffusionPipeline): differential equations." https://arxiv.org/abs/2011.13456 """ - unet: UNet2DModel - scheduler: KarrasVeScheduler - def __init__(self, unet, scheduler): super().__init__() scheduler = scheduler.set_format("pt") diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 1d6e05d97feb..5fa86eacebf7 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -14,8 +14,10 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch +import warnings + # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit -from typing import Union +from typing import Optional, Union import numpy as np import torch @@ -98,6 +100,9 @@ def get_adjacent_sigma(self, timesteps, t): raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") def set_seed(self, seed): + warnings.warn( + "The method `set_seed` is outdated. Please consider passing a generator instead.", DeprecationWarning + ) tensor_format = getattr(self, "tensor_format", "pt") if tensor_format == "np": np.random.seed(seed) @@ -111,14 +116,14 @@ def step_pred( model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - seed=None, + generator: Optional[torch.Generator] = None, + **kwargs, ): """ Predict the sample at the previous timestep by reversing the SDE. """ - if seed is not None: - self.set_seed(seed) - # TODO(Patrick) non-PyTorch + if "seed" in kwargs and kwargs["seed"] is not None: + self.set_seed(kwargs["seed"]) if self.timesteps is None: raise ValueError( @@ -140,7 +145,7 @@ def step_pred( drift = drift - diffusion[:, None, None, None] ** 2 * model_output # equation 6: sample noise for the diffusion term of - noise = self.randn_like(sample) + noise = self.randn_like(sample, generator=generator) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g @@ -151,14 +156,15 @@ def step_correct( self, model_output: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray], - seed=None, + generator: Optional[torch.Generator] = None, + **kwargs, ): """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. """ - if seed is not None: - self.set_seed(seed) + if "seed" in kwargs and kwargs["seed"] is not None: + self.set_seed(kwargs["seed"]) if self.timesteps is None: raise ValueError( @@ -167,7 +173,7 @@ def step_correct( # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction - noise = self.randn_like(sample) + noise = self.randn_like(sample, generator=generator) # compute step size from the model_output, the noise, and the snr grad_norm = self.norm(model_output) From 4a559e9329d2c1f3bc6c553cab1e2a7c5cc1c3e2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 10:39:51 +0000 Subject: [PATCH 02/11] apply anton's suggestiosn --- .../pipelines/score_sde_ve/pipeline_score_sde_ve.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 0ab92effbee8..ce8a87983450 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -5,8 +5,17 @@ from diffusers import DiffusionPipeline +from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasVeScheduler + class ScoreSdeVePipeline(DiffusionPipeline): + + # add type hints for linting + unet: UNet2DModel + scheduler: KarrasVeScheduler + def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) From 9aacfddd891704795c717bea0e1d2fcd9fbb7066 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 10:41:13 +0000 Subject: [PATCH 03/11] fix --- .../pipelines/score_sde_ve/pipeline_score_sde_ve.py | 9 --------- .../stochatic_karras_ve/pipeline_stochastic_karras_ve.py | 6 ++++++ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index ce8a87983450..0ab92effbee8 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -5,17 +5,8 @@ from diffusers import DiffusionPipeline -from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline -from ...schedulers import KarrasVeScheduler - class ScoreSdeVePipeline(DiffusionPipeline): - - # add type hints for linting - unet: UNet2DModel - scheduler: KarrasVeScheduler - def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index 0bb65045e5b1..211019c59245 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -3,7 +3,9 @@ import torch +from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasVeScheduler class KarrasVePipeline(DiffusionPipeline): @@ -16,6 +18,10 @@ class KarrasVePipeline(DiffusionPipeline): differential equations." https://arxiv.org/abs/2011.13456 """ + # add type hints for linting + unet: UNet2DModel + scheduler: KarrasVeScheduler + def __init__(self, unet, scheduler): super().__init__() scheduler = scheduler.set_format("pt") From 3d093c9a146306dc07fa672834e587a5116e80fc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 13:30:36 +0200 Subject: [PATCH 04/11] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/schedulers/scheduling_sde_ve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 5fa86eacebf7..147108ef5ba0 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -101,7 +101,7 @@ def get_adjacent_sigma(self, timesteps, t): def set_seed(self, seed): warnings.warn( - "The method `set_seed` is outdated. Please consider passing a generator instead.", DeprecationWarning + "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a generator instead.", DeprecationWarning ) tensor_format = getattr(self, "tensor_format", "pt") if tensor_format == "np": From 59d1e4c97bdc7891fbaed9d3b60ca94a9ca74af0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 11:35:46 +0000 Subject: [PATCH 05/11] up --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 27c156def843..5d735a3901ea 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -56,7 +56,7 @@ def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs): model_output = self.unet(image, t)["sample"] # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step(model_output, t, image)["prev_sample"] + image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() From 0d1ca30b09245e4b7cf493fcb844997b069b6318 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 18:24:11 +0200 Subject: [PATCH 06/11] update --- src/diffusers/utils/dummy_transformers_objects.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_transformers_objects.py index 753e3fdbe291..dc929427221a 100644 --- a/src/diffusers/utils/dummy_transformers_objects.py +++ b/src/diffusers/utils/dummy_transformers_objects.py @@ -1,3 +1,4 @@ # This file is autogenerated by the command `make fix-copies`, do not edit. # flake8: noqa from ..utils import DummyObject, requires_backends + From 109c90fcada61b717fc5457ce22d765c6a551e63 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 18:24:49 +0200 Subject: [PATCH 07/11] make style --- src/diffusers/schedulers/scheduling_sde_ve.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 147108ef5ba0..924dd8599e73 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -101,7 +101,9 @@ def get_adjacent_sigma(self, timesteps, t): def set_seed(self, seed): warnings.warn( - "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a generator instead.", DeprecationWarning + "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a" + " generator instead.", + DeprecationWarning, ) tensor_format = getattr(self, "tensor_format", "pt") if tensor_format == "np": From 6a905e473b692da307b309199901b62a8175c7c5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 18:25:09 +0200 Subject: [PATCH 08/11] Apply suggestions from code review Co-authored-by: Anton Lozhkov --- src/diffusers/schedulers/scheduling_sde_ve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 924dd8599e73..44a392e002b6 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -16,7 +16,6 @@ import warnings -# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit from typing import Optional, Union import numpy as np From 92829b364d9643b5b6ae0cc74f8cc6e8b2e05243 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 18:38:20 +0200 Subject: [PATCH 09/11] make fix-copies --- src/diffusers/schedulers/scheduling_sde_ve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 44a392e002b6..e3fec0353dea 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -15,7 +15,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import warnings - from typing import Optional, Union import numpy as np From 8d13b9e93a5a25379e2a522915d7330e2c758e25 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 18:38:42 +0200 Subject: [PATCH 10/11] make style --- src/diffusers/utils/dummy_transformers_objects.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_transformers_objects.py index dc929427221a..753e3fdbe291 100644 --- a/src/diffusers/utils/dummy_transformers_objects.py +++ b/src/diffusers/utils/dummy_transformers_objects.py @@ -1,4 +1,3 @@ # This file is autogenerated by the command `make fix-copies`, do not edit. # flake8: noqa from ..utils import DummyObject, requires_backends - From 761130a61b6af881af028afbf8358d3ddfacc793 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 31 Aug 2022 18:45:02 +0200 Subject: [PATCH 11/11] make style and new copies --- src/diffusers/utils/dummy_scipy_objects.py | 1 + src/diffusers/utils/dummy_transformers_objects.py | 1 + utils/check_dummies.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/dummy_scipy_objects.py b/src/diffusers/utils/dummy_scipy_objects.py index 889baf67a8a2..3706c57541c1 100644 --- a/src/diffusers/utils/dummy_scipy_objects.py +++ b/src/diffusers/utils/dummy_scipy_objects.py @@ -1,5 +1,6 @@ # This file is autogenerated by the command `make fix-copies`, do not edit. # flake8: noqa + from ..utils import DummyObject, requires_backends diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_transformers_objects.py index 753e3fdbe291..b69af5613977 100644 --- a/src/diffusers/utils/dummy_transformers_objects.py +++ b/src/diffusers/utils/dummy_transformers_objects.py @@ -1,3 +1,4 @@ # This file is autogenerated by the command `make fix-copies`, do not edit. # flake8: noqa + from ..utils import DummyObject, requires_backends diff --git a/utils/check_dummies.py b/utils/check_dummies.py index f9a45284f3fe..2b426048f26c 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -107,7 +107,7 @@ def create_dummy_files(): for backend, objects in backend_specific_objects.items(): backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]" dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" - dummy_file += "# flake8: noqa\n" + dummy_file += "# flake8: noqa\n\n" dummy_file += "from ..utils import DummyObject, requires_backends\n\n" dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) dummy_files[backend] = dummy_file