Skip to content

Commit 07bd2fa

Browse files
authored
make controlnet support interrupt (#9620)
* make controlnet support interrupt * remove white space in controlnet interrupt
1 parent af28ae2 commit 07bd2fa

File tree

6 files changed

+48
-0
lines changed

6 files changed

+48
-0
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,10 @@ def cross_attention_kwargs(self):
893893
def num_timesteps(self):
894894
return self._num_timesteps
895895

896+
@property
897+
def interrupt(self):
898+
return self._interrupt
899+
896900
@torch.no_grad()
897901
@replace_example_docstring(EXAMPLE_DOC_STRING)
898902
def __call__(
@@ -1089,6 +1093,7 @@ def __call__(
10891093
self._guidance_scale = guidance_scale
10901094
self._clip_skip = clip_skip
10911095
self._cross_attention_kwargs = cross_attention_kwargs
1096+
self._interrupt = False
10921097

10931098
# 2. Define call parameters
10941099
if prompt is not None and isinstance(prompt, str):
@@ -1235,6 +1240,9 @@ def __call__(
12351240
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
12361241
with self.progress_bar(total=num_inference_steps) as progress_bar:
12371242
for i, t in enumerate(timesteps):
1243+
if self.interrupt:
1244+
continue
1245+
12381246
# Relevant thread:
12391247
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
12401248
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,10 @@ def cross_attention_kwargs(self):
891891
def num_timesteps(self):
892892
return self._num_timesteps
893893

894+
@property
895+
def interrupt(self):
896+
return self._interrupt
897+
894898
@torch.no_grad()
895899
@replace_example_docstring(EXAMPLE_DOC_STRING)
896900
def __call__(
@@ -1081,6 +1085,7 @@ def __call__(
10811085
self._guidance_scale = guidance_scale
10821086
self._clip_skip = clip_skip
10831087
self._cross_attention_kwargs = cross_attention_kwargs
1088+
self._interrupt = False
10841089

10851090
# 2. Define call parameters
10861091
if prompt is not None and isinstance(prompt, str):
@@ -1211,6 +1216,9 @@ def __call__(
12111216
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
12121217
with self.progress_bar(total=num_inference_steps) as progress_bar:
12131218
for i, t in enumerate(timesteps):
1219+
if self.interrupt:
1220+
continue
1221+
12141222
# expand the latents if we are doing classifier free guidance
12151223
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
12161224
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,10 @@ def cross_attention_kwargs(self):
976976
def num_timesteps(self):
977977
return self._num_timesteps
978978

979+
@property
980+
def interrupt(self):
981+
return self._interrupt
982+
979983
@torch.no_grad()
980984
@replace_example_docstring(EXAMPLE_DOC_STRING)
981985
def __call__(
@@ -1191,6 +1195,7 @@ def __call__(
11911195
self._guidance_scale = guidance_scale
11921196
self._clip_skip = clip_skip
11931197
self._cross_attention_kwargs = cross_attention_kwargs
1198+
self._interrupt = False
11941199

11951200
# 2. Define call parameters
11961201
if prompt is not None and isinstance(prompt, str):
@@ -1375,6 +1380,9 @@ def __call__(
13751380
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
13761381
with self.progress_bar(total=num_inference_steps) as progress_bar:
13771382
for i, t in enumerate(timesteps):
1383+
if self.interrupt:
1384+
continue
1385+
13781386
# expand the latents if we are doing classifier free guidance
13791387
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
13801388
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,10 @@ def cross_attention_kwargs(self):
11451145
def num_timesteps(self):
11461146
return self._num_timesteps
11471147

1148+
@property
1149+
def interrupt(self):
1150+
return self._interrupt
1151+
11481152
@torch.no_grad()
11491153
@replace_example_docstring(EXAMPLE_DOC_STRING)
11501154
def __call__(
@@ -1427,6 +1431,7 @@ def __call__(
14271431
self._guidance_scale = guidance_scale
14281432
self._clip_skip = clip_skip
14291433
self._cross_attention_kwargs = cross_attention_kwargs
1434+
self._interrupt = False
14301435

14311436
# 2. Define call parameters
14321437
if prompt is not None and isinstance(prompt, str):
@@ -1695,6 +1700,9 @@ def denoising_value_valid(dnv):
16951700

16961701
with self.progress_bar(total=num_inference_steps) as progress_bar:
16971702
for i, t in enumerate(timesteps):
1703+
if self.interrupt:
1704+
continue
1705+
16981706
# expand the latents if we are doing classifier free guidance
16991707
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
17001708

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,10 @@ def denoising_end(self):
990990
def num_timesteps(self):
991991
return self._num_timesteps
992992

993+
@property
994+
def interrupt(self):
995+
return self._interrupt
996+
993997
@torch.no_grad()
994998
@replace_example_docstring(EXAMPLE_DOC_STRING)
995999
def __call__(
@@ -1245,6 +1249,7 @@ def __call__(
12451249
self._clip_skip = clip_skip
12461250
self._cross_attention_kwargs = cross_attention_kwargs
12471251
self._denoising_end = denoising_end
1252+
self._interrupt = False
12481253

12491254
# 2. Define call parameters
12501255
if prompt is not None and isinstance(prompt, str):
@@ -1442,6 +1447,9 @@ def __call__(
14421447
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
14431448
with self.progress_bar(total=num_inference_steps) as progress_bar:
14441449
for i, t in enumerate(timesteps):
1450+
if self.interrupt:
1451+
continue
1452+
14451453
# Relevant thread:
14461454
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
14471455
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,10 @@ def cross_attention_kwargs(self):
10701070
def num_timesteps(self):
10711071
return self._num_timesteps
10721072

1073+
@property
1074+
def interrupt(self):
1075+
return self._interrupt
1076+
10731077
@torch.no_grad()
10741078
@replace_example_docstring(EXAMPLE_DOC_STRING)
10751079
def __call__(
@@ -1338,6 +1342,7 @@ def __call__(
13381342
self._guidance_scale = guidance_scale
13391343
self._clip_skip = clip_skip
13401344
self._cross_attention_kwargs = cross_attention_kwargs
1345+
self._interrupt = False
13411346

13421347
# 2. Define call parameters
13431348
if prompt is not None and isinstance(prompt, str):
@@ -1510,6 +1515,9 @@ def __call__(
15101515
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
15111516
with self.progress_bar(total=num_inference_steps) as progress_bar:
15121517
for i, t in enumerate(timesteps):
1518+
if self.interrupt:
1519+
continue
1520+
15131521
# expand the latents if we are doing classifier free guidance
15141522
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
15151523
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

0 commit comments

Comments
 (0)