Skip to content

[Core] refactor model loading #10013

Closed
0 of 1 issue completed
Closed
0 of 1 issue completed
@sayakpaul

Description

@sayakpaul

Currently, we have got two codepaths:

  1. For non-sharded checkpoints we do:
    unexpected_keys = load_model_dict_into_meta(
  2. For sharded checkpoints we do:
    accelerate.load_checkpoint_and_dispatch(

And then for the (bnb) quantized checkpoints, we merge a sharded checkpoint:

model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)

Essentially, we shouldn't have to merge sharded checkpoints even if it's quantized.

This will also allow us to more generally use keep_module_in_fp32 for sharded checkpoints. Currently, we have this logic for casting a model (which is tested thoroughly):

elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:

When using load_model_dict_into_meta(), we do consider keep_module_in_fp32:

keep_in_fp32_modules=None,

But since for sharded checkpoints, we use load_checkpoint_and_dispatch(), there is no way to pass keep_module_in_fp32:
https://huggingface.co/docs/accelerate/main/en/package_reference/big_modeling#accelerate.load_checkpoint_and_dispatch

As discussed with @SunMarc, it's better to uniformize this so that we don't have to maintain two different codepaths and rely completely on load_model_dict_into_meta(). Marc has kindly agreed to open a PR to attempt this (this could be done in a series of PRs if needed). But I will join if any help is needed.

Sub-issues

Metadata

Metadata

Labels

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions