Skip to content

Commit a20a0e6

Browse files
patrickvonplatenpcuencaanton-l
authored
[Refactor] Remove set_seed (huggingface#289)
* [Refactor] Remove set_seed and class attributes * apply anton's suggestiosn * fix * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up * update * make style * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> * make fix-copies * make style * make style and new copies Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
1 parent 4d2c6fe commit a20a0e6

File tree

6 files changed

+24
-15
lines changed

6 files changed

+24
-15
lines changed

pipelines/ddpm/pipeline_ddpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
5656
model_output = self.unet(image, t)["sample"]
5757

5858
# 2. compute previous image: x_t -> t_t-1
59-
image = self.scheduler.step(model_output, t, image)["prev_sample"]
59+
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
6060

6161
image = (image / 2 + 0.5).clamp(0, 1)
6262
image = image.cpu().permute(0, 2, 3, 1).numpy()

pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu
3030

3131
model = self.unet
3232

33-
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
33+
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
3434
sample = sample.to(self.device)
3535

3636
self.scheduler.set_timesteps(num_inference_steps)
@@ -42,11 +42,11 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu
4242
# correction step
4343
for _ in range(self.scheduler.correct_steps):
4444
model_output = self.unet(sample, sigma_t)["sample"]
45-
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
45+
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
4646

4747
# prediction step
4848
model_output = model(sample, sigma_t)["sample"]
49-
output = self.scheduler.step_pred(model_output, t, sample)
49+
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
5050

5151
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
5252

pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class KarrasVePipeline(DiffusionPipeline):
1919
differential equations." https://arxiv.org/abs/2011.13456
2020
"""
2121

22+
# add type hints for linting
2223
unet: UNet2DModel
2324
scheduler: KarrasVeScheduler
2425

schedulers/scheduling_sde_ve.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
1616

17-
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
18-
from typing import Union
17+
import warnings
18+
from typing import Optional, Union
1919

2020
import numpy as np
2121
import torch
@@ -98,6 +98,11 @@ def get_adjacent_sigma(self, timesteps, t):
9898
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
9999

100100
def set_seed(self, seed):
101+
warnings.warn(
102+
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
103+
" generator instead.",
104+
DeprecationWarning,
105+
)
101106
tensor_format = getattr(self, "tensor_format", "pt")
102107
if tensor_format == "np":
103108
np.random.seed(seed)
@@ -111,14 +116,14 @@ def step_pred(
111116
model_output: Union[torch.FloatTensor, np.ndarray],
112117
timestep: int,
113118
sample: Union[torch.FloatTensor, np.ndarray],
114-
seed=None,
119+
generator: Optional[torch.Generator] = None,
120+
**kwargs,
115121
):
116122
"""
117123
Predict the sample at the previous timestep by reversing the SDE.
118124
"""
119-
if seed is not None:
120-
self.set_seed(seed)
121-
# TODO(Patrick) non-PyTorch
125+
if "seed" in kwargs and kwargs["seed"] is not None:
126+
self.set_seed(kwargs["seed"])
122127

123128
if self.timesteps is None:
124129
raise ValueError(
@@ -140,7 +145,7 @@ def step_pred(
140145
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
141146

142147
# equation 6: sample noise for the diffusion term of
143-
noise = self.randn_like(sample)
148+
noise = self.randn_like(sample, generator=generator)
144149
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
145150
# TODO is the variable diffusion the correct scaling term for the noise?
146151
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
@@ -151,14 +156,15 @@ def step_correct(
151156
self,
152157
model_output: Union[torch.FloatTensor, np.ndarray],
153158
sample: Union[torch.FloatTensor, np.ndarray],
154-
seed=None,
159+
generator: Optional[torch.Generator] = None,
160+
**kwargs,
155161
):
156162
"""
157163
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
158164
after making the prediction for the previous timestep.
159165
"""
160-
if seed is not None:
161-
self.set_seed(seed)
166+
if "seed" in kwargs and kwargs["seed"] is not None:
167+
self.set_seed(kwargs["seed"])
162168

163169
if self.timesteps is None:
164170
raise ValueError(
@@ -167,7 +173,7 @@ def step_correct(
167173

168174
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
169175
# sample noise for correction
170-
noise = self.randn_like(sample)
176+
noise = self.randn_like(sample, generator=generator)
171177

172178
# compute step size from the model_output, the noise, and the snr
173179
grad_norm = self.norm(model_output)

utils/dummy_scipy_objects.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# This file is autogenerated by the command `make fix-copies`, do not edit.
22
# flake8: noqa
3+
34
from ..utils import DummyObject, requires_backends
45

56

utils/dummy_transformers_objects.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# This file is autogenerated by the command `make fix-copies`, do not edit.
22
# flake8: noqa
3+
34
from ..utils import DummyObject, requires_backends
45

56

0 commit comments

Comments
 (0)