@@ -575,6 +575,46 @@ def test_stable_diffusion_attention_chunk(self):
575
575
576
576
assert np .abs (output_2 .images .flatten () - output_1 .images .flatten ()).max () < 1e-4
577
577
578
+ def test_stable_diffusion_negative_prompt (self ):
579
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
580
+ unet = self .dummy_cond_unet
581
+ scheduler = PNDMScheduler (skip_prk_steps = True )
582
+ vae = self .dummy_vae
583
+ bert = self .dummy_text_encoder
584
+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
585
+
586
+ # make sure here that pndm scheduler skips prk
587
+ sd_pipe = StableDiffusionPipeline (
588
+ unet = unet ,
589
+ scheduler = scheduler ,
590
+ vae = vae ,
591
+ text_encoder = bert ,
592
+ tokenizer = tokenizer ,
593
+ safety_checker = self .dummy_safety_checker ,
594
+ feature_extractor = self .dummy_extractor ,
595
+ )
596
+ sd_pipe = sd_pipe .to (device )
597
+ sd_pipe .set_progress_bar_config (disable = None )
598
+
599
+ prompt = "A painting of a squirrel eating a burger"
600
+ negative_prompt = "french fries"
601
+ generator = torch .Generator (device = device ).manual_seed (0 )
602
+ output = sd_pipe (
603
+ prompt ,
604
+ negative_prompt = negative_prompt ,
605
+ generator = generator ,
606
+ guidance_scale = 6.0 ,
607
+ num_inference_steps = 2 ,
608
+ output_type = "np" ,
609
+ )
610
+
611
+ image = output .images
612
+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
613
+
614
+ assert image .shape == (1 , 128 , 128 , 3 )
615
+ expected_slice = np .array ([0.4851 , 0.4617 , 0.4765 , 0.5127 , 0.4845 , 0.5153 , 0.5141 , 0.4886 , 0.4719 ])
616
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
617
+
578
618
def test_score_sde_ve_pipeline (self ):
579
619
unet = self .dummy_uncond_unet
580
620
scheduler = ScoreSdeVeScheduler ()
@@ -704,6 +744,48 @@ def test_stable_diffusion_img2img(self):
704
744
assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
705
745
assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
706
746
747
+ def test_stable_diffusion_img2img_negative_prompt (self ):
748
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
749
+ unet = self .dummy_cond_unet
750
+ scheduler = PNDMScheduler (skip_prk_steps = True )
751
+ vae = self .dummy_vae
752
+ bert = self .dummy_text_encoder
753
+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
754
+
755
+ init_image = self .dummy_image .to (device )
756
+
757
+ # make sure here that pndm scheduler skips prk
758
+ sd_pipe = StableDiffusionImg2ImgPipeline (
759
+ unet = unet ,
760
+ scheduler = scheduler ,
761
+ vae = vae ,
762
+ text_encoder = bert ,
763
+ tokenizer = tokenizer ,
764
+ safety_checker = self .dummy_safety_checker ,
765
+ feature_extractor = self .dummy_extractor ,
766
+ )
767
+ sd_pipe = sd_pipe .to (device )
768
+ sd_pipe .set_progress_bar_config (disable = None )
769
+
770
+ prompt = "A painting of a squirrel eating a burger"
771
+ negative_prompt = "french fries"
772
+ generator = torch .Generator (device = device ).manual_seed (0 )
773
+ output = sd_pipe (
774
+ prompt ,
775
+ negative_prompt = negative_prompt ,
776
+ generator = generator ,
777
+ guidance_scale = 6.0 ,
778
+ num_inference_steps = 2 ,
779
+ output_type = "np" ,
780
+ init_image = init_image ,
781
+ )
782
+ image = output .images
783
+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
784
+
785
+ assert image .shape == (1 , 32 , 32 , 3 )
786
+ expected_slice = np .array ([0.4065 , 0.3783 , 0.4050 , 0.5266 , 0.4781 , 0.4252 , 0.4203 , 0.4692 , 0.4365 ])
787
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
788
+
707
789
def test_stable_diffusion_img2img_multiple_init_images (self ):
708
790
device = "cpu" # ensure determinism for the device-dependent torch.Generator
709
791
unet = self .dummy_cond_unet
@@ -861,6 +943,52 @@ def test_stable_diffusion_inpaint(self):
861
943
assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
862
944
assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
863
945
946
+ def test_stable_diffusion_inpaint_negative_prompt (self ):
947
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
948
+ unet = self .dummy_cond_unet
949
+ scheduler = PNDMScheduler (skip_prk_steps = True )
950
+ vae = self .dummy_vae
951
+ bert = self .dummy_text_encoder
952
+ tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
953
+
954
+ image = self .dummy_image .cpu ().permute (0 , 2 , 3 , 1 )[0 ]
955
+ init_image = Image .fromarray (np .uint8 (image )).convert ("RGB" )
956
+ mask_image = Image .fromarray (np .uint8 (image + 4 )).convert ("RGB" ).resize ((128 , 128 ))
957
+
958
+ # make sure here that pndm scheduler skips prk
959
+ sd_pipe = StableDiffusionInpaintPipeline (
960
+ unet = unet ,
961
+ scheduler = scheduler ,
962
+ vae = vae ,
963
+ text_encoder = bert ,
964
+ tokenizer = tokenizer ,
965
+ safety_checker = self .dummy_safety_checker ,
966
+ feature_extractor = self .dummy_extractor ,
967
+ )
968
+ sd_pipe = sd_pipe .to (device )
969
+ sd_pipe .set_progress_bar_config (disable = None )
970
+
971
+ prompt = "A painting of a squirrel eating a burger"
972
+ negative_prompt = "french fries"
973
+ generator = torch .Generator (device = device ).manual_seed (0 )
974
+ output = sd_pipe (
975
+ prompt ,
976
+ negative_prompt = negative_prompt ,
977
+ generator = generator ,
978
+ guidance_scale = 6.0 ,
979
+ num_inference_steps = 2 ,
980
+ output_type = "np" ,
981
+ init_image = init_image ,
982
+ mask_image = mask_image ,
983
+ )
984
+
985
+ image = output .images
986
+ image_slice = image [0 , - 3 :, - 3 :, - 1 ]
987
+
988
+ assert image .shape == (1 , 32 , 32 , 3 )
989
+ expected_slice = np .array ([0.4765 , 0.5339 , 0.4541 , 0.6240 , 0.5439 , 0.4055 , 0.5503 , 0.5891 , 0.5150 ])
990
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
991
+
864
992
def test_stable_diffusion_num_images_per_prompt (self ):
865
993
device = "cpu" # ensure determinism for the device-dependent torch.Generator
866
994
unet = self .dummy_cond_unet
0 commit comments