Skip to content

Fix nondeterministic tests for GPU runs #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 37 additions & 44 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_ldm_text2img(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
Expand All @@ -259,14 +260,11 @@ def test_stable_diffusion_ddim(self):
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(device)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
)
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")

image = output["sample"]

Expand All @@ -277,6 +275,7 @@ def test_stable_diffusion_ddim(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
vae = self.dummy_vae
Expand All @@ -293,14 +292,11 @@ def test_stable_diffusion_pndm(self):
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(device)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
)
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")

image = output["sample"]

Expand All @@ -311,8 +307,8 @@ def test_stable_diffusion_pndm(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
Expand All @@ -328,14 +324,11 @@ def test_stable_diffusion_k_lms(self):
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(device)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
)
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")

image = output["sample"]

Expand Down Expand Up @@ -395,13 +388,14 @@ def test_karras_ve_pipeline(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

init_image = self.dummy_image
init_image = self.dummy_image.to(device)

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
Expand All @@ -413,19 +407,18 @@ def test_stable_diffusion_img2img(self):
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(device)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
)
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
)

image = output["sample"]

Expand All @@ -436,13 +429,14 @@ def test_stable_diffusion_img2img(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image = self.dummy_image.permute(0, 2, 3, 1)[0]
image = self.dummy_image.to(device).permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))

Expand All @@ -456,20 +450,19 @@ def test_stable_diffusion_inpaint(self):
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(device)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
)
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
)

image = output["sample"]

Expand Down
17 changes: 8 additions & 9 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import torch

from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from diffusers.testing_utils import slow, torch_device
from diffusers.training_utils import enable_full_determinism, set_seed
from diffusers.testing_utils import slow
from diffusers.training_utils import set_seed


torch.backends.cuda.matmul.allow_tf32 = False
Expand All @@ -34,8 +34,7 @@ def get_model_optimizer(self, resolution=32):

@slow
def test_training_step_equality(self):
enable_full_determinism(0)

device = "cpu" # ensure full determinism without setting the CUBLAS_WORKSPACE_CONFIG env variable
ddpm_scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
Expand All @@ -57,13 +56,13 @@ def test_training_step_equality(self):

# shared batches for DDPM and DDIM
set_seed(0)
clean_images = [torch.randn((4, 3, 32, 32)).clip(-1, 1).to(torch_device) for _ in range(4)]
noise = [torch.randn((4, 3, 32, 32)).to(torch_device) for _ in range(4)]
timesteps = [torch.randint(0, 1000, (4,)).long().to(torch_device) for _ in range(4)]
clean_images = [torch.randn((4, 3, 32, 32)).clip(-1, 1).to(device) for _ in range(4)]
noise = [torch.randn((4, 3, 32, 32)).to(device) for _ in range(4)]
timesteps = [torch.randint(0, 1000, (4,)).long().to(device) for _ in range(4)]

# train with a DDPM scheduler
model, optimizer = self.get_model_optimizer(resolution=32)
model.train().to(torch_device)
model.train().to(device)
for i in range(4):
optimizer.zero_grad()
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
Expand All @@ -75,7 +74,7 @@ def test_training_step_equality(self):

# recreate the model and optimizer, and retry with DDIM
model, optimizer = self.get_model_optimizer(resolution=32)
model.train().to(torch_device)
model.train().to(device)
for i in range(4):
optimizer.zero_grad()
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
Expand Down