Description
Currently, we have got two codepaths:
- For non-sharded checkpoints we do:
- For sharded checkpoints we do:
And then for the (bnb) quantized checkpoints, we merge a sharded checkpoint:
Essentially, we shouldn't have to merge sharded checkpoints even if it's quantized.
This will also allow us to more generally use keep_module_in_fp32
for sharded checkpoints. Currently, we have this logic for casting a model (which is tested thoroughly):
When using load_model_dict_into_meta()
, we do consider keep_module_in_fp32
:
But since for sharded checkpoints, we use load_checkpoint_and_dispatch()
, there is no way to pass keep_module_in_fp32
:
https://huggingface.co/docs/accelerate/main/en/package_reference/big_modeling#accelerate.load_checkpoint_and_dispatch
As discussed with @SunMarc, it's better to uniformize this so that we don't have to maintain two different codepaths and rely completely on load_model_dict_into_meta()
. Marc has kindly agreed to open a PR to attempt this (this could be done in a series of PRs if needed). But I will join if any help is needed.
Sub-issues
Metadata
Metadata
Type
Projects
Status