Skip to content

FlaxDiffusionPipeline & FlaxStableDiffusionPipeline #559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 20, 2022

Conversation

mishig25
Copy link
Contributor

@mishig25 mishig25 commented Sep 19, 2022

Implement FlaxDiffusionPipeline & FlaxStableDiffusionPipeline

Based on https://github.com/patil-suraj/stable-diffusion-jax/blob/stateless-scheduler/stable_diffusion_jax/pipeline_stable_diffusion.py

From this comment, we have decided to create FlaxDiffusionPipeline, rather than try to reuse DiffusionPipeline (which currently handles pytorch & onnx).

Design decisions & questions

  1. Just like DiffusionPipeline does not inherit from torch.nn.Module, FlaxDiffusionPipeline should not inherit from flax.linen.Module either. Wdyt?
  2. Every pipeline (for example, FlaxStableDiffusionPipeline ) needs to implement InferenceState(flax.struct.dataclass) so that pmap can consume the pipeline. Wdyt?
  3. If the first two points above holds, then the implementation of DiffusionPipeline & FlaxDiffusionPipeline are quite similar except with one major difference. Since flax pretrained models are initialized as model, params = xyz.from_pretrained(), FlaxDiffusionPipeline from_pretrained, save_pretrained methods needs to handle inference_state. See example here

TODOS:

  • handle all the TODO comments I left in the implementation
  • test the entire pipeline

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 19, 2022

The documentation is not available anymore as the PR was closed or merged.

@mishig25 mishig25 changed the title Flax_pipeline WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline Sep 19, 2022
@kashif
Copy link
Contributor

kashif commented Sep 19, 2022

@mishig25 I believe we will need to use the FlaxDDIMScheduler, ...

@pcuenca
Copy link
Member

pcuenca commented Sep 19, 2022

Assumptions 1 and 2 sound reasonable to me.

Regarding 3, if you want to override, say, the scheduler, then you'd need to do something like this:

scheduler, scheduler_state = SomeFlaxScheduler.from_pretrained(...)
inference_state = InferenceState(scheduler_state=scheduler_state)
pipe = FlaxStableDiffusionPipeline.from_pretrained(
    model_path,
    scheduler=scheduler,
    inference_state=inference_state,
)

Is that correct?

If so, my first instinct would be to return the final InferenceState too. However I haven't played with the code yet or made myself familiar with it.

@@ -14,4 +14,5 @@

from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to wrap this into a if flax_available_... statement I think

from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging


INDEX_FILE = "diffusion_flax_model.bin"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is never used (just like in PyTorch) we can remove it I think

@patrickvonplaten patrickvonplaten marked this pull request as ready for review September 19, 2022 20:45
@@ -14,4 +14,6 @@

from .unet_2d import UNet2DModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should be wrapped into is_available(...)



@flax.struct.dataclass
class InferenceState:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be removed - let's just make it an inference state

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could potentially be helpful to override pipeline modules, as in my code snippet above #559 (comment).

We can do the same with a dictionary, but it's uglier in my opinion. Or with a helper function that returns a dict.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I think it can just be a dict no? dicts are more universal and it means that not every pipeline has to have a data class state

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing for now -> let's maybe add later again if necessary

params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin):
loaded_sub_model = load_method(loadable_folder, **loading_kwargs)
params[name] = loaded_sub_model.create_state()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca @kashif @patil-suraj this means that every flax scheduler needs a create_state() function, but I think the design is ok/makes sense

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it does, I don't see a problem with this approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree!

self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)

# HACK for now - clean up later (PVP)
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a hack for now - in the future IMO we should move all this logic to create_state so that the scheduler is fully stateless

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

beta_prod_t = 1 - alpha_prod_t

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)

# 4. Clip "predicted x_0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's for now remove all the unimportant things

@@ -148,7 +148,8 @@ def __init__(
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4

self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def create_state(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @pcuenca @kashif what do you think about the design?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having a look thanks! I am also adding the scheduler tests too so this will be helpful I believe

alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @pcuenca mentioned we need jnp.where methods everywhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup! onto it

@patrickvonplaten
Copy link
Contributor

For now the API looks as follows (on a TPUv3-8)

#!/usr/bin/env python3
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
import numpy as np
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("fusing/sd-v1-4-flax", use_auth_token=True)

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prng_seed = jax.random.PRNGKey(0)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_prompts(prompt)
prompt_ids = shard(prompt_ids)

# set inference steps
num_inference_steps = 50

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images

images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
# Problem: resulting images don't look good

@patil-suraj , @pcuenca , @kashif , @mishig25 very keen to get your feedback on the API.

Note: We need to pass tensors into the forward call which is why I've added a prepare_prompts function.
Note: Right now the pipeline generates incorrect images and needs debugging.

@pcuenca
Copy link
Member

pcuenca commented Sep 20, 2022

The pipeline runs out of memory in v2-8. Using dtype=jnp.bfloat16 still loads everything in float32. I think this is in part because of #565: the dtype is passed as part of the kwargs but then gets ignored. I only wanted to exclude it when saving the configuration, as it couldn't be serialized, but it's always ignored on load even when we want to override it.

What would be the best way to deal with this?

(This comment applies to the model instance, params are also loaded in float32 and have to be converted if necessary)

@pcuenca
Copy link
Member

pcuenca commented Sep 20, 2022

Questions about schedulers (and overridden pipeline modules):

  1. If the user provides their own scheduler to pipeline from_pretrained, how is the state going to be handled and added to the params dict?
    a. Invoke scheduler.create_state() inside from_pretrained anyway.
    b. Have the user pass it to the pipeline using a new kw arg called scheduler_params.
    c. Let them provide a dictionary with params for all overridden modules.
    d. Go back to using the InferenceState so we can pass it instead of a dictionary.
  2. PNDM requires the latents shape in set_timesteps. This is because it reserves space in the state for the 4 samples that are used in the step computations. Perhaps there’s a better way to resolve it without this information, but assuming there isn’t, what should we do?
    a. Always send the shape to all schedulers when invoking set_timesteps() and let them ignore it if they don’t have a use for it. This is a departure from the PyTorch version, but we already had to add the state argument anyway.
    b. Only send it in specific cases, checking types or signatures.

For now I’m going with 1.c and 2.a, would love to hear other opinions.

@patil-suraj
Copy link
Contributor

patil-suraj commented Sep 20, 2022

The pipeline runs out of memory in v2-8. Using dtype=jnp.bfloat16 still loads everything in float32. I think this is in part because of #565: the dtype is passed as part of the kwargs but then gets ignored. I only wanted to exclude it when saving the configuration, as it couldn't be serialized, but it's always ignored on load even when we want to override it.

What would be the best way to deal with this?

(This comment applies to the model instance, params are also loaded in float32 and have to be converted if necessary)

@pcuenca the dtype only specifies the dtype of computation and not of params , so all params will be loaded in fp32 by default. It's analogues to with autocast("cuda") in PT.

We can create fp16/bf16 branch for flax weights, the same way we do in pt , so params get loaded in that dtype.

@pcuenca
Copy link
Member

pcuenca commented Sep 20, 2022

The pipeline runs out of memory in v2-8. Using dtype=jnp.bfloat16 still loads everything in float32. I think this is in part because of #565: the dtype is passed as part of the kwargs but then gets ignored. I only wanted to exclude it when saving the configuration, as it couldn't be serialized, but it's always ignored on load even when we want to override it.
What would be the best way to deal with this?
(This comment applies to the model instance, params are also loaded in float32 and have to be converted if necessary)

@pcuenca the dtype only specifies the dtype of computation and not of params , so all params will be loaded in fp32 by default. It's analogues to with autocast("cuda") in PT.

We can create fp16/bf16 branch for flax weights, the same way we do in pt , so params get loaded in that dtype.

You are right. In that case we could convert the weights to bfloat16 externally to save memory; for example in the notebook demo. Or save a specific branch as you propose.

However the pipeline in your repo @patil-suraj runs fine in v2-8 TPUs (I've been running an inference backend since last Friday, including the safety checker). Any idea why this one doesn't fit?

init_kwargs = {}

# inference_params
params = {}
Copy link
Contributor Author

@mishig25 mishig25 Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it not have to be special flax data structure to avoid memory fragmentation and other issues with pmap?
or it being regular python dict is okay?
cc: @patil-suraj @patrickvonplaten @pcuenca

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know tbh, very interesting if that's the case! Do you have a reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was in an assumption that everytime you shard a data, it needs to be @flux.struct.xyz?
going over the docstring her

def shard(xs):
  """Helper for pmap to shard a pytree of arrays by local_device_count.

I guess params is a valid pytree since it is just a dict that contains valid pytree nodes. So it should be fine?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a dict is a valid pytree that can be sharded :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dict works - we could make it a frozen dict to be super certain! Will update to a frozen dict so that the pipeline is not allowed to change it internally -> then we're fully in Jaxistan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds perfect!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do this in a future PR

@patrickvonplaten patrickvonplaten changed the title WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline FlaxDiffusionPipeline & FlaxStableDiffusionPipeline Sep 20, 2022
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
)
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj @younesbelkada we need to first check from_pt here in case there are both Flax and PT files, otherwise it breaks

@patrickvonplaten
Copy link
Contributor

Pipeline can easily be checked with: #592

@patrickvonplaten patrickvonplaten merged commit d934d3d into main Sep 20, 2022
@patrickvonplaten patrickvonplaten deleted the flax_pipeline branch September 20, 2022 19:30
@skirsten
Copy link
Contributor

Off-topic

@patil-suraj

We can create fp16/bf16 branch for flax weights, the same way we do in pt , so params get loaded in that dtype.

Considering there is only a bf16 and fp32 (flax) branch, how can I convert the model to fp16 for flax?
Probably using FlaxModelMixin.to_bf16 but I did not figure out how to apply it. Maybe somebody could upload the scripts used for the conversion to the script folder?

@patil-suraj
Copy link
Contributor

There is a FlaxModelMixin.to_fp16 also, and we need to convert each individial model(unet, text_encoder etc) to fp16 and then save and load the pipeline.

For example:

pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", revision="flax",
)
params["unet"] = pipe.unet.to_fp16(params["unet"])
params["vae"] = pipe.vae.to_fp16(params["vae"])
params["text_encoder"] = pipe.text_encoder.to_fp16(params["text_encoder"])
params["safety_checker"] = pipe.safety_checker.to_fp16(params["safety_checker"])

pipe.save_pretrained(pipe, params=params)

PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline

* todo comment

* Fix imports

* Fix imports

* add dummies

* Fix empty init

* make pipeline work

* up

* Use Flax schedulers (typing, docstring)

* Wrap model imports inside availability checks.

* more updates

* make sure flax is not broken

* make style

* more fixes

* up

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@latenitesoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants