2
2
3
3
import torch
4
4
import torch .nn as nn
5
+ from typing import Optional , Tuple
5
6
6
7
from ..configuration_utils import ConfigMixin , register_to_config
7
8
from ..modeling_utils import ModelMixin
@@ -13,15 +14,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
13
14
@register_to_config
14
15
def __init__ (
15
16
self ,
16
- sample_size : int = None ,
17
+ sample_size : Optional [ int ] = None ,
17
18
in_channels : int = 4 ,
18
19
out_channels : int = 4 ,
19
20
center_input_sample : bool = False ,
20
21
flip_sin_to_cos : bool = True ,
21
22
freq_shift : int = 0 ,
22
- down_block_types : tuple [str ] = ("CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "DownBlock2D" ),
23
- up_block_types : tuple [str ] = ("UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ),
24
- block_out_channels : tuple [int ] = (320 , 640 , 1280 , 1280 ),
23
+ down_block_types : Tuple [str ] = (
24
+ "CrossAttnDownBlock2D" ,
25
+ "CrossAttnDownBlock2D" ,
26
+ "CrossAttnDownBlock2D" ,
27
+ "DownBlock2D" ,
28
+ ),
29
+ up_block_types : Tuple [str ] = ("UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ),
30
+ block_out_channels : Tuple [int ] = (320 , 640 , 1280 , 1280 ),
25
31
layers_per_block : int = 2 ,
26
32
downsample_padding : int = 1 ,
27
33
mid_block_scale_factor : float = 1 ,
0 commit comments