Skip to content

Commit 82bd9b1

Browse files
authored
add DDIM to CLIP Guided Stable Diffusion and add code example (#4920)
* add DDIM to CLIP Guided Stable Diffusion and add example code * modify code sample * use generator and modify code sample
1 parent 9a00c15 commit 82bd9b1

File tree

3 files changed

+126
-30
lines changed

3 files changed

+126
-30
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Community Examples
2+
3+
社区示例包含由社区添加的推理和训练示例。可以从下表中了解所有社区实例的概况。点击**Code Example**,跳转到对应实例的可运行代码,可以复制并运行。如果一个示例不能像预期的那样工作,请创建一个issue提问。
4+
5+
|Example|Description|Code Example|Author|
6+
|-|-|-|-|
7+
|CLIP Guided Stable Diffusion|使用CLIP引导Stable Diffusion实现文生图|[CLIP Guided Stable Diffusion](#CLIP%20Guided%20Stable%20Diffusion)||
8+
9+
## Example usages
10+
11+
### CLIP Guided Stable Diffusion
12+
13+
使用 CLIP 模型引导 Stable Diffusion 去噪,可以生成更真实的图像。
14+
15+
以下代码运行需要16GB的显存。
16+
17+
```python
18+
import os
19+
20+
import paddle
21+
from clip_guided_stable_diffusion import CLIPGuidedStableDiffusion
22+
23+
from paddlenlp.transformers import CLIPFeatureExtractor, CLIPModel
24+
25+
feature_extractor = CLIPFeatureExtractor.from_pretrained(
26+
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
27+
clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
28+
dtype=paddle.float32)
29+
30+
guided_pipeline = CLIPGuidedStableDiffusion.from_pretrained(
31+
"runwayml/stable-diffusion-v1-5",
32+
clip_model=clip_model,
33+
feature_extractor=feature_extractor,
34+
paddle_dtype=paddle.float16,
35+
)
36+
guided_pipeline.enable_attention_slicing()
37+
38+
prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
39+
40+
generator = paddle.Generator().manual_seed(2022)
41+
with paddle.amp.auto_cast(True, level="O2"):
42+
images = []
43+
for i in range(4):
44+
image = guided_pipeline(
45+
prompt,
46+
num_inference_steps=50,
47+
guidance_scale=7.5,
48+
clip_guidance_scale=100,
49+
num_cutouts=4,
50+
use_cutouts=False,
51+
generator=generator,
52+
unfreeze_unet=False,
53+
unfreeze_vae=False,
54+
).images[0]
55+
images.append(image)
56+
57+
# save images locally
58+
if not os.path.exists("clip_guided_sd"):
59+
os.mkdir("clip_guided_sd")
60+
for i, img in enumerate(images):
61+
img.save(f"./clip_guided_sd/image_{i}.png")
62+
```
63+
生成的图片保存在`images`列表中,样例如下:
64+
65+
| image_0 | image_1 | image_2 | image_3 |
66+
|:-------------------:|:-------------------:|:-------------------:|:-------------------:|
67+
|![][clip_guided_sd_0]|![][clip_guided_sd_1]|![][clip_guided_sd_2]|![][clip_guided_sd_3]|
68+
69+
[clip_guided_sd_0]: https://user-images.githubusercontent.com/40912707/220514674-e5cb29a3-b07e-4e8f-a4c8-323b35637294.png
70+
[clip_guided_sd_1]: https://user-images.githubusercontent.com/40912707/220514703-1eaf444e-1506-4c44-b686-5950fd79a3da.png
71+
[clip_guided_sd_2]: https://user-images.githubusercontent.com/40912707/220514765-89e48c13-156f-4e61-b433-06f1283d2265.png
72+
[clip_guided_sd_3]: https://user-images.githubusercontent.com/40912707/220514751-82d63fd4-e35e-482b-a8e1-c5c956119b2e.png

ppdiffusers/examples/community/clip_guided_stable_diffusion.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
from typing import Callable, List, Optional, Union
1718

1819
import paddle
@@ -28,6 +29,7 @@
2829
)
2930
from ppdiffusers import (
3031
AutoencoderKL,
32+
DDIMScheduler,
3133
DiffusionPipeline,
3234
LMSDiscreteScheduler,
3335
PNDMScheduler,
@@ -84,7 +86,7 @@ def __init__(
8486
clip_model: CLIPModel,
8587
tokenizer: CLIPTokenizer,
8688
unet: UNet2DConditionModel,
87-
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler],
89+
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
8890
feature_extractor: CLIPFeatureExtractor,
8991
):
9092
super().__init__()
@@ -99,7 +101,12 @@ def __init__(
99101
)
100102

101103
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
102-
self.make_cutouts = MakeCutouts(feature_extractor.size)
104+
self.cut_out_size = (
105+
feature_extractor.size
106+
if isinstance(feature_extractor.size, int)
107+
else feature_extractor.size["shortest_edge"]
108+
)
109+
self.make_cutouts = MakeCutouts(self.cut_out_size)
103110

104111
set_stop_gradient(self.text_encoder, True)
105112
set_stop_gradient(self.clip_model, True)
@@ -152,7 +159,7 @@ def cond_fn(
152159
# predict the noise residual
153160
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
154161

155-
if isinstance(self.scheduler, PNDMScheduler):
162+
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
156163
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
157164
beta_prod_t = 1 - alpha_prod_t
158165
# compute predicted original sample from predicted noise also called
@@ -174,7 +181,7 @@ def cond_fn(
174181
if use_cutouts:
175182
image = self.make_cutouts(image, num_cutouts)
176183
else:
177-
resize_transform = transforms.Resize(self.feature_extractor.size)
184+
resize_transform = transforms.Resize(self.cut_out_size)
178185
image = paddle.stack([resize_transform(img) for img in image], axis=0)
179186
image = self.normalize(image).astype(latents.dtype)
180187

@@ -207,11 +214,12 @@ def __call__(
207214
guidance_scale: Optional[float] = 7.5,
208215
negative_prompt: Optional[Union[str, List[str]]] = None,
209216
num_images_per_prompt: Optional[int] = 1,
217+
eta: float = 0.0,
210218
clip_guidance_scale: Optional[float] = 100,
211219
clip_prompt: Optional[Union[str, List[str]]] = None,
212220
num_cutouts: Optional[int] = 4,
213221
use_cutouts: Optional[bool] = True,
214-
seed: Optional[int] = None,
222+
generator: Optional[paddle.Generator] = None,
215223
latents: Optional[paddle.Tensor] = None,
216224
output_type: Optional[str] = "pil",
217225
return_dict: bool = True,
@@ -277,9 +285,9 @@ def __call__(
277285
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input_ids)
278286
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, axis=-1, keepdim=True)
279287
# duplicate text embeddings clip for each generation per prompt
280-
bs_embed, seq_len, _ = text_embeddings.shape
281-
text_embeddings_clip = text_embeddings_clip.tile([1, num_images_per_prompt, 1])
282-
text_embeddings_clip = text_embeddings_clip.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
288+
bs_embed, _ = text_embeddings_clip.shape
289+
text_embeddings_clip = text_embeddings_clip.tile([1, num_images_per_prompt])
290+
text_embeddings_clip = text_embeddings_clip.reshape([bs_embed * num_images_per_prompt, -1])
283291

284292
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
285293
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -334,8 +342,7 @@ def __call__(
334342
# However this currently doesn't work in `mps`.
335343
latents_shape = [batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8]
336344
if latents is None:
337-
paddle.seed(seed)
338-
latents = paddle.randn(latents_shape, dtype=text_embeddings.dtype)
345+
latents = paddle.randn(latents_shape, generator=generator, dtype=text_embeddings.dtype)
339346
else:
340347
if latents.shape != latents_shape:
341348
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
@@ -350,6 +357,20 @@ def __call__(
350357
# scale the initial noise by the standard deviation required by the scheduler
351358
latents = latents * self.scheduler.init_noise_sigma
352359

360+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
361+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
362+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
363+
# and should be between [0, 1]
364+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
365+
extra_step_kwargs = {}
366+
if accepts_eta:
367+
extra_step_kwargs["eta"] = eta
368+
369+
# check if the scheduler accepts generator
370+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
371+
if accepts_generator:
372+
extra_step_kwargs["generator"] = generator
373+
353374
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
354375
# expand the latents if we are doing classifier free guidance
355376
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
@@ -381,7 +402,7 @@ def __call__(
381402
)
382403

383404
# compute the previous noisy sample x_t -> x_t-1
384-
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
405+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
385406

386407
# call the callback, if provided
387408
if callback is not None and i % callback_steps == 0:

ppdiffusers/examples/community/inference_clip_guided_stable_diffusion.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import paddle
1617
from clip_guided_stable_diffusion import CLIPGuidedStableDiffusion
1718
from IPython.display import display
1819
from PIL import Image
@@ -34,7 +35,7 @@ def image_grid(imgs, rows, cols):
3435
def create_clip_guided_pipeline(
3536
model_id="CompVis/stable-diffusion-v1-4", clip_model_id="openai/clip-vit-large-patch14", scheduler="plms"
3637
):
37-
pipeline = StableDiffusionPipeline.from_pretrained(model_id)
38+
pipeline = StableDiffusionPipeline.from_pretrained(model_id, paddle_dtype=paddle.float16)
3839

3940
if scheduler == "lms":
4041
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
@@ -116,29 +117,31 @@ def infer(
116117
clip_guidance_scale = 100 # @param {type: "number"}
117118
num_cutouts = 4 # @param {type: "number"}
118119
use_cutouts = False # @param ["False", "True"]
119-
unfreeze_unet = True # @param ["False", "True"]
120-
unfreeze_vae = True # @param ["False", "True"]
120+
unfreeze_unet = False # @param ["False", "True"]
121+
unfreeze_vae = False # @param ["False", "True"]
121122
seed = 3788086447 # @param {type: "number"}
122123

123124
model_id = "CompVis/stable-diffusion-v1-4"
124125
clip_model_id = "openai/clip-vit-large-patch14" # @param ["openai/clip-vit-base-patch32", "openai/clip-vit-base-patch14", "openai/clip-rn101", "openai/clip-rn50"] {allow-input: true}
125126
scheduler = "plms" # @param ['plms', 'lms']
126127
guided_pipeline = create_clip_guided_pipeline(model_id, clip_model_id)
127-
grid_image = infer(
128-
prompt=prompt,
129-
negative_prompt=negative_prompt,
130-
clip_prompt=clip_prompt,
131-
num_return_images=num_return_images,
132-
num_images_per_prompt=num_images_per_prompt,
133-
num_inference_steps=num_inference_steps,
134-
clip_guidance_scale=clip_guidance_scale,
135-
guidance_scale=guidance_scale,
136-
guided_pipeline=guided_pipeline,
137-
use_cutouts=use_cutouts,
138-
num_cutouts=num_cutouts,
139-
seed=seed,
140-
unfreeze_unet=unfreeze_unet,
141-
unfreeze_vae=unfreeze_vae,
142-
)
128+
129+
with paddle.amp.auto_cast(True, level="O2"):
130+
grid_image = infer(
131+
prompt=prompt,
132+
negative_prompt=negative_prompt,
133+
clip_prompt=clip_prompt,
134+
num_return_images=num_return_images,
135+
num_images_per_prompt=num_images_per_prompt,
136+
num_inference_steps=num_inference_steps,
137+
clip_guidance_scale=clip_guidance_scale,
138+
guidance_scale=guidance_scale,
139+
guided_pipeline=guided_pipeline,
140+
use_cutouts=use_cutouts,
141+
num_cutouts=num_cutouts,
142+
seed=seed,
143+
unfreeze_unet=unfreeze_unet,
144+
unfreeze_vae=unfreeze_vae,
145+
)
143146

144147
display(grid_image)

0 commit comments

Comments
 (0)