Skip to content

Commit 8332543

Browse files
committed
remove void, add types for params
1 parent a273710 commit 8332543

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,24 @@ class UNet2DModel(ModelMixin, ConfigMixin):
1313
@register_to_config
1414
def __init__(
1515
self,
16-
sample_size=None,
17-
in_channels=3,
18-
out_channels=3,
19-
center_input_sample=False,
20-
time_embedding_type="positional",
21-
freq_shift=0,
22-
flip_sin_to_cos=True,
23-
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
24-
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
25-
block_out_channels=(224, 448, 672, 896),
26-
layers_per_block=2,
27-
mid_block_scale_factor=1,
28-
downsample_padding=1,
29-
act_fn="silu",
30-
attention_head_dim=8,
31-
norm_num_groups=32,
32-
norm_eps=1e-5,
33-
) -> None:
16+
sample_size: int = None,
17+
in_channels: int = 3,
18+
out_channels: int = 3,
19+
center_input_sample: bool = False,
20+
time_embedding_type: str = "positional",
21+
freq_shift: int = 0,
22+
flip_sin_to_cos: bool = True,
23+
down_block_types: tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
24+
up_block_types: tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
25+
block_out_channels: tuple[int] = (224, 448, 672, 896),
26+
layers_per_block: int = 2,
27+
mid_block_scale_factor: float = 1,
28+
downsample_padding: int = 1,
29+
act_fn: str = "silu",
30+
attention_head_dim: int = 8,
31+
norm_num_groups: int = 32,
32+
norm_eps: float = 1e-5,
33+
):
3434
super().__init__()
3535

3636
self.sample_size = sample_size

src/diffusers/models/unet_2d_condition.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
1313
@register_to_config
1414
def __init__(
1515
self,
16-
sample_size=None,
17-
in_channels=4,
18-
out_channels=4,
19-
center_input_sample=False,
20-
flip_sin_to_cos=True,
21-
freq_shift=0,
22-
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
23-
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
24-
block_out_channels=(320, 640, 1280, 1280),
25-
layers_per_block=2,
26-
downsample_padding=1,
27-
mid_block_scale_factor=1,
28-
act_fn="silu",
29-
norm_num_groups=32,
30-
norm_eps=1e-5,
31-
cross_attention_dim=1280,
32-
attention_head_dim=8,
33-
) -> None:
16+
sample_size: int = None,
17+
in_channels: int = 4,
18+
out_channels: int = 4,
19+
center_input_sample: bool = False,
20+
flip_sin_to_cos: bool = True,
21+
freq_shift: int = 0,
22+
down_block_types: tuple[str] = ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
23+
up_block_types: tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
24+
block_out_channels: tuple[int] = (320, 640, 1280, 1280),
25+
layers_per_block: int = 2,
26+
downsample_padding: int = 1,
27+
mid_block_scale_factor: float = 1,
28+
act_fn: str = "silu",
29+
norm_num_groups: int = 32,
30+
norm_eps: float = 1e-5,
31+
cross_attention_dim: int = 1280,
32+
attention_head_dim: int = 8,
33+
):
3434
super().__init__()
3535

3636
self.sample_size = sample_size

0 commit comments

Comments
 (0)