Skip to content

Commit f71023c

Browse files
shirayuPrathik Rao
authored andcommitted
Avoid negative strides for tensors (huggingface#717)
* Avoid negative strides for tensors * Changed not to make torch.tensor * Removed a needless copy
1 parent 97c1c1c commit f71023c

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,9 @@ def __call__(
304304
latents = init_latents
305305

306306
t_start = max(num_inference_steps - init_timestep + offset, 0)
307+
timesteps = self.scheduler.timesteps[t_start:]
307308

308-
# Some schedulers like PNDM have timesteps as arrays
309-
# It's more optimzed to move all timesteps to correct device beforehand
310-
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
311-
312-
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
309+
for i, t in enumerate(self.progress_bar(timesteps)):
313310
t_index = t_start + i
314311

315312
# expand the latents if we are doing classifier free guidance

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,9 @@ def __call__(
342342
latents = init_latents
343343

344344
t_start = max(num_inference_steps - init_timestep + offset, 0)
345+
timesteps = self.scheduler.timesteps[t_start:]
345346

346-
# Some schedulers like PNDM have timesteps as arrays
347-
# It's more optimzed to move all timesteps to correct device beforehand
348-
timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
349-
350-
for i, t in tqdm(enumerate(timesteps_tensor)):
347+
for i, t in tqdm(enumerate(timesteps)):
351348
t_index = t_start + i
352349
# expand the latents if we are doing classifier free guidance
353350
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

0 commit comments

Comments
 (0)