File tree Expand file tree Collapse file tree 6 files changed +48
-0
lines changed
src/diffusers/pipelines/controlnet Expand file tree Collapse file tree 6 files changed +48
-0
lines changed Original file line number Diff line number Diff line change @@ -893,6 +893,10 @@ def cross_attention_kwargs(self):
893
893
def num_timesteps (self ):
894
894
return self ._num_timesteps
895
895
896
+ @property
897
+ def interrupt (self ):
898
+ return self ._interrupt
899
+
896
900
@torch .no_grad ()
897
901
@replace_example_docstring (EXAMPLE_DOC_STRING )
898
902
def __call__ (
@@ -1089,6 +1093,7 @@ def __call__(
1089
1093
self ._guidance_scale = guidance_scale
1090
1094
self ._clip_skip = clip_skip
1091
1095
self ._cross_attention_kwargs = cross_attention_kwargs
1096
+ self ._interrupt = False
1092
1097
1093
1098
# 2. Define call parameters
1094
1099
if prompt is not None and isinstance (prompt , str ):
@@ -1235,6 +1240,9 @@ def __call__(
1235
1240
is_torch_higher_equal_2_1 = is_torch_version (">=" , "2.1" )
1236
1241
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1237
1242
for i , t in enumerate (timesteps ):
1243
+ if self .interrupt :
1244
+ continue
1245
+
1238
1246
# Relevant thread:
1239
1247
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1240
1248
if (is_unet_compiled and is_controlnet_compiled ) and is_torch_higher_equal_2_1 :
Original file line number Diff line number Diff line change @@ -891,6 +891,10 @@ def cross_attention_kwargs(self):
891
891
def num_timesteps (self ):
892
892
return self ._num_timesteps
893
893
894
+ @property
895
+ def interrupt (self ):
896
+ return self ._interrupt
897
+
894
898
@torch .no_grad ()
895
899
@replace_example_docstring (EXAMPLE_DOC_STRING )
896
900
def __call__ (
@@ -1081,6 +1085,7 @@ def __call__(
1081
1085
self ._guidance_scale = guidance_scale
1082
1086
self ._clip_skip = clip_skip
1083
1087
self ._cross_attention_kwargs = cross_attention_kwargs
1088
+ self ._interrupt = False
1084
1089
1085
1090
# 2. Define call parameters
1086
1091
if prompt is not None and isinstance (prompt , str ):
@@ -1211,6 +1216,9 @@ def __call__(
1211
1216
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
1212
1217
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1213
1218
for i , t in enumerate (timesteps ):
1219
+ if self .interrupt :
1220
+ continue
1221
+
1214
1222
# expand the latents if we are doing classifier free guidance
1215
1223
latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
1216
1224
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
Original file line number Diff line number Diff line change @@ -976,6 +976,10 @@ def cross_attention_kwargs(self):
976
976
def num_timesteps (self ):
977
977
return self ._num_timesteps
978
978
979
+ @property
980
+ def interrupt (self ):
981
+ return self ._interrupt
982
+
979
983
@torch .no_grad ()
980
984
@replace_example_docstring (EXAMPLE_DOC_STRING )
981
985
def __call__ (
@@ -1191,6 +1195,7 @@ def __call__(
1191
1195
self ._guidance_scale = guidance_scale
1192
1196
self ._clip_skip = clip_skip
1193
1197
self ._cross_attention_kwargs = cross_attention_kwargs
1198
+ self ._interrupt = False
1194
1199
1195
1200
# 2. Define call parameters
1196
1201
if prompt is not None and isinstance (prompt , str ):
@@ -1375,6 +1380,9 @@ def __call__(
1375
1380
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
1376
1381
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1377
1382
for i , t in enumerate (timesteps ):
1383
+ if self .interrupt :
1384
+ continue
1385
+
1378
1386
# expand the latents if we are doing classifier free guidance
1379
1387
latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
1380
1388
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
Original file line number Diff line number Diff line change @@ -1145,6 +1145,10 @@ def cross_attention_kwargs(self):
1145
1145
def num_timesteps (self ):
1146
1146
return self ._num_timesteps
1147
1147
1148
+ @property
1149
+ def interrupt (self ):
1150
+ return self ._interrupt
1151
+
1148
1152
@torch .no_grad ()
1149
1153
@replace_example_docstring (EXAMPLE_DOC_STRING )
1150
1154
def __call__ (
@@ -1427,6 +1431,7 @@ def __call__(
1427
1431
self ._guidance_scale = guidance_scale
1428
1432
self ._clip_skip = clip_skip
1429
1433
self ._cross_attention_kwargs = cross_attention_kwargs
1434
+ self ._interrupt = False
1430
1435
1431
1436
# 2. Define call parameters
1432
1437
if prompt is not None and isinstance (prompt , str ):
@@ -1695,6 +1700,9 @@ def denoising_value_valid(dnv):
1695
1700
1696
1701
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1697
1702
for i , t in enumerate (timesteps ):
1703
+ if self .interrupt :
1704
+ continue
1705
+
1698
1706
# expand the latents if we are doing classifier free guidance
1699
1707
latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
1700
1708
Original file line number Diff line number Diff line change @@ -990,6 +990,10 @@ def denoising_end(self):
990
990
def num_timesteps (self ):
991
991
return self ._num_timesteps
992
992
993
+ @property
994
+ def interrupt (self ):
995
+ return self ._interrupt
996
+
993
997
@torch .no_grad ()
994
998
@replace_example_docstring (EXAMPLE_DOC_STRING )
995
999
def __call__ (
@@ -1245,6 +1249,7 @@ def __call__(
1245
1249
self ._clip_skip = clip_skip
1246
1250
self ._cross_attention_kwargs = cross_attention_kwargs
1247
1251
self ._denoising_end = denoising_end
1252
+ self ._interrupt = False
1248
1253
1249
1254
# 2. Define call parameters
1250
1255
if prompt is not None and isinstance (prompt , str ):
@@ -1442,6 +1447,9 @@ def __call__(
1442
1447
is_torch_higher_equal_2_1 = is_torch_version (">=" , "2.1" )
1443
1448
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1444
1449
for i , t in enumerate (timesteps ):
1450
+ if self .interrupt :
1451
+ continue
1452
+
1445
1453
# Relevant thread:
1446
1454
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1447
1455
if (is_unet_compiled and is_controlnet_compiled ) and is_torch_higher_equal_2_1 :
Original file line number Diff line number Diff line change @@ -1070,6 +1070,10 @@ def cross_attention_kwargs(self):
1070
1070
def num_timesteps (self ):
1071
1071
return self ._num_timesteps
1072
1072
1073
+ @property
1074
+ def interrupt (self ):
1075
+ return self ._interrupt
1076
+
1073
1077
@torch .no_grad ()
1074
1078
@replace_example_docstring (EXAMPLE_DOC_STRING )
1075
1079
def __call__ (
@@ -1338,6 +1342,7 @@ def __call__(
1338
1342
self ._guidance_scale = guidance_scale
1339
1343
self ._clip_skip = clip_skip
1340
1344
self ._cross_attention_kwargs = cross_attention_kwargs
1345
+ self ._interrupt = False
1341
1346
1342
1347
# 2. Define call parameters
1343
1348
if prompt is not None and isinstance (prompt , str ):
@@ -1510,6 +1515,9 @@ def __call__(
1510
1515
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
1511
1516
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1512
1517
for i , t in enumerate (timesteps ):
1518
+ if self .interrupt :
1519
+ continue
1520
+
1513
1521
# expand the latents if we are doing classifier free guidance
1514
1522
latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
1515
1523
latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
You can’t perform that action at this time.
0 commit comments