Skip to content

Commit 4674fdf

Browse files
patil-surajpcuenca
andauthored
Add image2image example script. (#231)
* boom boom * reorganise examples * add image2image in example inference * add readme * fix example * update colab url * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * fix init_timestep * update colab url * update main readme * rename readme Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent 6028d58 commit 4674fdf

File tree

5 files changed

+213
-1
lines changed

5 files changed

+213
-1
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ More precisely, 🤗 Diffusers offers:
2323
- State-of-the-art diffusion pipelines that can be run in inference with just a couple of lines of code (see [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)).
2424
- Various noise schedulers that can be used interchangeably for the prefered speed vs. quality trade-off in inference (see [src/diffusers/schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)).
2525
- Multiple types of models, such as UNet, can be used as building blocks in an end-to-end diffusion system (see [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)).
26-
- Training examples to show how to train the most popular diffusion models (see [examples](https://github.com/huggingface/diffusers/tree/main/examples)).
26+
- Training examples to show how to train the most popular diffusion models (see [examples/training](https://github.com/huggingface/diffusers/tree/main/examples/training)).
27+
- Inference examples to show how to create custom pipelines for advanced tasks such as image2image, in-painting (see [examples/inference](https://github.com/huggingface/diffusers/tree/main/examples/inference))
2728

2829
## Quickstart
2930

examples/inference/image_to_image.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import inspect
2+
from typing import List, Optional, Union
3+
4+
import numpy as np
5+
import torch
6+
7+
import PIL
8+
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel
9+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
10+
from tqdm.auto import tqdm
11+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
12+
13+
14+
def preprocess(image):
15+
w, h = image.size
16+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
17+
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
18+
image = np.array(image).astype(np.float32) / 255.0
19+
image = image[None].transpose(0, 3, 1, 2)
20+
image = torch.from_numpy(image)
21+
return 2.0 * image - 1.0
22+
23+
24+
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
25+
def __init__(
26+
self,
27+
vae: AutoencoderKL,
28+
text_encoder: CLIPTextModel,
29+
tokenizer: CLIPTokenizer,
30+
unet: UNet2DConditionModel,
31+
scheduler: Union[DDIMScheduler, PNDMScheduler],
32+
safety_checker: StableDiffusionSafetyChecker,
33+
feature_extractor: CLIPFeatureExtractor,
34+
):
35+
super().__init__()
36+
scheduler = scheduler.set_format("pt")
37+
self.register_modules(
38+
vae=vae,
39+
text_encoder=text_encoder,
40+
tokenizer=tokenizer,
41+
unet=unet,
42+
scheduler=scheduler,
43+
safety_checker=safety_checker,
44+
feature_extractor=feature_extractor,
45+
)
46+
47+
@torch.no_grad()
48+
def __call__(
49+
self,
50+
prompt: Union[str, List[str]],
51+
init_image: torch.FloatTensor,
52+
strength: float = 0.8,
53+
num_inference_steps: Optional[int] = 50,
54+
guidance_scale: Optional[float] = 7.5,
55+
eta: Optional[float] = 0.0,
56+
generator: Optional[torch.Generator] = None,
57+
output_type: Optional[str] = "pil",
58+
):
59+
60+
if isinstance(prompt, str):
61+
batch_size = 1
62+
elif isinstance(prompt, list):
63+
batch_size = len(prompt)
64+
else:
65+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
66+
67+
# set timesteps
68+
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
69+
extra_set_kwargs = {}
70+
offset = 0
71+
if accepts_offset:
72+
offset = 1
73+
extra_set_kwargs["offset"] = 1
74+
75+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
76+
77+
# encode the init image into latents and scale the latents
78+
init_latents = self.vae.encode(init_image.to(self.device)).sample()
79+
init_latents = 0.18215 * init_latents
80+
81+
# prepare init_latents noise to latents
82+
init_latents = torch.cat([init_latents] * batch_size)
83+
84+
# get the original timestep using init_timestep
85+
init_timestep = int(num_inference_steps * strength) + offset
86+
init_timestep = min(init_timestep, num_inference_steps)
87+
timesteps = self.scheduler.timesteps[-init_timestep]
88+
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
89+
90+
# add noise to latents using the timesteps
91+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
92+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
93+
94+
# get prompt text embeddings
95+
text_input = self.tokenizer(
96+
prompt,
97+
padding="max_length",
98+
max_length=self.tokenizer.model_max_length,
99+
truncation=True,
100+
return_tensors="pt",
101+
)
102+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
103+
104+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
105+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
106+
# corresponds to doing no classifier free guidance.
107+
do_classifier_free_guidance = guidance_scale > 1.0
108+
# get unconditional embeddings for classifier free guidance
109+
if do_classifier_free_guidance:
110+
max_length = text_input.input_ids.shape[-1]
111+
uncond_input = self.tokenizer(
112+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
113+
)
114+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
115+
116+
# For classifier free guidance, we need to do two forward passes.
117+
# Here we concatenate the unconditional and text embeddings into a single batch
118+
# to avoid doing two forward passes
119+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
120+
121+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
122+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
123+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
124+
# and should be between [0, 1]
125+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
126+
extra_step_kwargs = {}
127+
if accepts_eta:
128+
extra_step_kwargs["eta"] = eta
129+
130+
latents = init_latents
131+
t_start = max(num_inference_steps - init_timestep + offset, 0)
132+
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
133+
# expand the latents if we are doing classifier free guidance
134+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
135+
136+
# predict the noise residual
137+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
138+
139+
# perform guidance
140+
if do_classifier_free_guidance:
141+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
142+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
143+
144+
# compute the previous noisy sample x_t -> x_t-1
145+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
146+
147+
# scale and decode the image latents with vae
148+
latents = 1 / 0.18215 * latents
149+
image = self.vae.decode(latents)
150+
151+
image = (image / 2 + 0.5).clamp(0, 1)
152+
image = image.cpu().permute(0, 2, 3, 1).numpy()
153+
154+
# run safety checker
155+
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
156+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
157+
158+
if output_type == "pil":
159+
image = self.numpy_to_pil(image)
160+
161+
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}

examples/inference/readme.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Inference Examples
2+
3+
## Installing the dependencies
4+
5+
Before running the scipts, make sure to install the library's dependencies:
6+
7+
```bash
8+
pip install diffusers transformers ftfy
9+
```
10+
11+
## Image-to-Image text-guided generation with Stable Diffusion
12+
13+
The `image_to_image.py` script implements `StableDiffusionImg2ImgPipeline`. It lets you pass a text prompt and an initial image to condition the generation of new images. This example also showcases how you can write custom diffusion pipelines using `diffusers`!
14+
15+
### How to use it
16+
17+
18+
```python
19+
from torch import autocast
20+
import requests
21+
from PIL import Image
22+
from io import BytesIO
23+
24+
from image_to_image import StableDiffusionImg2ImgPipeline, preprocess
25+
26+
# load the pipeline
27+
device = "cuda"
28+
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
29+
"CompVis/stable-diffusion-v1-4",
30+
revision="fp16",
31+
torch_dtype=torch.float16,
32+
use_auth_token=True
33+
).to(device)
34+
35+
# let's download an initial image
36+
url = "https://rg.gosu.cc/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
37+
38+
response = requests.get(url)
39+
init_image = Image.open(BytesIO(response.content)).convert("RGB")
40+
init_image = init_image.resize((768, 512))
41+
init_image = preprocess(init_image)
42+
43+
prompt = "A fantasy landscape, trending on artstation"
44+
45+
with autocast("cuda"):
46+
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
47+
48+
images[0].save("fantasy_landscape.png")
49+
```
50+
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
File renamed without changes.

0 commit comments

Comments
 (0)