Skip to content

Question about flax model output classes #528

Closed
@mishig25

Description

@mishig25

For example here, it is using from dataclasses import dataclass

@dataclass
class FlaxUNet2DConditionOutput(BaseOutput):

But transformers equivalents use @flax.struct.dataclass. For example here

@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):

The benefit of using @flax.struct.dataclass over naive python dataclass is that: jax.jit can consume @flax.struct.dataclass

So the question is: should we use @flax.struct.dataclass on diffusers as well ?

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions