Skip to content

Initialize Flax model params on CPU  #24711

Open
@gianlucadetommaso

Description

@gianlucadetommaso

Feature request

Currently, the from_pretrained method of Flax models automatically puts model parameters on a single GPU device, if available. For very large models, this is not great, as the model parameters may just not fit on GPU memory.

In contrast, when passing _do_init=False to from_pretrained, the parameters are returned on CPU, outside the model.

I would love to have a feature that allows me to initialize model parameters on the device I want - in this case, on CPU - but at the same time initialize the model parameters within the model. Right now I have to call _do_init=False to avoid out-of-memory, but this causes inconsistencies with my API.

The feature could be either implemented as just another type (if we detect a numpy type, we initialize on CPU; otherwise on GPU) or as an additional argument, e.g. initialize_on_cpu: bool = False.

Motivation

Described above. Another reason is to be more consistent with the PyTorch behaviour, where parameters are initialized (as a generator) on CPU.

Your contribution

If we agree on on the design, I am happy to add this myself.

Metadata

Metadata

Assignees

No one assigned

    Labels

    FlaxGood Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions