Skip to content

Commit e20d266

Browse files
author
Mishig Davaadorj
authored
FlaxUNet2DConditionOutput @flax.struct.dataclass (huggingface#550)
1 parent fc505a8 commit e20d266

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

models/unet_2d_condition_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from dataclasses import dataclass
21
from typing import Tuple, Union
32

3+
import flax
44
import flax.linen as nn
55
import jax
66
import jax.numpy as jnp
@@ -19,7 +19,7 @@
1919
)
2020

2121

22-
@dataclass
22+
@flax.struct.dataclass
2323
class FlaxUNet2DConditionOutput(BaseOutput):
2424
"""
2525
Args:

0 commit comments

Comments
 (0)