Skip to content

Commit ef51185

Browse files
Fix is inference mode (PaddlePaddle#711)
paddle 3.0beta 不支持 from paddle.incubate.jit.is_inference_mode,这里修复下
1 parent 0f37251 commit ef51185

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727
from ...utils.paddle_utils import randn_tensor
2828
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2929

30+
try:
31+
# paddle.incubate.jit.inference is available in paddle develop but not in paddle 3.0beta, so we add a try except.
32+
from paddle.incubate.jit import is_inference_mode
33+
except:
34+
35+
def is_inference_mode(func):
36+
return False
37+
3038

3139
class DiTPipeline(DiffusionPipeline):
3240
r"""
@@ -192,7 +200,7 @@ def __call__(
192200
)
193201
# predict noise model_output
194202
noise_pred_out = self.transformer(latent_model_input, timestep=timesteps, class_labels=class_labels_input)
195-
if paddle.incubate.jit.is_inference_mode(self.transformer):
203+
if is_inference_mode(self.transformer):
196204
# self.transformer run in paddle inference.
197205
noise_pred = noise_pred_out
198206
else:
@@ -227,7 +235,7 @@ def __call__(
227235
latents = 1 / self.vae.config.scaling_factor * latents
228236

229237
samples_out = self.vae.decode(latents)
230-
if paddle.incubate.jit.is_inference_mode(self.vae.decode):
238+
if is_inference_mode(self.vae.decode):
231239
# self.vae.decode run in paddle inference.
232240
samples = samples_out
233241
else:

0 commit comments

Comments
 (0)