-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Refactor] Remove set_seed #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a733d72
4a559e9
9aacfdd
3d093c9
59d1e4c
f223486
0d1ca30
27be6fc
109c90f
6a905e4
92829b3
8d13b9e
761130a
a37da3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu | |
|
||
model = self.unet | ||
|
||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max | ||
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max | ||
sample = sample.to(self.device) | ||
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
@@ -42,11 +42,11 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu | |
# correction step | ||
for _ in range(self.scheduler.correct_steps): | ||
model_output = self.unet(sample, sigma_t)["sample"] | ||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] | ||
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's pass the generator here |
||
|
||
# prediction step | ||
model_output = model(sample, sigma_t)["sample"] | ||
output = self.scheduler.step_pred(model_output, t, sample) | ||
output = self.scheduler.step_pred(model_output, t, sample, generator=generator) | ||
|
||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,8 @@ | |
|
||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch | ||
|
||
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit | ||
from typing import Union | ||
import warnings | ||
from typing import Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -98,6 +98,11 @@ def get_adjacent_sigma(self, timesteps, t): | |
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") | ||
|
||
def set_seed(self, seed): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not work with
|
||
warnings.warn( | ||
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a" | ||
" generator instead.", | ||
DeprecationWarning, | ||
) | ||
tensor_format = getattr(self, "tensor_format", "pt") | ||
if tensor_format == "np": | ||
np.random.seed(seed) | ||
|
@@ -111,14 +116,14 @@ def step_pred( | |
model_output: Union[torch.FloatTensor, np.ndarray], | ||
timestep: int, | ||
sample: Union[torch.FloatTensor, np.ndarray], | ||
seed=None, | ||
generator: Optional[torch.Generator] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Predict the sample at the previous timestep by reversing the SDE. | ||
""" | ||
if seed is not None: | ||
self.set_seed(seed) | ||
# TODO(Patrick) non-PyTorch | ||
if "seed" in kwargs and kwargs["seed"] is not None: | ||
self.set_seed(kwargs["seed"]) | ||
|
||
if self.timesteps is None: | ||
raise ValueError( | ||
|
@@ -140,7 +145,7 @@ def step_pred( | |
drift = drift - diffusion[:, None, None, None] ** 2 * model_output | ||
|
||
# equation 6: sample noise for the diffusion term of | ||
noise = self.randn_like(sample) | ||
noise = self.randn_like(sample, generator=generator) | ||
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep | ||
# TODO is the variable diffusion the correct scaling term for the noise? | ||
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g | ||
|
@@ -151,14 +156,15 @@ def step_correct( | |
self, | ||
model_output: Union[torch.FloatTensor, np.ndarray], | ||
sample: Union[torch.FloatTensor, np.ndarray], | ||
seed=None, | ||
generator: Optional[torch.Generator] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly | ||
after making the prediction for the previous timestep. | ||
""" | ||
if seed is not None: | ||
self.set_seed(seed) | ||
if "seed" in kwargs and kwargs["seed"] is not None: | ||
self.set_seed(kwargs["seed"]) | ||
|
||
if self.timesteps is None: | ||
raise ValueError( | ||
|
@@ -167,7 +173,7 @@ def step_correct( | |
|
||
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" | ||
# sample noise for correction | ||
noise = self.randn_like(sample) | ||
noise = self.randn_like(sample, generator=generator) | ||
|
||
# compute step size from the model_output, the noise, and the snr | ||
grad_norm = self.norm(model_output) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DDPM scheduler is also stochastic