Skip to content

Commit effe9d6

Browse files
authored
[FlaxStableDiffusionPipeline] fix bug when nsfw is detected (#832)
fix nsfw bug
1 parent 0679d09 commit effe9d6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def __call__(
291291
# block images
292292
if any(has_nsfw_concept):
293293
for i, is_nsfw in enumerate(has_nsfw_concept):
294-
images[i] = np.asarray(images_uint8_casted[i])
294+
if is_nsfw:
295+
images[i] = np.asarray(images_uint8_casted[i])
295296

296297
images = images.reshape(num_devices, batch_size, height, width, 3)
297298
else:

0 commit comments

Comments
 (0)