Skip to content

Commit 0243798

Browse files
authored
[Type Hint] Unet Models (huggingface#330)
* add void check * remove void, add types for params
1 parent 8912ee6 commit 0243798

File tree

2 files changed

+41
-36
lines changed

2 files changed

+41
-36
lines changed

models/unet_2d.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -13,23 +13,23 @@ class UNet2DModel(ModelMixin, ConfigMixin):
1313
@register_to_config
1414
def __init__(
1515
self,
16-
sample_size=None,
17-
in_channels=3,
18-
out_channels=3,
19-
center_input_sample=False,
20-
time_embedding_type="positional",
21-
freq_shift=0,
22-
flip_sin_to_cos=True,
23-
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
24-
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
25-
block_out_channels=(224, 448, 672, 896),
26-
layers_per_block=2,
27-
mid_block_scale_factor=1,
28-
downsample_padding=1,
29-
act_fn="silu",
30-
attention_head_dim=8,
31-
norm_num_groups=32,
32-
norm_eps=1e-5,
16+
sample_size: Optional[int] = None,
17+
in_channels: int = 3,
18+
out_channels: int = 3,
19+
center_input_sample: bool = False,
20+
time_embedding_type: str = "positional",
21+
freq_shift: int = 0,
22+
flip_sin_to_cos: bool = True,
23+
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
24+
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
25+
block_out_channels: Tuple[int] = (224, 448, 672, 896),
26+
layers_per_block: int = 2,
27+
mid_block_scale_factor: float = 1,
28+
downsample_padding: int = 1,
29+
act_fn: str = "silu",
30+
attention_head_dim: int = 8,
31+
norm_num_groups: int = 32,
32+
norm_eps: float = 1e-5,
3333
):
3434
super().__init__()
3535

models/unet_2d_condition.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
1313
@register_to_config
1414
def __init__(
1515
self,
16-
sample_size=None,
17-
in_channels=4,
18-
out_channels=4,
19-
center_input_sample=False,
20-
flip_sin_to_cos=True,
21-
freq_shift=0,
22-
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
23-
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
24-
block_out_channels=(320, 640, 1280, 1280),
25-
layers_per_block=2,
26-
downsample_padding=1,
27-
mid_block_scale_factor=1,
28-
act_fn="silu",
29-
norm_num_groups=32,
30-
norm_eps=1e-5,
31-
cross_attention_dim=1280,
32-
attention_head_dim=8,
16+
sample_size: Optional[int] = None,
17+
in_channels: int = 4,
18+
out_channels: int = 4,
19+
center_input_sample: bool = False,
20+
flip_sin_to_cos: bool = True,
21+
freq_shift: int = 0,
22+
down_block_types: Tuple[str] = (
23+
"CrossAttnDownBlock2D",
24+
"CrossAttnDownBlock2D",
25+
"CrossAttnDownBlock2D",
26+
"DownBlock2D",
27+
),
28+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
29+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
30+
layers_per_block: int = 2,
31+
downsample_padding: int = 1,
32+
mid_block_scale_factor: float = 1,
33+
act_fn: str = "silu",
34+
norm_num_groups: int = 32,
35+
norm_eps: float = 1e-5,
36+
cross_attention_dim: int = 1280,
37+
attention_head_dim: int = 8,
3338
):
3439
super().__init__()
3540

0 commit comments

Comments
 (0)