Skip to content

Commit 226a8b7

Browse files
sidthekidderanton-l
authored andcommitted
[Type Hints] DDIM pipelines (#345)
* type hints * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co>
1 parent 43b3584 commit 226a8b7

File tree

3 files changed

+25
-10
lines changed

3 files changed

+25
-10
lines changed

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
import warnings
18-
from typing import Tuple, Union
18+
from typing import Optional, Tuple, Union
1919

2020
import torch
2121

@@ -31,11 +31,11 @@ def __init__(self, unet, scheduler):
3131
@torch.no_grad()
3232
def __call__(
3333
self,
34-
batch_size=1,
35-
generator=None,
36-
eta=0.0,
37-
num_inference_steps=50,
38-
output_type="pil",
34+
batch_size: int = 1,
35+
generator: Optional[torch.Generator] = None,
36+
eta: float = 0.0,
37+
num_inference_steps: int = 50,
38+
output_type: Optional[str] = "pil",
3939
return_dict: bool = True,
4040
**kwargs,
4141
) -> Union[ImagePipelineOutput, Tuple]:

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
import warnings
18-
from typing import Tuple, Union
18+
from typing import Optional, Tuple, Union
1919

2020
import torch
2121

@@ -30,7 +30,12 @@ def __init__(self, unet, scheduler):
3030

3131
@torch.no_grad()
3232
def __call__(
33-
self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs
33+
self,
34+
batch_size: int = 1,
35+
generator: Optional[torch.Generator] = None,
36+
output_type: Optional[str] = "pil",
37+
return_dict: bool = True,
38+
**kwargs,
3439
) -> Union[ImagePipelineOutput, Tuple]:
3540
if "torch_device" in kwargs:
3641
device = kwargs.pop("torch_device")

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,23 @@
1010
from transformers.configuration_utils import PretrainedConfig
1111
from transformers.modeling_outputs import BaseModelOutput
1212
from transformers.modeling_utils import PreTrainedModel
13+
from transformers.tokenization_utils import PreTrainedTokenizer
1314
from transformers.utils import logging
1415

16+
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
1517
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
18+
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
1619

1720

1821
class LDMTextToImagePipeline(DiffusionPipeline):
19-
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
22+
def __init__(
23+
self,
24+
vqvae: Union[VQModel, AutoencoderKL],
25+
bert: PreTrainedModel,
26+
tokenizer: PreTrainedTokenizer,
27+
unet: Union[UNet2DModel, UNet2DConditionModel],
28+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
29+
):
2030
super().__init__()
2131
scheduler = scheduler.set_format("pt")
2232
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@@ -618,7 +628,7 @@ def custom_forward(*inputs):
618628

619629

620630
class LDMBertModel(LDMBertPreTrainedModel):
621-
def __init__(self, config):
631+
def __init__(self, config: LDMBertConfig):
622632
super().__init__(config)
623633
self.model = LDMBertEncoder(config)
624634
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)

0 commit comments

Comments
 (0)