Skip to content

[Refactor] Remove set_seed #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
model_output = self.unet(image, t)["sample"]

# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image)["prev_sample"]
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DDPM scheduler is also stochastic


image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu

model = self.unet

sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
sample = sample.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
Expand All @@ -42,11 +42,11 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu
# correction step
for _ in range(self.scheduler.correct_steps):
model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's pass the generator here


# prediction step
model_output = model(sample, sigma_t)["sample"]
output = self.scheduler.step_pred(model_output, t, sample)
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)

sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class KarrasVePipeline(DiffusionPipeline):
differential equations." https://arxiv.org/abs/2011.13456
"""

# add type hints for linting
unet: UNet2DModel
scheduler: KarrasVeScheduler

Expand Down
26 changes: 16 additions & 10 deletions src/diffusers/schedulers/scheduling_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

import warnings

# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
from typing import Union
from typing import Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -98,6 +100,9 @@ def get_adjacent_sigma(self, timesteps, t):
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

def set_seed(self, seed):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not work with set_seed(...) here as it goes against the generator design. IMO passing generators around is the correct thing to do here because:

  • We cannot really pass seeds around and globally set manual_seed(...) every time:
    • It's not a great design to retrieve the current seed from a generator and then pass this -> better to pass generator directly
    • Flax passes PNRG keys around that can be split -> we cannot split PyTorch seeds in the same way and we cannot pass a seed into a torch.randn(...) function -> so let's go for the generator here

warnings.warn(
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a generator instead.", DeprecationWarning
)
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
np.random.seed(seed)
Expand All @@ -111,14 +116,14 @@ def step_pred(
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
generator: Optional[torch.Generator] = None,
**kwargs,
):
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
if seed is not None:
self.set_seed(seed)
# TODO(Patrick) non-PyTorch
if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"])

if self.timesteps is None:
raise ValueError(
Expand All @@ -140,7 +145,7 @@ def step_pred(
drift = drift - diffusion[:, None, None, None] ** 2 * model_output

# equation 6: sample noise for the diffusion term of
noise = self.randn_like(sample)
noise = self.randn_like(sample, generator=generator)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
Expand All @@ -151,14 +156,15 @@ def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
generator: Optional[torch.Generator] = None,
**kwargs,
):
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
"""
if seed is not None:
self.set_seed(seed)
if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"])

if self.timesteps is None:
raise ValueError(
Expand All @@ -167,7 +173,7 @@ def step_correct(

# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = self.randn_like(sample)
noise = self.randn_like(sample, generator=generator)

# compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(model_output)
Expand Down