Skip to content

Commit 1b2c6ac

Browse files
committed
Include CLIPTextModel parameters in conversion
1 parent 7265dd8 commit 1b2c6ac

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,22 @@ def _copy_layers(hf_layers, pt_layers):
595595
return hf_model
596596

597597

598+
def convert_ldm_clip_checkpoint(checkpoint):
599+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
600+
601+
keys = list(checkpoint.keys())
602+
603+
text_model_dict = {}
604+
605+
for key in keys:
606+
if key.startswith("cond_stage_model.transformer"):
607+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
608+
609+
text_model.load_state_dict(text_model_dict)
610+
611+
return text_model
612+
613+
598614
if __name__ == "__main__":
599615
parser = argparse.ArgumentParser()
600616

@@ -668,7 +684,7 @@ def _copy_layers(hf_layers, pt_layers):
668684
# Convert the text model.
669685
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
670686
if text_model_type == "FrozenCLIPEmbedder":
671-
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
687+
text_model = convert_ldm_clip_checkpoint(checkpoint)
672688
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
673689
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
674690
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")

0 commit comments

Comments
 (0)