-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Refactor] Remove set_seed #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py
Show resolved
Hide resolved
The documentation is not available anymore as the PR was closed or merged. |
@@ -98,6 +100,9 @@ def get_adjacent_sigma(self, timesteps, t): | |||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") | |||
|
|||
def set_seed(self, seed): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not work with set_seed(...)
here as it goes against the generator
design. IMO passing generators around is the correct thing to do here because:
- We cannot really pass seeds around and globally set
manual_seed(...)
every time:- It's not a great design to retrieve the current seed from a generator and then pass this -> better to pass generator directly
- Flax passes PNRG keys around that can be split -> we cannot split PyTorch seeds in the same way and we cannot pass a
seed
into atorch.randn(...)
function -> so let's go for the generator here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@@ -42,11 +42,11 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu | |||
# correction step | |||
for _ in range(self.scheduler.correct_steps): | |||
model_output = self.unet(sample, sigma_t)["sample"] | |||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] | |||
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's pass the generator here
…into refactors
@@ -56,7 +56,7 @@ def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs): | |||
model_output = self.unet(image, t)["sample"] | |||
|
|||
# 2. compute previous image: x_t -> t_t-1 | |||
image = self.scheduler.step(model_output, t, image)["prev_sample"] | |||
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DDPM scheduler is also stochastic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! Let's see how the tests behave now
…into refactors
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
* [Refactor] Remove set_seed and class attributes * apply anton's suggestiosn * fix * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up * update * make style * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> * make fix-copies * make style * make style and new copies Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
* [Refactor] Remove set_seed and class attributes * apply anton's suggestiosn * fix * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up * update * make style * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> * make fix-copies * make style * make style and new copies Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
No description provided.