Closed
Description
I have a question about JAX / Flax pipeline. We run it on a v3-8 Cloud TPU VM and it works pretty great. The first inference as expected takes around a minute because of the compilation. And the next ones take around 5-6 seconds, which is really impressive. Either way, if we change one of the input parameters other than prompt (width, height, scale, steps), the inference takes again about a minute, so it looks like it compiles it over again.
So the question: is that expected or it still should be 5-6 seconds and we're doing something wrong? If this is expected behavior, are there any ways we can improve inference time when changing these parameters?