Skip to content

Allow pipeline to run in bfloat16 #581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed

Allow pipeline to run in bfloat16 #581

wants to merge 16 commits into from

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Sep 20, 2022

Changes:

  • Allow dtype to be specified on model load. This is a temporary solution until Save training dtype as part of the configuration #567 is addressed in a more principled way.
  • Convert loaded params to bfloat16 or float16 if necessary.
  • Keep latents in float32 during inference loop.

See comment: #559 (comment)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@pcuenca pcuenca mentioned this pull request Sep 20, 2022
@patil-suraj
Copy link
Contributor

The dtype here specifies the dtype of computation and should not be used for params. We should actually add a different argument to specify the dtype of parameters. For example, flax now has a param_dtype argument https://github.com/google/flax/blob/main/flax/linen/linear.py#L78.

The reason to do this is that, it's not recommended to keep params in bfloat16 on TPUs during training as it can cause numerical instabilities, it's only done for extremely large models. Also having these two arguments allows you to do mixed-precision training: keep params in fp32 but do forward computation in half precision.

For now we could create a branch with bfloat16 weights and the use that for inference.

@pcuenca
Copy link
Member Author

pcuenca commented Sep 20, 2022

I removed the automatic conversion but left the other change in order to be able to load the models specifying a dtype. @patil-suraj please, let me know if that's ok.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

Comment on lines 175 to 177
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
if latents.shape != latents_shape:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file doesn't seem to be in main. Should be removed from here before mergin

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The merge target is flax_pipeline for now.

Base automatically changed from flax_pipeline to main September 20, 2022 19:30
@pcuenca
Copy link
Member Author

pcuenca commented Sep 21, 2022

Replaced by #600.

@pcuenca pcuenca closed this Sep 21, 2022
@pcuenca pcuenca deleted the flax_pipeline_bf16 branch October 2, 2022 18:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants