Skip to content

Commit 4d9e420

Browse files
committed
init
1 parent 130c9d7 commit 4d9e420

File tree

96 files changed

+16597
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+16597
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/.idea

README.md

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# MagicFusion: Boosting Text-to-Image Generation Performance by Fusing Diffusion Models
2+
3+
### Abstract
4+
5+
> The advent of open-source AI communities has produced a cornucopia of powerful text-guided diffusion models that are trained on various datasets. While few explorations have been conducted on ensembling such models to combine their strengths. In this work, we propose a simple yet effective method called Saliency-aware Noise Blending (SNB) that can empower the fused text-guided diffusion models to achieve more controllable generation. Specifically, we experimentally find that the responses of classifier-free guidance are highly related to the saliency of generated images. Thus we propose to trust different models in their areas of expertise by blending the predicted noises of two diffusion models in a saliency-aware manner. SNB is training-free and can be completed within a DDIM sampling process. Additionally, it can automatically align the semantics of two noise spaces without requiring additional annotations such as masks. Extensive experiments show the impressive effectiveness of SNB in various applications. Project page is available at https://magicfusion.github.io.
6+
7+
### An overview of our Saiency-aware Noise Blending.
8+
9+
![](figures/method.png)
10+
11+
### Preparation
12+
13+
#### Environment
14+
15+
First set-up the `ldm` enviroment following the instruction
16+
from [textual inversion](https://github.com/rinongal/textual_inversion) repo, or the
17+
original [Stable Diffusion](https://github.com/CompVis/stable-diffusion) repo.
18+
19+
#### Models
20+
21+
To use our method, you need to obtain the pre-trained stable diffusion models following their instructions. You can
22+
decide which version of checkpoint to use, but I use `sd-v1-4-full-ema.ckpt`. You can grab the stable diffusion
23+
model `sd-v1-4-full-ema.ckpt` from https://huggingface.co/CompVis/stable-diffusion-v-1-4-original and make sure to put
24+
it in the root path. Similarly, for the Gartoon model `anything-v3-full.safetensors`, just head over
25+
to https://huggingface.co/Linaqruf/anything-v3.0/tree/main and download the file 'anything-v3-full.safetensors' to the
26+
root path.
27+
28+
### Try MagicFusion
29+
30+
Our method can be quickly utilized through the `magicfusion.sh` file. In this file, you can specify the paths to two
31+
pre-trained models and their corresponding prompts. The program will then generate individual results for each model as
32+
well as the fused output, which will be saved in the designated `outdir` path.
33+
34+
Magicfusion provides strong controllability over the fusion effects of two pre-trained models. To obtain content that
35+
meets our generation requirements, we can first observe the generation results of the two pre-trained models. For
36+
example, in Cross-domain Fusion, the generated results for the prompt "A lion with a crown on his head" with the two
37+
pre-trained models are as follows:
38+
39+
**Genearal Model:**
40+
41+
![](figures/github_lion_general.png)
42+
43+
**Cartoon Model:**
44+
45+
![](figures/github_lion_cartoon.png)
46+
47+
It's pretty obvious that the cartoon model goes a bit too far in anthropomorphizing the lion, whereas the composition
48+
generated by the general model is more aligned with our creative requirements. However, the general model fails to
49+
generate crown.
50+
51+
We can achieve a composition for the generated output that matches the output of the general model while still
52+
maintaining a cartoon style, by simply fusing two models.
53+
54+
![](figures/github_lion_fusion_full.png)
55+
56+
Furthermore, we enhance the realism of the generated images by utilizing more of the noise generated by the general
57+
model during the sampling process. Specifically, the fusion is performed at 1000-600 steps, and after 600 steps only the
58+
noise generated by the general model is used.
59+
60+
![](figures/github_lion_fusion.png)
61+
62+
In this case, we set `magicfusion.sh` as follows,
63+
64+
```
65+
model0="sd-v1-4-full-ema.ckpt"
66+
model1="anything-v3-full.safetensors"
67+
prompt="A lion with a crown on his head"
68+
prompt1="A lion with a crown on his head"
69+
70+
outdir="output/lion_crown"
71+
fusion_selection="600,1,1"
72+
merge_mode=2
73+
```
74+
75+
Feel free to explore the fusion effects of other pre-trained models by replacing `model0` and `model1` and utilizing
76+
the `fusion_selection` tool. Specifically, in the `fusion_selection="v1,v2,v3"` command, the first parameter `v1`
77+
specifies the time point at which our Saliency-aware Noise Blending (SNB) is introduced (or be stoped) during the
78+
sampling process, with one model serving as the default noise generator before this point. The remaining
79+
parameters, `v2`
80+
and `v3`, correspond to the `kg` and `ke` terms in equation 4 of the paper, respectively, and play a crucial role in
81+
determining the fusion details for each time step.
82+
83+
For prompts like 'a bee is making honey', the cartoon model provides a more accurate composition. Therefore, we utilize
84+
more noise generated by the cartoon model during the sampling process. For instance, in the first 1000-850 steps of
85+
sampling, we exclusively use the cartoon model's noise to establish the basic composition of the generated image. We
86+
then switch to our Saliency-aware Noise Blending for the remaining 850 steps to enhance the realism of the scene.
87+
88+
In this case, we set `magicfusion.sh` as follows,
89+
90+
```
91+
model0="sd-v1-4-full-ema.ckpt"
92+
model1="anything-v3-full.safetensors"
93+
prompt="A bee is making honey"
94+
prompt1="A bee is making honey"
95+
96+
outdir="output/making_honey"
97+
fusion_selection="850,1,1"
98+
merge_mode=1
99+
```
100+
101+
102+
For Application 1, i.e., fine-grained fusion, you just need to change expert model to a fine-grained car model, then
103+
give a scene prompt to general model and a fine-grained prompt to car model.
104+
105+
For Application 2, i.e., recontextualization, you need to fine-tune the stable diffusion model for a specific object
106+
like [Dreambooth](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion) does. To generate the specific object with
107+
the placeholder `[]` (e.g. sks), both `model0` and `model1` are specified as the fine-tuned model. Then, the fusion
108+
result can be abtained by giving the prompt '`a photo of [] <class>`' and '`a photo of [] <class> <environment>`'.
109+
110+
For Application 1 and 2, we get the experimental result just by setting `merge_mode=1`.
111+
112+
The program is run with the following command.
113+
114+
```
115+
python scripts/stable_txt2img.py --model0 $model0 --model1 $model1 --outdir $outdir --prompt "$prompt" --prompt1 "$prompt1" --fusion_selection $fusion_selection --mode $merge_mode
116+
```
117+
118+
Overall, the Salience Map of noise is an important and interesting finding, and MagicFusion can achieve high
119+
controllability of the generated content by simply setting the fusion strategy in the sampling process (e.g., when to
120+
start or stop fusion). More principles and laws about the noise space of diffusion models are expected to be discovered
121+
by exploring our MagicFusion.
122+
123+
Thanks.
124+
125+
### BibTeX
126+
127+
```
128+
@misc{zhao2023magicfusion,
129+
title={MagicFusion: Boosting Text-to-Image Generation Performance by Fusing Diffusion Models},
130+
author={Jing Zhao and Heliang Zheng and Chaoyue Wang and Long Lan and Wenjing Yang},
131+
year={2023},
132+
eprint={2303.13126},
133+
archivePrefix={arXiv},
134+
primaryClass={cs.CV}
135+
```
136+
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
model:
2+
base_learning_rate: 4.5e-6
3+
target: ldm.models.autoencoder.AutoencoderKL
4+
params:
5+
monitor: "val/rec_loss"
6+
embed_dim: 16
7+
lossconfig:
8+
target: ldm.modules.losses.LPIPSWithDiscriminator
9+
params:
10+
disc_start: 50001
11+
kl_weight: 0.000001
12+
disc_weight: 0.5
13+
14+
ddconfig:
15+
double_z: True
16+
z_channels: 16
17+
resolution: 256
18+
in_channels: 3
19+
out_ch: 3
20+
ch: 128
21+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22+
num_res_blocks: 2
23+
attn_resolutions: [16]
24+
dropout: 0.0
25+
26+
27+
data:
28+
target: main.DataModuleFromConfig
29+
params:
30+
batch_size: 12
31+
wrap: True
32+
train:
33+
target: ldm.data.imagenet.ImageNetSRTrain
34+
params:
35+
size: 256
36+
degradation: pil_nearest
37+
validation:
38+
target: ldm.data.imagenet.ImageNetSRValidation
39+
params:
40+
size: 256
41+
degradation: pil_nearest
42+
43+
lightning:
44+
callbacks:
45+
image_logger:
46+
target: main.ImageLogger
47+
params:
48+
batch_frequency: 1000
49+
max_images: 8
50+
increase_log_steps: True
51+
52+
trainer:
53+
benchmark: True
54+
accumulate_grad_batches: 2
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
model:
2+
base_learning_rate: 4.5e-6
3+
target: ldm.models.autoencoder.AutoencoderKL
4+
params:
5+
monitor: "val/rec_loss"
6+
embed_dim: 4
7+
lossconfig:
8+
target: ldm.modules.losses.LPIPSWithDiscriminator
9+
params:
10+
disc_start: 50001
11+
kl_weight: 0.000001
12+
disc_weight: 0.5
13+
14+
ddconfig:
15+
double_z: True
16+
z_channels: 4
17+
resolution: 256
18+
in_channels: 3
19+
out_ch: 3
20+
ch: 128
21+
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22+
num_res_blocks: 2
23+
attn_resolutions: [ ]
24+
dropout: 0.0
25+
26+
data:
27+
target: main.DataModuleFromConfig
28+
params:
29+
batch_size: 12
30+
wrap: True
31+
train:
32+
target: ldm.data.imagenet.ImageNetSRTrain
33+
params:
34+
size: 256
35+
degradation: pil_nearest
36+
validation:
37+
target: ldm.data.imagenet.ImageNetSRValidation
38+
params:
39+
size: 256
40+
degradation: pil_nearest
41+
42+
lightning:
43+
callbacks:
44+
image_logger:
45+
target: main.ImageLogger
46+
params:
47+
batch_frequency: 1000
48+
max_images: 8
49+
increase_log_steps: True
50+
51+
trainer:
52+
benchmark: True
53+
accumulate_grad_batches: 2
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
model:
2+
base_learning_rate: 4.5e-6
3+
target: ldm.models.autoencoder.AutoencoderKL
4+
params:
5+
monitor: "val/rec_loss"
6+
embed_dim: 3
7+
lossconfig:
8+
target: ldm.modules.losses.LPIPSWithDiscriminator
9+
params:
10+
disc_start: 50001
11+
kl_weight: 0.000001
12+
disc_weight: 0.5
13+
14+
ddconfig:
15+
double_z: True
16+
z_channels: 3
17+
resolution: 256
18+
in_channels: 3
19+
out_ch: 3
20+
ch: 128
21+
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22+
num_res_blocks: 2
23+
attn_resolutions: [ ]
24+
dropout: 0.0
25+
26+
27+
data:
28+
target: main.DataModuleFromConfig
29+
params:
30+
batch_size: 12
31+
wrap: True
32+
train:
33+
target: ldm.data.imagenet.ImageNetSRTrain
34+
params:
35+
size: 256
36+
degradation: pil_nearest
37+
validation:
38+
target: ldm.data.imagenet.ImageNetSRValidation
39+
params:
40+
size: 256
41+
degradation: pil_nearest
42+
43+
lightning:
44+
callbacks:
45+
image_logger:
46+
target: main.ImageLogger
47+
params:
48+
batch_frequency: 1000
49+
max_images: 8
50+
increase_log_steps: True
51+
52+
trainer:
53+
benchmark: True
54+
accumulate_grad_batches: 2
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
model:
2+
base_learning_rate: 4.5e-6
3+
target: ldm.models.autoencoder.AutoencoderKL
4+
params:
5+
monitor: "val/rec_loss"
6+
embed_dim: 64
7+
lossconfig:
8+
target: ldm.modules.losses.LPIPSWithDiscriminator
9+
params:
10+
disc_start: 50001
11+
kl_weight: 0.000001
12+
disc_weight: 0.5
13+
14+
ddconfig:
15+
double_z: True
16+
z_channels: 64
17+
resolution: 256
18+
in_channels: 3
19+
out_ch: 3
20+
ch: 128
21+
ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22+
num_res_blocks: 2
23+
attn_resolutions: [16,8]
24+
dropout: 0.0
25+
26+
data:
27+
target: main.DataModuleFromConfig
28+
params:
29+
batch_size: 12
30+
wrap: True
31+
train:
32+
target: ldm.data.imagenet.ImageNetSRTrain
33+
params:
34+
size: 256
35+
degradation: pil_nearest
36+
validation:
37+
target: ldm.data.imagenet.ImageNetSRValidation
38+
params:
39+
size: 256
40+
degradation: pil_nearest
41+
42+
lightning:
43+
callbacks:
44+
image_logger:
45+
target: main.ImageLogger
46+
params:
47+
batch_frequency: 1000
48+
max_images: 8
49+
increase_log_steps: True
50+
51+
trainer:
52+
benchmark: True
53+
accumulate_grad_batches: 2

0 commit comments

Comments
 (0)