Skip to content

Commit fa7b906

Browse files
patrickvonplatenpcuencaanton-l
authored andcommitted
[Refactor] Remove set_seed (#289)
* [Refactor] Remove set_seed and class attributes * apply anton's suggestiosn * fix * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up * update * make style * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> * make fix-copies * make style * make style and new copies Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
1 parent eb311b4 commit fa7b906

File tree

7 files changed

+25
-16
lines changed

7 files changed

+25
-16
lines changed

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
5656
model_output = self.unet(image, t)["sample"]
5757

5858
# 2. compute previous image: x_t -> t_t-1
59-
image = self.scheduler.step(model_output, t, image)["prev_sample"]
59+
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
6060

6161
image = (image / 2 + 0.5).clamp(0, 1)
6262
image = image.cpu().permute(0, 2, 3, 1).numpy()

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu
3030

3131
model = self.unet
3232

33-
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
33+
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
3434
sample = sample.to(self.device)
3535

3636
self.scheduler.set_timesteps(num_inference_steps)
@@ -42,11 +42,11 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu
4242
# correction step
4343
for _ in range(self.scheduler.correct_steps):
4444
model_output = self.unet(sample, sigma_t)["sample"]
45-
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
45+
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
4646

4747
# prediction step
4848
model_output = model(sample, sigma_t)["sample"]
49-
output = self.scheduler.step_pred(model_output, t, sample)
49+
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
5050

5151
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
5252

src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class KarrasVePipeline(DiffusionPipeline):
1919
differential equations." https://arxiv.org/abs/2011.13456
2020
"""
2121

22+
# add type hints for linting
2223
unet: UNet2DModel
2324
scheduler: KarrasVeScheduler
2425

src/diffusers/schedulers/scheduling_sde_ve.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
1616

17-
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
18-
from typing import Union
17+
import warnings
18+
from typing import Optional, Union
1919

2020
import numpy as np
2121
import torch
@@ -99,6 +99,11 @@ def get_adjacent_sigma(self, timesteps, t):
9999
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
100100

101101
def set_seed(self, seed):
102+
warnings.warn(
103+
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
104+
" generator instead.",
105+
DeprecationWarning,
106+
)
102107
tensor_format = getattr(self, "tensor_format", "pt")
103108
if tensor_format == "np":
104109
np.random.seed(seed)
@@ -112,14 +117,14 @@ def step_pred(
112117
model_output: Union[torch.FloatTensor, np.ndarray],
113118
timestep: int,
114119
sample: Union[torch.FloatTensor, np.ndarray],
115-
seed=None,
120+
generator: Optional[torch.Generator] = None,
121+
**kwargs,
116122
):
117123
"""
118124
Predict the sample at the previous timestep by reversing the SDE.
119125
"""
120-
if seed is not None:
121-
self.set_seed(seed)
122-
# TODO(Patrick) non-PyTorch
126+
if "seed" in kwargs and kwargs["seed"] is not None:
127+
self.set_seed(kwargs["seed"])
123128

124129
if self.timesteps is None:
125130
raise ValueError(
@@ -141,7 +146,7 @@ def step_pred(
141146
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
142147

143148
# equation 6: sample noise for the diffusion term of
144-
noise = self.randn_like(sample)
149+
noise = self.randn_like(sample, generator=generator)
145150
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
146151
# TODO is the variable diffusion the correct scaling term for the noise?
147152
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
@@ -152,14 +157,15 @@ def step_correct(
152157
self,
153158
model_output: Union[torch.FloatTensor, np.ndarray],
154159
sample: Union[torch.FloatTensor, np.ndarray],
155-
seed=None,
160+
generator: Optional[torch.Generator] = None,
161+
**kwargs,
156162
):
157163
"""
158164
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
159165
after making the prediction for the previous timestep.
160166
"""
161-
if seed is not None:
162-
self.set_seed(seed)
167+
if "seed" in kwargs and kwargs["seed"] is not None:
168+
self.set_seed(kwargs["seed"])
163169

164170
if self.timesteps is None:
165171
raise ValueError(
@@ -168,7 +174,7 @@ def step_correct(
168174

169175
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
170176
# sample noise for correction
171-
noise = self.randn_like(sample)
177+
noise = self.randn_like(sample, generator=generator)
172178

173179
# compute step size from the model_output, the noise, and the snr
174180
grad_norm = self.norm(model_output)

src/diffusers/utils/dummy_scipy_objects.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# This file is autogenerated by the command `make fix-copies`, do not edit.
22
# flake8: noqa
3+
34
from ..utils import DummyObject, requires_backends
45

56

src/diffusers/utils/dummy_transformers_objects.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# This file is autogenerated by the command `make fix-copies`, do not edit.
22
# flake8: noqa
3+
34
from ..utils import DummyObject, requires_backends
45

56

utils/check_dummies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def create_dummy_files():
107107
for backend, objects in backend_specific_objects.items():
108108
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
109109
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
110-
dummy_file += "# flake8: noqa\n"
110+
dummy_file += "# flake8: noqa\n\n"
111111
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
112112
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
113113
dummy_files[backend] = dummy_file

0 commit comments

Comments
 (0)