Skip to content

Commit 90cfb04

Browse files
committed
fix typing, format
1 parent 8332543 commit 90cfb04

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
from typing import Optional, Tuple
56

67
from ..configuration_utils import ConfigMixin, register_to_config
78
from ..modeling_utils import ModelMixin
@@ -13,16 +14,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
1314
@register_to_config
1415
def __init__(
1516
self,
16-
sample_size: int = None,
17+
sample_size: Optional[int] = None,
1718
in_channels: int = 3,
1819
out_channels: int = 3,
1920
center_input_sample: bool = False,
2021
time_embedding_type: str = "positional",
2122
freq_shift: int = 0,
2223
flip_sin_to_cos: bool = True,
23-
down_block_types: tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
24-
up_block_types: tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
25-
block_out_channels: tuple[int] = (224, 448, 672, 896),
24+
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
25+
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
26+
block_out_channels: Tuple[int] = (224, 448, 672, 896),
2627
layers_per_block: int = 2,
2728
mid_block_scale_factor: float = 1,
2829
downsample_padding: int = 1,

src/diffusers/models/unet_2d_condition.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
from typing import Optional, Tuple
56

67
from ..configuration_utils import ConfigMixin, register_to_config
78
from ..modeling_utils import ModelMixin
@@ -13,15 +14,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
1314
@register_to_config
1415
def __init__(
1516
self,
16-
sample_size: int = None,
17+
sample_size: Optional[int] = None,
1718
in_channels: int = 4,
1819
out_channels: int = 4,
1920
center_input_sample: bool = False,
2021
flip_sin_to_cos: bool = True,
2122
freq_shift: int = 0,
22-
down_block_types: tuple[str] = ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
23-
up_block_types: tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
24-
block_out_channels: tuple[int] = (320, 640, 1280, 1280),
23+
down_block_types: Tuple[str] = (
24+
"CrossAttnDownBlock2D",
25+
"CrossAttnDownBlock2D",
26+
"CrossAttnDownBlock2D",
27+
"DownBlock2D",
28+
),
29+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
30+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
2531
layers_per_block: int = 2,
2632
downsample_padding: int = 1,
2733
mid_block_scale_factor: float = 1,

0 commit comments

Comments
 (0)