Description
Feature request
Currently, the from_pretrained
method of Flax models automatically puts model parameters on a single GPU device, if available. For very large models, this is not great, as the model parameters may just not fit on GPU memory.
In contrast, when passing _do_init=False
to from_pretrained
, the parameters are returned on CPU, outside the model.
I would love to have a feature that allows me to initialize model parameters on the device I want - in this case, on CPU - but at the same time initialize the model parameters within the model. Right now I have to call _do_init=False
to avoid out-of-memory, but this causes inconsistencies with my API.
The feature could be either implemented as just another type (if we detect a numpy type, we initialize on CPU; otherwise on GPU) or as an additional argument, e.g. initialize_on_cpu: bool = False
.
Motivation
Described above. Another reason is to be more consistent with the PyTorch behaviour, where parameters are initialized (as a generator) on CPU.
Your contribution
If we agree on on the design, I am happy to add this myself.