Skip to content

Commit 726aba0

Browse files
[Pytorch] pytorch only timesteps (#724)
* pytorch timesteps * style * get rid of if-else * fix test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 60c9634 commit 726aba0

12 files changed

+42
-32
lines changed

docs/source/api/schedulers.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
3636
To this end, the design of schedulers is such that:
3737

3838
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
39-
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
39+
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
4040

4141

4242
## API

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,8 @@ def __call__(
278278
self.scheduler.set_timesteps(num_inference_steps)
279279

280280
# Some schedulers like PNDM have timesteps as arrays
281-
# It's more optimzed to move all timesteps to correct device beforehand
282-
if torch.is_tensor(self.scheduler.timesteps):
283-
timesteps_tensor = self.scheduler.timesteps.to(self.device)
284-
else:
285-
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
281+
# It's more optimized to move all timesteps to correct device beforehand
282+
timesteps_tensor = self.scheduler.timesteps.to(self.device)
286283

287284
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
288285
if isinstance(self.scheduler, LMSDiscreteScheduler):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,10 @@ def __call__(
304304
latents = init_latents
305305

306306
t_start = max(num_inference_steps - init_timestep + offset, 0)
307-
timesteps = self.scheduler.timesteps[t_start:]
307+
308+
# Some schedulers like PNDM have timesteps as arrays
309+
# It's more optimized to move all timesteps to correct device beforehand
310+
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
308311

309312
for i, t in enumerate(self.progress_bar(timesteps)):
310313
t_index = t_start + i

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,10 @@ def __call__(
342342
latents = init_latents
343343

344344
t_start = max(num_inference_steps - init_timestep + offset, 0)
345-
timesteps = self.scheduler.timesteps[t_start:]
345+
346+
# Some schedulers like PNDM have timesteps as arrays
347+
# It's more optimized to move all timesteps to correct device beforehand
348+
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
346349

347350
for i, t in tqdm(enumerate(timesteps)):
348351
t_index = t_start + i

src/diffusers/schedulers/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
44
- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
5-
- Schedulers are available in numpy, but can easily be transformed into PyTorch.
5+
- Schedulers are available in PyTorch and Jax.
66

77
## API
88

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154

155155
# setable values
156156
self.num_inference_steps = None
157-
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
157+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
158158

159159
def _get_variance(self, timestep, prev_timestep):
160160
alpha_prod_t = self.alphas_cumprod[timestep]
@@ -166,7 +166,7 @@ def _get_variance(self, timestep, prev_timestep):
166166

167167
return variance
168168

169-
def set_timesteps(self, num_inference_steps: int, **kwargs):
169+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
170170
"""
171171
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
172172
@@ -183,7 +183,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
183183
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
184184
# creates integer timesteps by multiplying by ratio
185185
# casting to int to avoid issues when num_inference_step is power of 3
186-
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
186+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
187+
self.timesteps = torch.from_numpy(timesteps).to(device)
187188
self.timesteps += offset
188189

189190
def step(

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ def __init__(
142142

143143
# setable values
144144
self.num_inference_steps = None
145-
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
145+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
146146

147147
self.variance_type = variance_type
148148

149-
def set_timesteps(self, num_inference_steps: int):
149+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
150150
"""
151151
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
152152
@@ -156,9 +156,10 @@ def set_timesteps(self, num_inference_steps: int):
156156
"""
157157
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
158158
self.num_inference_steps = num_inference_steps
159-
self.timesteps = np.arange(
159+
timesteps = np.arange(
160160
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
161-
)[::-1]
161+
)[::-1].copy()
162+
self.timesteps = torch.from_numpy(timesteps).to(device)
162163

163164
def _get_variance(self, t, predicted_variance=None, variance_type=None):
164165
alpha_prod_t = self.alphas_cumprod[t]

src/diffusers/schedulers/scheduling_karras_ve.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ def __init__(
9797

9898
# setable values
9999
self.num_inference_steps: int = None
100-
self.timesteps: np.ndarray = None
100+
self.timesteps: np.IntTensor = None
101101
self.schedule: torch.FloatTensor = None # sigma(t_i)
102102

103-
def set_timesteps(self, num_inference_steps: int):
103+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
104104
"""
105105
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
106106
@@ -110,15 +110,16 @@ def set_timesteps(self, num_inference_steps: int):
110110
111111
"""
112112
self.num_inference_steps = num_inference_steps
113-
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
113+
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
114+
self.timesteps = torch.from_numpy(timesteps).to(device)
114115
schedule = [
115116
(
116117
self.config.sigma_max**2
117118
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
118119
)
119120
for i in self.timesteps
120121
]
121-
self.schedule = torch.tensor(schedule, dtype=torch.float32)
122+
self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
122123

123124
def add_noise_to_input(
124125
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
self.plms_timesteps = None
148148
self.timesteps = None
149149

150-
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
150+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
151151
"""
152152
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
153153
@@ -184,7 +184,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor
184184
::-1
185185
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
186186

187-
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
187+
timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
188+
self.timesteps = torch.from_numpy(timesteps).to(device)
188189

189190
self.ets = []
190191
self.counter = 0

src/diffusers/schedulers/scheduling_sde_ve.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def __init__(
8989

9090
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
9191

92-
def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
92+
def set_timesteps(
93+
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
94+
):
9395
"""
9496
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
9597
@@ -101,7 +103,7 @@ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
101103
"""
102104
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
103105

104-
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
106+
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
105107

106108
def set_sigmas(
107109
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None

src/diffusers/schedulers/scheduling_sde_vp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
1616

17-
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
18-
1917
import math
18+
from typing import Union
2019

2120
import torch
2221

@@ -52,8 +51,8 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling
5251
self.discrete_sigmas = None
5352
self.timesteps = None
5453

55-
def set_timesteps(self, num_inference_steps):
56-
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
54+
def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
55+
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
5756

5857
def step_pred(self, score, x, t, generator=None):
5958
if self.timesteps is None:

tests/test_scheduler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def test_steps_offset(self):
354354
scheduler_config = self.get_scheduler_config(steps_offset=1)
355355
scheduler = scheduler_class(**scheduler_config)
356356
scheduler.set_timesteps(5)
357-
assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all()
357+
assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1]))
358358

359359
def test_betas(self):
360360
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
@@ -568,10 +568,12 @@ def test_steps_offset(self):
568568
scheduler_config = self.get_scheduler_config(steps_offset=1)
569569
scheduler = scheduler_class(**scheduler_config)
570570
scheduler.set_timesteps(10)
571-
assert np.equal(
571+
assert torch.equal(
572572
scheduler.timesteps,
573-
np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
574-
).all()
573+
torch.LongTensor(
574+
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
575+
),
576+
)
575577

576578
def test_betas(self):
577579
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):

0 commit comments

Comments
 (0)