Skip to content

[Community] Make safety model end-to-end compileable - Inference time of JAX / Flax pipeline #927

Closed
@ngaer

Description

@ngaer

I have a question about JAX / Flax pipeline. We run it on a v3-8 Cloud TPU VM and it works pretty great. The first inference as expected takes around a minute because of the compilation. And the next ones take around 5-6 seconds, which is really impressive. Either way, if we change one of the input parameters other than prompt (width, height, scale, steps), the inference takes again about a minute, so it looks like it compiles it over again.

So the question: is that expected or it still should be 5-6 seconds and we're doing something wrong? If this is expected behavior, are there any ways we can improve inference time when changing these parameters?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions