Skip to content

Commit 2efc2ed

Browse files
patrickvonplatenanton-l
authored andcommitted
[Type hint] Karras VE pipeline (#288)
* [Type hint] Karras VE pipeline * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
1 parent 3d79aa4 commit 2efc2ed

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
import warnings
3+
from typing import Optional
34

45
import torch
56

@@ -21,13 +22,20 @@ class KarrasVePipeline(DiffusionPipeline):
2122
unet: UNet2DModel
2223
scheduler: KarrasVeScheduler
2324

24-
def __init__(self, unet, scheduler):
25+
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
2526
super().__init__()
2627
scheduler = scheduler.set_format("pt")
2728
self.register_modules(unet=unet, scheduler=scheduler)
2829

2930
@torch.no_grad()
30-
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, output_type="pil", **kwargs):
31+
def __call__(
32+
self,
33+
batch_size: int = 1,
34+
num_inference_steps: int = 50,
35+
generator: Optional[torch.Generator] = None,
36+
output_type: Optional[str] = "pil",
37+
**kwargs,
38+
):
3139
if "torch_device" in kwargs:
3240
device = kwargs.pop("torch_device")
3341
warnings.warn(

0 commit comments

Comments
 (0)