Skip to content

Flax documentation #589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module

## AutoencoderKL
[[autodoc]] AutoencoderKL

## FlaxModelMixin
[[autodoc]] FlaxModelMixin

## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput

## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel

## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput

## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput

## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
76 changes: 76 additions & 0 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@


class FlaxAttentionBlock(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762

Parameters:
query_dim (:obj:`int`):
Input hidden states dimension
heads (:obj:`int`, *optional*, defaults to 8):
Number of heads
dim_head (:obj:`int`, *optional*, defaults to 64):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`

"""
query_dim: int
heads: int = 8
dim_head: int = 64
Expand Down Expand Up @@ -74,6 +90,23 @@ def __call__(self, hidden_states, context=None, deterministic=True):


class FlaxBasicTransformerBlock(nn.Module):
r"""
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
https://arxiv.org/abs/1706.03762


Parameters:
dim (:obj:`int`):
Inner hidden states dimension
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
n_heads: int
d_head: int
Expand Down Expand Up @@ -110,6 +143,25 @@ def __call__(self, hidden_states, context, deterministic=True):


class FlaxSpatialTransformer(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
https://arxiv.org/pdf/1506.02025.pdf


Parameters:
in_channels (:obj:`int`):
Input number of channels
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
depth (:obj:`int`, *optional*, defaults to 1):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
n_heads: int
d_head: int
Expand Down Expand Up @@ -162,6 +214,18 @@ def __call__(self, hidden_states, context, deterministic=True):


class FlaxGluFeedForward(nn.Module):
r"""
Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
https://arxiv.org/abs/2002.05202

Parameters:
dim (:obj:`int`):
Inner hidden states dimension
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
Expand All @@ -179,6 +243,18 @@ def __call__(self, hidden_states, deterministic=True):


class FlaxGEGLU(nn.Module):
r"""
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
https://arxiv.org/abs/2002.05202.

Parameters:
dim (:obj:`int`):
Input hidden states dimension
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/models/embeddings_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):


class FlaxTimestepEmbedding(nn.Module):
r"""
Time step Embedding Module. Learns embeddings for input time steps.

Args:
time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32

Expand All @@ -49,6 +58,13 @@ def __call__(self, temb):


class FlaxTimesteps(nn.Module):
r"""
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239

Args:
dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
"""
dim: int = 32
freq_shift: float = 1

Expand Down
31 changes: 24 additions & 7 deletions src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)

Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.

Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 ❤️


Parameters:
sample_size (`int`, *optional*): The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
sample_size (`int`, *optional*):
The size of the input sample.
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4):
The number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
Expand All @@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8):
The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
"""

sample_size: int = 32
Expand Down
91 changes: 91 additions & 0 deletions src/diffusers/models/unet_blocks_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@


class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104

Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
Expand Down Expand Up @@ -73,6 +93,23 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru


class FlaxDownBlock2D(nn.Module):
r"""
Flax 2D downsizing block

Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
Expand Down Expand Up @@ -113,6 +150,26 @@ def __call__(self, hidden_states, temb, deterministic=True):


class FlaxCrossAttnUpBlock2D(nn.Module):
r"""
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104

Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
prev_output_channel: int
Expand Down Expand Up @@ -170,6 +227,25 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_


class FlaxUpBlock2D(nn.Module):
r"""
Flax 2D upsampling block

Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
prev_output_channel (:obj:`int`):
Output channels from the previous block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
prev_output_channel: int
Expand Down Expand Up @@ -214,6 +290,21 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T


class FlaxUNetMidBlock2DCrossAttn(nn.Module):
r"""
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104

Parameters:
in_channels (:obj:`int`):
Input channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
dropout: float = 0.0
num_layers: int = 1
Expand Down
Loading