-
Notifications
You must be signed in to change notification settings - Fork 6.1k
FlaxDiffusionPipeline & FlaxStableDiffusionPipeline #559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@mishig25 I believe we will need to use the |
Assumptions Regarding scheduler, scheduler_state = SomeFlaxScheduler.from_pretrained(...)
inference_state = InferenceState(scheduler_state=scheduler_state)
pipe = FlaxStableDiffusionPipeline.from_pretrained(
model_path,
scheduler=scheduler,
inference_state=inference_state,
) Is that correct? If so, my first instinct would be to return the final |
src/diffusers/models/__init__.py
Outdated
@@ -14,4 +14,5 @@ | |||
|
|||
from .unet_2d import UNet2DModel | |||
from .unet_2d_condition import UNet2DConditionModel | |||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to wrap this into a if flax_available_...
statement I think
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging | ||
|
||
|
||
INDEX_FILE = "diffusion_flax_model.bin" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is never used (just like in PyTorch) we can remove it I think
src/diffusers/models/__init__.py
Outdated
@@ -14,4 +14,6 @@ | |||
|
|||
from .unet_2d import UNet2DModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should be wrapped into is_available(...)
|
||
|
||
@flax.struct.dataclass | ||
class InferenceState: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be removed - let's just make it an inference state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could potentially be helpful to override pipeline modules, as in my code snippet above #559 (comment).
We can do the same with a dictionary, but it's uglier in my opinion. Or with a helper function that returns a dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I think it can just be a dict
no? dicts
are more universal and it means that not every pipeline has to have a data class state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing for now -> let's maybe add later again if necessary
params[name] = loaded_params | ||
elif issubclass(class_obj, SchedulerMixin): | ||
loaded_sub_model = load_method(loadable_folder, **loading_kwargs) | ||
params[name] = loaded_sub_model.create_state() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcuenca @kashif @patil-suraj this means that every flax scheduler needs a create_state()
function, but I think the design is ok/makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it does, I don't see a problem with this approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree!
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) | ||
|
||
# HACK for now - clean up later (PVP) | ||
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a hack for now - in the future IMO we should move all this logic to create_state
so that the scheduler is fully stateless
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed.
beta_prod_t = 1 - alpha_prod_t | ||
|
||
# 3. compute predicted original sample from predicted noise also called | ||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | ||
|
||
# 4. Clip "predicted x_0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's for now remove all the unimportant things
@@ -148,7 +148,8 @@ def __init__( | |||
# mainly at formula (9), (12), (13) and the Algorithm 2. | |||
self.pndm_order = 4 | |||
|
|||
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) | |||
def create_state(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having a look thanks! I am also adding the scheduler tests too so this will be helpful I believe
alpha_prod_t = self.alphas_cumprod[timestep] | ||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | ||
alpha_prod_t = alphas_cumprod[timestep] | ||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @pcuenca mentioned we need jnp.where
methods everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup! onto it
For now the API looks as follows (on a TPUv3-8) #!/usr/bin/env python3
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
import numpy as np
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("fusing/sd-v1-4-flax", use_auth_token=True)
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prng_seed = jax.random.PRNGKey(0)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_prompts(prompt)
prompt_ids = shard(prompt_ids)
# set inference steps
num_inference_steps = 50
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
# Problem: resulting images don't look good @patil-suraj , @pcuenca , @kashif , @mishig25 very keen to get your feedback on the API. Note: We need to pass tensors into the forward call which is why I've added a |
The pipeline runs out of memory in What would be the best way to deal with this? (This comment applies to the model instance, params are also loaded in |
Questions about schedulers (and overridden pipeline modules):
For now I’m going with |
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
@pcuenca the We can create |
You are right. In that case we could convert the weights to However the pipeline in your repo @patil-suraj runs fine in v2-8 TPUs (I've been running an inference backend since last Friday, including the safety checker). Any idea why this one doesn't fit? |
init_kwargs = {} | ||
|
||
# inference_params | ||
params = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it not have to be special flax data structure to avoid memory fragmentation and other issues with pmap?
or it being regular python dict is okay?
cc: @patil-suraj @patrickvonplaten @pcuenca
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know tbh, very interesting if that's the case! Do you have a reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was in an assumption that everytime you shard
a data, it needs to be @flux.struct.xyz
?
going over the docstring her
def shard(xs):
"""Helper for pmap to shard a pytree of arrays by local_device_count.
I guess params
is a valid pytree since it is just a dict that contains valid pytree nodes. So it should be fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, a dict is a valid pytree that can be sharded :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dict works - we could make it a frozen dict to be super certain! Will update to a frozen dict so that the pipeline is not allowed to change it internally -> then we're fully in Jaxistan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do this in a future PR
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): | ||
raise EnvironmentError( | ||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " | ||
) | ||
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | ||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj @younesbelkada we need to first check from_pt
here in case there are both Flax and PT files, otherwise it breaks
Pipeline can easily be checked with: #592 |
Off-topic
Considering there is only a bf16 and fp32 (flax) branch, how can I convert the model to fp16 for flax? |
There is a For example: pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", revision="flax",
)
params["unet"] = pipe.unet.to_fp16(params["unet"])
params["vae"] = pipe.vae.to_fp16(params["vae"])
params["text_encoder"] = pipe.text_encoder.to_fp16(params["text_encoder"])
params["safety_checker"] = pipe.safety_checker.to_fp16(params["safety_checker"])
pipe.save_pretrained(pipe, params=params) |
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline * todo comment * Fix imports * Fix imports * add dummies * Fix empty init * make pipeline work * up * Use Flax schedulers (typing, docstring) * Wrap model imports inside availability checks. * more updates * make sure flax is not broken * make style * more fixes * up Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@latenitesoft.com>
Implement FlaxDiffusionPipeline & FlaxStableDiffusionPipeline
Based on https://github.com/patil-suraj/stable-diffusion-jax/blob/stateless-scheduler/stable_diffusion_jax/pipeline_stable_diffusion.py
From this comment, we have decided to create FlaxDiffusionPipeline, rather than try to reuse DiffusionPipeline (which currently handles pytorch & onnx).
Design decisions & questions
torch.nn.Module
, FlaxDiffusionPipeline should not inherit fromflax.linen.Module
either. Wdyt?FlaxStableDiffusionPipeline
) needs to implement InferenceState(flax.struct.dataclass) so that pmap can consume the pipeline. Wdyt?DiffusionPipeline
&FlaxDiffusionPipeline
are quite similar except with one major difference. Since flax pretrained models are initialized asmodel, params = xyz.from_pretrained()
, FlaxDiffusionPipelinefrom_pretrained
,save_pretrained
methods needs to handleinference_state
. See example hereTODOS: