Skip to content

Commit 9c4b21f

Browse files
committed
type hints
1 parent 5164c9f commit 9c4b21f

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616

1717
import warnings
18+
from typing import Optional
1819

1920
import torch
2021

@@ -28,7 +29,15 @@ def __init__(self, unet, scheduler):
2829
self.register_modules(unet=unet, scheduler=scheduler)
2930

3031
@torch.no_grad()
31-
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
32+
def __call__(
33+
self,
34+
batch_size: int = 1,
35+
generator: Optional[torch.Generator] = None,
36+
eta: float = 0.0,
37+
num_inference_steps: int = 50,
38+
output_type: Optional[str] = "pil",
39+
**kwargs,
40+
):
3241

3342
if "torch_device" in kwargs:
3443
device = kwargs.pop("torch_device")

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616

1717
import warnings
18+
from typing import Optional
1819

1920
import torch
2021

@@ -28,7 +29,13 @@ def __init__(self, unet, scheduler):
2829
self.register_modules(unet=unet, scheduler=scheduler)
2930

3031
@torch.no_grad()
31-
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
32+
def __call__(
33+
self,
34+
batch_size: int = 1,
35+
generator: Optional[torch.Generator] = None,
36+
output_type: Optional[str] = "pil",
37+
**kwargs,
38+
):
3239
if "torch_device" in kwargs:
3340
device = kwargs.pop("torch_device")
3441
warnings.warn(

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,23 @@
1010
from transformers.configuration_utils import PretrainedConfig
1111
from transformers.modeling_outputs import BaseModelOutput
1212
from transformers.modeling_utils import PreTrainedModel
13+
from transformers.tokenization_utils import PreTrainedTokenizer
1314
from transformers.utils import logging
1415

16+
from ...models import UNet2DModel, UNet2DConditionModel, AutoencoderKL, VQModel
1517
from ...pipeline_utils import DiffusionPipeline
18+
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1619

1720

1821
class LDMTextToImagePipeline(DiffusionPipeline):
19-
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
22+
def __init__(
23+
self,
24+
vqvae: Union[VQModel, AutoencoderKL],
25+
bert: PreTrainedModel,
26+
tokenizer: PreTrainedTokenizer,
27+
unet: Union[UNet2DModel, UNet2DConditionModel],
28+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
29+
):
2030
super().__init__()
2131
scheduler = scheduler.set_format("pt")
2232
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@@ -614,7 +624,7 @@ def custom_forward(*inputs):
614624

615625

616626
class LDMBertModel(LDMBertPreTrainedModel):
617-
def __init__(self, config):
627+
def __init__(self, config: LDMBertConfig):
618628
super().__init__(config)
619629
self.model = LDMBertEncoder(config)
620630
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)

0 commit comments

Comments
 (0)