@@ -13,24 +13,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
13
13
@register_to_config
14
14
def __init__ (
15
15
self ,
16
- sample_size = None ,
17
- in_channels = 4 ,
18
- out_channels = 4 ,
19
- center_input_sample = False ,
20
- flip_sin_to_cos = True ,
21
- freq_shift = 0 ,
22
- down_block_types = ("CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "DownBlock2D" ),
23
- up_block_types = ("UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ),
24
- block_out_channels = (320 , 640 , 1280 , 1280 ),
25
- layers_per_block = 2 ,
26
- downsample_padding = 1 ,
27
- mid_block_scale_factor = 1 ,
28
- act_fn = "silu" ,
29
- norm_num_groups = 32 ,
30
- norm_eps = 1e-5 ,
31
- cross_attention_dim = 1280 ,
32
- attention_head_dim = 8 ,
33
- ) -> None :
16
+ sample_size : int = None ,
17
+ in_channels : int = 4 ,
18
+ out_channels : int = 4 ,
19
+ center_input_sample : bool = False ,
20
+ flip_sin_to_cos : bool = True ,
21
+ freq_shift : int = 0 ,
22
+ down_block_types : tuple [ str ] = ("CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "DownBlock2D" ),
23
+ up_block_types : tuple [ str ] = ("UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ),
24
+ block_out_channels : tuple [ int ] = (320 , 640 , 1280 , 1280 ),
25
+ layers_per_block : int = 2 ,
26
+ downsample_padding : int = 1 ,
27
+ mid_block_scale_factor : float = 1 ,
28
+ act_fn : str = "silu" ,
29
+ norm_num_groups : int = 32 ,
30
+ norm_eps : float = 1e-5 ,
31
+ cross_attention_dim : int = 1280 ,
32
+ attention_head_dim : int = 8 ,
33
+ ):
34
34
super ().__init__ ()
35
35
36
36
self .sample_size = sample_size
0 commit comments