Closed
Description
For example here, it is using from dataclasses import dataclass
diffusers/src/diffusers/models/unet_2d_condition_flax.py
Lines 22 to 23 in d8b0e4f
But transformers equivalents use @flax.struct.dataclass
. For example here
@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):
The benefit of using @flax.struct.dataclass
over naive python dataclass
is that: jax.jit
can consume @flax.struct.dataclass
So the question is: should we use @flax.struct.dataclass
on diffusers
as well ?
Metadata
Metadata
Labels
No labels