Skip to content

Commit 7f567ec

Browse files
committed
Changed not to make torch.tensor
1 parent 1beb12d commit 7f567ec

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -281,15 +281,9 @@ def __call__(
281281
latents = init_latents
282282

283283
t_start = max(num_inference_steps - init_timestep + offset, 0)
284+
timesteps = self.scheduler.timesteps[t_start:]
284285

285-
# Some schedulers like PNDM have timesteps as arrays
286-
# It's more optimzed to move all timesteps to correct device beforehand
287-
if torch.is_tensor(self.scheduler.timesteps):
288-
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
289-
else:
290-
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy()[t_start:], device=self.device)
291-
292-
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
286+
for i, t in enumerate(self.progress_bar(timesteps)):
293287
t_index = t_start + i
294288

295289
# expand the latents if we are doing classifier free guidance

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,15 +315,9 @@ def __call__(
315315
latents = init_latents
316316

317317
t_start = max(num_inference_steps - init_timestep + offset, 0)
318+
timesteps = self.scheduler.timesteps.copy()[t_start:]
318319

319-
# Some schedulers like PNDM have timesteps as arrays
320-
# It's more optimzed to move all timesteps to correct device beforehand
321-
if torch.is_tensor(self.scheduler.timesteps):
322-
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
323-
else:
324-
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy()[t_start:], device=self.device)
325-
326-
for i, t in tqdm(enumerate(timesteps_tensor)):
320+
for i, t in tqdm(enumerate(timesteps)):
327321
t_index = t_start + i
328322
# expand the latents if we are doing classifier free guidance
329323
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

0 commit comments

Comments
 (0)