Skip to content

Commit 13fb61b

Browse files
authored
fix dit weights convert to ppdiffusers (PaddlePaddle#759)
1 parent 10692bb commit 13fb61b

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

ppdiffusers/examples/class_conditional_image_generation/DiT/tools/convert_dit_to_ppdiffusers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444
def main(args):
4545
num_layers, hidden_size, patch_size, num_heads = arch_settings[args.model_name]
4646

47-
state_dict = paddle.load(args.model_weights)
47+
state_dict_prefix = paddle.load(args.model_weights)
48+
state_dict = {k.replace("transformer.", ""): v for k, v in state_dict_prefix.items()}
49+
del state_dict_prefix
4850

4951
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
5052
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
@@ -158,7 +160,7 @@ def main(args):
158160
pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
159161

160162
if args.save:
161-
pipeline.save_pretrained(args.checkpoint_path)
163+
pipeline.save_pretrained(args.checkpoint_path, safe_serialization=False)
162164

163165

164166
if __name__ == "__main__":

0 commit comments

Comments
 (0)