Skip to content

Commit f2a3e1e

Browse files
committed
Allegro VAE fix (#9811)
fix
1 parent d791908 commit f2a3e1e

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,8 +1091,6 @@ def forward(
10911091
sample_posterior: bool = False,
10921092
return_dict: bool = True,
10931093
generator: Optional[torch.Generator] = None,
1094-
encoder_local_batch_size: int = 2,
1095-
decoder_local_batch_size: int = 2,
10961094
) -> Union[DecoderOutput, torch.Tensor]:
10971095
r"""
10981096
Args:
@@ -1103,18 +1101,14 @@ def forward(
11031101
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
11041102
generator (`torch.Generator`, *optional*):
11051103
PyTorch random number generator.
1106-
encoder_local_batch_size (`int`, *optional*, defaults to 2):
1107-
Local batch size for the encoder's batch inference.
1108-
decoder_local_batch_size (`int`, *optional*, defaults to 2):
1109-
Local batch size for the decoder's batch inference.
11101104
"""
11111105
x = sample
1112-
posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
1106+
posterior = self.encode(x).latent_dist
11131107
if sample_posterior:
11141108
z = posterior.sample(generator=generator)
11151109
else:
11161110
z = posterior.mode()
1117-
dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
1111+
dec = self.decode(z).sample
11181112

11191113
if not return_dict:
11201114
return (dec,)

0 commit comments

Comments
 (0)