Skip to content

Add from_pt argument in .from_pretrained #527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions src/diffusers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - Flax general utilities."""
import re

import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey

from .utils import logging


logger = logging.get_logger(__name__)


def rename_key(key):
regex = r"\w+[.]\d+"
pats = re.findall(regex, key)
for pat in pats:
key = key.replace(pat, "_".join(pat.split(".")))
return key


#####################
# PyTorch => Flax #
#####################

# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""

# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
):
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor

# embedding
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
return renamed_pt_tuple_key, pt_tensor

# conv layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor

# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight":
pt_tensor = pt_tensor.T
return renamed_pt_tuple_key, pt_tensor

# old PyTorch layer norm weight
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
if pt_tuple_key[-1] == "gamma":
return renamed_pt_tuple_key, pt_tensor

# old PyTorch layer norm bias
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor

return pt_tuple_key, pt_tensor


def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
# Step 1: Convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

# Step 2: Since the model is stateless, get random Flax params
random_flax_params = flax_model.init_weights(PRNGKey(init_key))

random_flax_state_dict = flatten_dict(random_flax_params)
flax_state_dict = {}

# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
renamed_pt_key = rename_key(pt_key)
pt_tuple_key = tuple(renamed_pt_key.split("."))

# Correctly rename weight parameters
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)

if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)

# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)

return unflatten_dict(flax_state_dict)
61 changes: 39 additions & 22 deletions src/diffusers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError

from .modeling_utils import WEIGHTS_NAME
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import WEIGHTS_NAME, load_state_dict
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging


Expand Down Expand Up @@ -245,6 +246,8 @@ def from_pretrained(
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch checkpoint save file.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
Expand Down Expand Up @@ -272,6 +275,7 @@ def from_pretrained(
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
from_pt = kwargs.pop("from_pt", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
Expand Down Expand Up @@ -306,10 +310,16 @@ def from_pretrained(
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
)
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights."
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
" using `from_pt=True`."
)
else:
raise EnvironmentError(
Expand All @@ -320,7 +330,7 @@ def from_pretrained(
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_NAME,
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
Expand Down Expand Up @@ -370,25 +380,32 @@ def from_pretrained(
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)

try:
with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
if from_pt:
# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)

# Step 2: Convert the weights
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
else:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)

# flatten dicts
Expand Down
36 changes: 25 additions & 11 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def setup(self):
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")

self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
Expand Down Expand Up @@ -82,9 +82,9 @@ class FlaxBasicTransformerBlock(nn.Module):

def setup(self):
# self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
Expand All @@ -93,12 +93,12 @@ def setup(self):
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual

# cross attention
residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = hidden_states + residual

# feed forward
Expand Down Expand Up @@ -167,14 +167,28 @@ class FlaxGluFeedForward(nn.Module):
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32

def setup(self):
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_2(hidden_states)
return hidden_states


class FlaxGEGLU(nn.Module):
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32

def setup(self):
inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states)
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
return hidden_states
return hidden_linear * nn.gelu(hidden_gelu)
10 changes: 9 additions & 1 deletion src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):

def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timesteps = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
Expand Down Expand Up @@ -214,10 +214,17 @@ def __call__(
When returning a tuple, the first element is the sample tensor.
"""
# 1. time
if not isinstance(timesteps, jnp.ndarray):
timesteps = jnp.array([timesteps], dtype=jnp.int32)
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
timesteps = timesteps.astype(dtype=jnp.float32)
timesteps = jnp.expand_dims(timesteps, 0)

t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)

# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)

# 3. down
Expand Down Expand Up @@ -251,6 +258,7 @@ def __call__(
sample = self.conv_norm_out(sample)
sample = nn.silu(sample)
sample = self.conv_out(sample)
sample = jnp.transpose(sample, (0, 3, 1, 2))

if not return_dict:
return (sample,)
Expand Down
Loading