Skip to content

Commit 7736954

Browse files
anton-lnatolambert
authored andcommitted
Style the scripts directory (#250)
Style scripts
1 parent dae7849 commit 7736954

7 files changed

+452
-315
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
44
export PYTHONPATH = src
55

6-
check_dirs := examples tests src utils
6+
check_dirs := examples scripts src tests utils
77

88
modified_only_fixup:
99
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))

scripts/change_naming_configs_and_checkpoints.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
""" Conversion script for the LDM checkpoints. """
1616

1717
import argparse
18-
import os
1918
import json
19+
import os
20+
2021
import torch
21-
from diffusers import UNet2DModel, UNet2DConditionModel
22+
23+
from diffusers import UNet2DConditionModel, UNet2DModel
2224
from transformers.file_utils import has_file
2325

26+
2427
do_only_config = False
2528
do_only_weights = True
2629
do_only_renaming = False
@@ -37,9 +40,7 @@
3740
help="The config json file corresponding to the architecture.",
3841
)
3942

40-
parser.add_argument(
41-
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
42-
)
43+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
4344

4445
args = parser.parse_args()
4546

scripts/conversion_ldm_uncond.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import argparse
22

3-
import OmegaConf
43
import torch
54

6-
from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler
5+
import OmegaConf
6+
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
7+
78

89
def convert_ldm_original(checkpoint_path, config_path, output_path):
910
config = OmegaConf.load(config_path)
@@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
1617
for key in keys:
1718
if key.startswith(first_stage_key):
1819
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
19-
20+
2021
# extract state_dict for UNetLDM
2122
unet_state_dict = {}
2223
unet_key = "model.diffusion_model."
2324
for key in keys:
2425
if key.startswith(unet_key):
2526
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
26-
27+
2728
vqvae_init_args = config.model.params.first_stage_config.params
2829
unet_init_args = config.model.params.unet_config.params
2930

@@ -53,4 +54,3 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
5354
args = parser.parse_args()
5455

5556
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
56-

scripts/convert_ddpm_original_checkpoint_to_diffusers.py

Lines changed: 208 additions & 136 deletions
Large diffs are not rendered by default.

scripts/convert_ldm_original_checkpoint_to_diffusers.py

Lines changed: 123 additions & 97 deletions
Large diffs are not rendered by default.

scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
import argparse
1818
import json
19+
1920
import torch
20-
from diffusers import UNet2DModel
21+
22+
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
2123

2224

2325
def convert_ncsnpp_checkpoint(checkpoint, config):

scripts/generate_logits.py

Lines changed: 106 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,127 @@
1-
from huggingface_hub import HfApi
2-
from transformers.file_utils import has_file
3-
from diffusers import UNet2DModel
41
import random
2+
53
import torch
4+
5+
from diffusers import UNet2DModel
6+
from huggingface_hub import HfApi
7+
8+
69
api = HfApi()
710

811
results = {}
9-
results["google_ddpm_cifar10_32"] = torch.tensor([-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
10-
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
11-
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
12-
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557])
13-
results["google_ddpm_ema_bedroom_256"] = torch.tensor([-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
14-
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
15-
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
16-
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365])
17-
results["CompVis_ldm_celebahq_256"] = torch.tensor([-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
18-
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
19-
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
20-
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943])
21-
results["google_ncsnpp_ffhq_1024"] = torch.tensor([ 0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
22-
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
23-
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
24-
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505])
25-
results["google_ncsnpp_bedroom_256"] = torch.tensor([ 0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
26-
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
27-
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
28-
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386])
29-
results["google_ncsnpp_celebahq_256"] = torch.tensor([ 0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
30-
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
31-
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
32-
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431])
33-
results["google_ncsnpp_church_256"] = torch.tensor([ 0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
34-
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
35-
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
36-
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390])
37-
results["google_ncsnpp_ffhq_256"] = torch.tensor([ 0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
38-
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
39-
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
40-
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473])
41-
results["google_ddpm_cat_256"] = torch.tensor([-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
42-
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
43-
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
44-
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
45-
results["google_ddpm_celebahq_256"] = torch.tensor([-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
46-
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
47-
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
48-
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266])
49-
results["google_ddpm_ema_celebahq_256"] = torch.tensor([-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
50-
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
51-
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
52-
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355])
53-
results["google_ddpm_church_256"] = torch.tensor([-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
54-
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
55-
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
56-
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066])
57-
results["google_ddpm_bedroom_256"] = torch.tensor([-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
58-
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
59-
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
60-
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243])
61-
results["google_ddpm_ema_church_256"] = torch.tensor([-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
62-
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
63-
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
64-
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343])
65-
results["google_ddpm_ema_cat_256"] = torch.tensor([-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
66-
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
67-
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
68-
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219])
12+
# fmt: off
13+
results["google_ddpm_cifar10_32"] = torch.tensor([
14+
-0.7515, -1.6883, 0.2420, 0.0300, 0.6347, 1.3433, -1.1743, -3.7467,
15+
1.2342, -2.2485, 0.4636, 0.8076, -0.7991, 0.3969, 0.8498, 0.9189,
16+
-1.8887, -3.3522, 0.7639, 0.2040, 0.6271, -2.7148, -1.6316, 3.0839,
17+
0.3186, 0.2721, -0.9759, -1.2461, 2.6257, 1.3557
18+
])
19+
results["google_ddpm_ema_bedroom_256"] = torch.tensor([
20+
-2.3639, -2.5344, 0.0054, -0.6674, 1.5990, 1.0158, 0.3124, -2.1436,
21+
1.8795, -2.5429, -0.1566, -0.3973, 1.2490, 2.6447, 1.2283, -0.5208,
22+
-2.8154, -3.5119, 2.3838, 1.2033, 1.7201, -2.1256, -1.4576, 2.7948,
23+
2.4204, -0.9752, -1.2546, 0.8027, 3.2758, 3.1365
24+
])
25+
results["CompVis_ldm_celebahq_256"] = torch.tensor([
26+
-0.6531, -0.6891, -0.3172, -0.5375, -0.9140, -0.5367, -0.1175, -0.7869,
27+
-0.3808, -0.4513, -0.2098, -0.0083, 0.3183, 0.5140, 0.2247, -0.1304,
28+
-0.1302, -0.2802, -0.2084, -0.2025, -0.4967, -0.4873, -0.0861, 0.6925,
29+
0.0250, 0.1290, -0.1543, 0.6316, 1.0460, 1.4943
30+
])
31+
results["google_ncsnpp_ffhq_1024"] = torch.tensor([
32+
0.0911, 0.1107, 0.0182, 0.0435, -0.0805, -0.0608, 0.0381, 0.2172,
33+
-0.0280, 0.1327, -0.0299, -0.0255, -0.0050, -0.1170, -0.1046, 0.0309,
34+
0.1367, 0.1728, -0.0533, -0.0748, -0.0534, 0.1624, 0.0384, -0.1805,
35+
-0.0707, 0.0642, 0.0220, -0.0134, -0.1333, -0.1505
36+
])
37+
results["google_ncsnpp_bedroom_256"] = torch.tensor([
38+
0.1321, 0.1337, 0.0440, 0.0622, -0.0591, -0.0370, 0.0503, 0.2133,
39+
-0.0177, 0.1415, -0.0116, -0.0112, 0.0044, -0.0980, -0.0789, 0.0395,
40+
0.1502, 0.1785, -0.0488, -0.0514, -0.0404, 0.1539, 0.0454, -0.1559,
41+
-0.0665, 0.0659, 0.0383, -0.0005, -0.1266, -0.1386
42+
])
43+
results["google_ncsnpp_celebahq_256"] = torch.tensor([
44+
0.1154, 0.1218, 0.0307, 0.0526, -0.0711, -0.0541, 0.0366, 0.2078,
45+
-0.0267, 0.1317, -0.0226, -0.0193, -0.0014, -0.1055, -0.0902, 0.0330,
46+
0.1391, 0.1709, -0.0562, -0.0693, -0.0560, 0.1482, 0.0381, -0.1683,
47+
-0.0681, 0.0661, 0.0331, -0.0046, -0.1268, -0.1431
48+
])
49+
results["google_ncsnpp_church_256"] = torch.tensor([
50+
0.1192, 0.1240, 0.0414, 0.0606, -0.0557, -0.0412, 0.0430, 0.2042,
51+
-0.0200, 0.1385, -0.0115, -0.0132, 0.0017, -0.0965, -0.0802, 0.0398,
52+
0.1433, 0.1747, -0.0458, -0.0533, -0.0407, 0.1545, 0.0419, -0.1574,
53+
-0.0645, 0.0626, 0.0341, -0.0010, -0.1199, -0.1390
54+
])
55+
results["google_ncsnpp_ffhq_256"] = torch.tensor([
56+
0.1075, 0.1074, 0.0205, 0.0431, -0.0774, -0.0607, 0.0298, 0.2042,
57+
-0.0320, 0.1267, -0.0281, -0.0250, -0.0064, -0.1091, -0.0946, 0.0290,
58+
0.1328, 0.1650, -0.0580, -0.0738, -0.0586, 0.1440, 0.0337, -0.1746,
59+
-0.0712, 0.0605, 0.0250, -0.0099, -0.1316, -0.1473
60+
])
61+
results["google_ddpm_cat_256"] = torch.tensor([
62+
-1.4572, -2.0481, -0.0414, -0.6005, 1.4136, 0.5848, 0.4028, -2.7330,
63+
1.2212, -2.1228, 0.2155, 0.4039, 0.7662, 2.0535, 0.7477, -0.3243,
64+
-2.1758, -2.7648, 1.6947, 0.7026, 1.2338, -1.6078, -0.8682, 2.2810,
65+
1.8574, -0.5718, -0.5586, -0.0186, 2.3415, 2.1251])
66+
results["google_ddpm_celebahq_256"] = torch.tensor([
67+
-1.3690, -1.9720, -0.4090, -0.6966, 1.4660, 0.9938, -0.1385, -2.7324,
68+
0.7736, -1.8917, 0.2923, 0.4293, 0.1693, 1.4112, 1.1887, -0.3181,
69+
-2.2160, -2.6381, 1.3170, 0.8163, 0.9240, -1.6544, -0.6099, 2.5259,
70+
1.6430, -0.9090, -0.9392, -0.0126, 2.4268, 2.3266
71+
])
72+
results["google_ddpm_ema_celebahq_256"] = torch.tensor([
73+
-1.3525, -1.9628, -0.3956, -0.6860, 1.4664, 1.0014, -0.1259, -2.7212,
74+
0.7772, -1.8811, 0.2996, 0.4388, 0.1704, 1.4029, 1.1701, -0.3027,
75+
-2.2053, -2.6287, 1.3350, 0.8131, 0.9274, -1.6292, -0.6098, 2.5131,
76+
1.6505, -0.8958, -0.9298, -0.0151, 2.4257, 2.3355
77+
])
78+
results["google_ddpm_church_256"] = torch.tensor([
79+
-2.0585, -2.7897, -0.2850, -0.8940, 1.9052, 0.5702, 0.6345, -3.8959,
80+
1.5932, -3.2319, 0.1974, 0.0287, 1.7566, 2.6543, 0.8387, -0.5351,
81+
-3.2736, -4.3375, 2.9029, 1.6390, 1.4640, -2.1701, -1.9013, 2.9341,
82+
3.4981, -0.6255, -1.1644, -0.1591, 3.7097, 3.2066
83+
])
84+
results["google_ddpm_bedroom_256"] = torch.tensor([
85+
-2.3139, -2.5594, -0.0197, -0.6785, 1.7001, 1.1606, 0.3075, -2.1740,
86+
1.8071, -2.5630, -0.0926, -0.3811, 1.2116, 2.6246, 1.2731, -0.5398,
87+
-2.8153, -3.6140, 2.3893, 1.3262, 1.6258, -2.1856, -1.3267, 2.8395,
88+
2.3779, -1.0623, -1.2468, 0.8959, 3.3367, 3.2243
89+
])
90+
results["google_ddpm_ema_church_256"] = torch.tensor([
91+
-2.0628, -2.7667, -0.2089, -0.8263, 2.0539, 0.5992, 0.6495, -3.8336,
92+
1.6025, -3.2817, 0.1721, -0.0633, 1.7516, 2.7039, 0.8100, -0.5908,
93+
-3.2113, -4.4343, 2.9257, 1.3632, 1.5562, -2.1489, -1.9894, 3.0560,
94+
3.3396, -0.7328, -1.0417, 0.0383, 3.7093, 3.2343
95+
])
96+
results["google_ddpm_ema_cat_256"] = torch.tensor([
97+
-1.4574, -2.0569, -0.0473, -0.6117, 1.4018, 0.5769, 0.4129, -2.7344,
98+
1.2241, -2.1397, 0.2000, 0.3937, 0.7616, 2.0453, 0.7324, -0.3391,
99+
-2.1746, -2.7744, 1.6963, 0.6921, 1.2187, -1.6172, -0.8877, 2.2439,
100+
1.8471, -0.5839, -0.5605, -0.0464, 2.3250, 2.1219
101+
])
102+
# fmt: on
69103

70104
models = api.list_models(filter="diffusers")
71105
for mod in models:
72-
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
106+
if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256":
73107
local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1]
74108

75109
print(f"Started running {mod.modelId}!!!")
76110

77111
if mod.modelId.startswith("CompVis"):
78-
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder = "unet")
79-
else:
112+
model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet")
113+
else:
80114
model = UNet2DModel.from_pretrained(local_checkpoint)
81-
115+
82116
torch.manual_seed(0)
83117
random.seed(0)
84-
118+
85119
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
86120
time_step = torch.tensor([10] * noise.shape[0])
87121
with torch.no_grad():
88-
logits = model(noise, time_step)['sample']
122+
logits = model(noise, time_step)["sample"]
89123

90-
assert torch.allclose(logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3)
124+
assert torch.allclose(
125+
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
126+
)
91127
print(f"{mod.modelId} has passed succesfully!!!")

0 commit comments

Comments
 (0)