From 34c5ebd3e322a0f7c8c4589b3cc9c4ecc3b9b458 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Sun, 18 Sep 2022 08:50:21 +0000 Subject: [PATCH] FlaxUNet2DConditionOutput @flax.struct.dataclass --- src/diffusers/models/unet_2d_condition_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 1ac68e10c159..636a7ef9816a 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass from typing import Tuple, Union +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -19,7 +19,7 @@ ) -@dataclass +@flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ Args: