Skip to content

Unet2D model use ConfigMixin & ModelMixin #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


# Local checkout until weights are available in the Hub
flax_path = "/sddata/sd-v1-4-flax"
flax_path = "/home/mishig/stable-diffusion-v1-4-flax"

num_samples = 8
num_inference_steps = 50
Expand All @@ -30,9 +30,9 @@
clip_model, clip_params = FlaxCLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", _do_init=False, dtype=dtype
)
unet, unet_params = UNet2D.from_pretrained(f"{flax_path}/unet", _do_init=False, dtype=dtype)
unet, unet_params = UNet2D.from_pretrained(f"{flax_path}/unet", dtype=dtype)
vae, vae_params = AutoencoderKL.from_pretrained(f"{flax_path}/vae", _do_init=False, dtype=dtype)
safety_model, safety_model_params = StableDiffusionSafetyCheckerModel.from_pretrained(f"{flax_path}/safety_checker", _do_init=False, dtype=dtype)
# safety_model, safety_model_params = StableDiffusionSafetyCheckerModel.from_pretrained(f"{flax_path}/safety_checker", _do_init=False, dtype=dtype)

config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
Expand Down Expand Up @@ -67,7 +67,7 @@


# prepare inputs
p = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
p = "Astronaut riding a horse"

input_ids = tokenizer(
[p] * num_samples, padding="max_length", truncation=True, max_length=77, return_tensors="jax"
Expand Down
125 changes: 44 additions & 81 deletions stable_diffusion_jax/modeling_unet2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformers.modeling_flax_utils import FlaxPreTrainedModel

from .configuration_unet2d import UNet2DConfig
from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
from .modeling_utils import FlaxModelMixin


def get_sinusoidal_embeddings(timesteps, embedding_dim):
Expand Down Expand Up @@ -570,15 +569,22 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru
return hidden_states


class UNet2DModule(nn.Module):
config: UNet2DConfig
@flax_register_to_config
class UNet2D(nn.Module, FlaxModelMixin, ConfigMixin):
sample_size:int=32
in_channels:int=4
out_channels:int=4
down_block_types:Tuple=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")
up_block_types:Tuple=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
block_out_channels:Tuple=(224, 448, 672, 896)
layers_per_block:int=2
attention_head_dim:int=8
cross_attention_dim:int=768
dropout:float=0.1
dtype: jnp.dtype = jnp.float32

def setup(self):
config = self.config

self.sample_size = config.sample_size
block_out_channels = config.block_out_channels
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4

# input
Expand All @@ -597,7 +603,7 @@ def setup(self):
# down
down_blocks = []
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(config.down_block_types):
for i, down_block_type in enumerate(self.down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
Expand All @@ -606,18 +612,18 @@ def setup(self):
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
dropout=config.dropout,
num_layers=config.layers_per_block,
attn_num_head_channels=config.attention_head_dim,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=self.attention_head_dim,
add_downsample=not is_final_block,
dtype=self.dtype,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
dropout=config.dropout,
num_layers=config.layers_per_block,
dropout=self.dropout,
num_layers=self.layers_per_block,
add_downsample=not is_final_block,
dtype=self.dtype,
)
Expand All @@ -628,16 +634,16 @@ def setup(self):
# mid
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=config.dropout,
attn_num_head_channels=config.attention_head_dim,
dropout=self.dropout,
attn_num_head_channels=self.attention_head_dim,
dtype=self.dtype,
)

# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(config.up_block_types):
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
Expand All @@ -649,20 +655,20 @@ def setup(self):
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=config.layers_per_block + 1,
attn_num_head_channels=config.attention_head_dim,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=self.attention_head_dim,
add_upsample=not is_final_block,
dropout=config.dropout,
dropout=self.dropout,
dtype=self.dtype,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=config.layers_per_block + 1,
num_layers=self.layers_per_block + 1,
add_upsample=not is_final_block,
dropout=config.dropout,
dropout=self.dropout,
dtype=self.dtype,
)

Expand All @@ -673,13 +679,25 @@ def setup(self):
# out
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.conv_out = nn.Conv(
config.out_channels,
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)

def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors
sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timestpes = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.config.cross_attention_dim), dtype=jnp.float32)

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

return self.init(rngs, sample, timestpes, encoder_hidden_states)["params"]

def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True):

# 1. time
Expand All @@ -705,8 +723,8 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True)

# 5. up
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-(self.config.layers_per_block + 1) :]
down_block_res_samples = down_block_res_samples[: -(self.config.layers_per_block + 1)]
res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
if isinstance(up_block, CrossAttnUpBlock2D):
sample = up_block(
sample,
Expand All @@ -723,58 +741,3 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True)
sample = self.conv_out(sample)

return sample


class UNet2DPretrainedModel(FlaxPreTrainedModel):
config_class = UNet2DConfig
base_model_prefix = "model"
module_class: nn.Module = None

def __init__(
self,
config: UNet2DConfig,
input_shape: Tuple = (1, 32, 32, 4),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timestpes = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.config.cross_attention_dim), dtype=jnp.float32)

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

return self.module.init(rngs, sample, timestpes, encoder_hidden_states)["params"]

def __call__(
self,
sample,
timesteps,
encoder_hidden_states,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
):
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

return self.module.apply(
{"params": params or self.params},
jnp.array(sample),
jnp.array(timesteps, dtype=jnp.int32),
encoder_hidden_states,
not train,
rngs=rngs,
)


class UNet2D(UNet2DPretrainedModel):
module_class = UNet2DModule
Loading