Skip to content

Commit 2af83c2

Browse files
santiviqueznatolambert
authored andcommitted
[Type hint] Score SDE VE pipeline (#325)
1 parent 1b7c3c3 commit 2af83c2

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,33 @@
11
#!/usr/bin/env python3
22
import warnings
3+
from typing import Optional
34

45
import torch
56

67
from diffusers import DiffusionPipeline
78

9+
from ...models import UNet2DModel
10+
from ...schedulers import ScoreSdeVeScheduler
11+
812

913
class ScoreSdeVePipeline(DiffusionPipeline):
10-
def __init__(self, unet, scheduler):
14+
15+
unet: UNet2DModel
16+
scheduler: ScoreSdeVeScheduler
17+
18+
def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
1119
super().__init__()
1220
self.register_modules(unet=unet, scheduler=scheduler)
1321

1422
@torch.no_grad()
15-
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs):
23+
def __call__(
24+
self,
25+
batch_size: int = 1,
26+
num_inference_steps: int = 2000,
27+
generator: Optional[torch.Generator] = None,
28+
output_type: Optional[str] = "pil",
29+
**kwargs,
30+
):
1631
if "torch_device" in kwargs:
1732
device = kwargs.pop("torch_device")
1833
warnings.warn(

0 commit comments

Comments
 (0)