@@ -1091,8 +1091,6 @@ def forward(
1091
1091
sample_posterior : bool = False ,
1092
1092
return_dict : bool = True ,
1093
1093
generator : Optional [torch .Generator ] = None ,
1094
- encoder_local_batch_size : int = 2 ,
1095
- decoder_local_batch_size : int = 2 ,
1096
1094
) -> Union [DecoderOutput , torch .Tensor ]:
1097
1095
r"""
1098
1096
Args:
@@ -1103,18 +1101,14 @@ def forward(
1103
1101
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1104
1102
generator (`torch.Generator`, *optional*):
1105
1103
PyTorch random number generator.
1106
- encoder_local_batch_size (`int`, *optional*, defaults to 2):
1107
- Local batch size for the encoder's batch inference.
1108
- decoder_local_batch_size (`int`, *optional*, defaults to 2):
1109
- Local batch size for the decoder's batch inference.
1110
1104
"""
1111
1105
x = sample
1112
- posterior = self .encode (x , local_batch_size = encoder_local_batch_size ).latent_dist
1106
+ posterior = self .encode (x ).latent_dist
1113
1107
if sample_posterior :
1114
1108
z = posterior .sample (generator = generator )
1115
1109
else :
1116
1110
z = posterior .mode ()
1117
- dec = self .decode (z , local_batch_size = decoder_local_batch_size ).sample
1111
+ dec = self .decode (z ).sample
1118
1112
1119
1113
if not return_dict :
1120
1114
return (dec ,)
0 commit comments