Skip to content

Commit 80013ea

Browse files
committed
fix: allow resetting clip_skip to its default value
CLIPTextModel currently ignores attempts to set clip_skip back to -1, retaining the previously set value instead. While this is not an issue to the sd command (which does not support changing clip_skip between generations), it affects frontends that reuse model instances for multiple images. Since each model version's default clip_skip value is defined by its respective Conditioner class, it needs to be applied every time they get a different clip_skip value, so move that logic from their constructors into their set_clip_skip methods.
1 parent 10c6501 commit 80013ea

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

clip.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,8 @@ class CLIPTextModel : public GGMLBlock {
678678
bool with_final_ln = true;
679679

680680
CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
681-
int clip_skip_value = -1,
682-
bool with_final_ln = true)
681+
bool with_final_ln = true,
682+
int clip_skip_value = -1)
683683
: version(version), with_final_ln(with_final_ln) {
684684
if (version == OPEN_CLIP_VIT_H_14) {
685685
hidden_size = 1024;
@@ -701,7 +701,7 @@ class CLIPTextModel : public GGMLBlock {
701701

702702
void set_clip_skip(int skip) {
703703
if (skip <= 0) {
704-
return;
704+
skip = -1;
705705
}
706706
clip_skip = skip;
707707
}
@@ -871,9 +871,9 @@ struct CLIPTextModelRunner : public GGMLRunner {
871871
std::map<std::string, enum ggml_type>& tensor_types,
872872
const std::string prefix,
873873
CLIPVersion version = OPENAI_CLIP_VIT_L_14,
874-
int clip_skip_value = 1,
875-
bool with_final_ln = true)
876-
: GGMLRunner(backend), model(version, clip_skip_value, with_final_ln) {
874+
bool with_final_ln = true,
875+
int clip_skip_value = -1)
876+
: GGMLRunner(backend), model(version, with_final_ln, clip_skip_value) {
877877
model.init(params_ctx, tensor_types, prefix);
878878
}
879879

conditioner.hpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,23 +63,24 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6363
PMVersion pv = PM_VERSION_1,
6464
int clip_skip = -1)
6565
: version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) {
66-
if (clip_skip <= 0) {
67-
clip_skip = 1;
68-
if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) {
69-
clip_skip = 2;
70-
}
71-
}
7266
if (sd_version_is_sd1(version)) {
73-
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip);
67+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14);
7468
} else if (sd_version_is_sd2(version)) {
75-
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
69+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14);
7670
} else if (sd_version_is_sdxl(version)) {
77-
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
78-
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
71+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
72+
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
7973
}
74+
set_clip_skip(clip_skip);
8075
}
8176

8277
void set_clip_skip(int clip_skip) {
78+
if (clip_skip <= 0) {
79+
clip_skip = 1;
80+
if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) {
81+
clip_skip = 2;
82+
}
83+
}
8384
text_model->set_clip_skip(clip_skip);
8485
if (sd_version_is_sdxl(version)) {
8586
text_model2->set_clip_skip(clip_skip);
@@ -665,15 +666,16 @@ struct SD3CLIPEmbedder : public Conditioner {
665666
std::map<std::string, enum ggml_type>& tensor_types,
666667
int clip_skip = -1)
667668
: clip_g_tokenizer(0) {
668-
if (clip_skip <= 0) {
669-
clip_skip = 2;
670-
}
671-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
672-
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
669+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
670+
clip_g = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
673671
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
672+
set_clip_skip(clip_skip);
674673
}
675674

676675
void set_clip_skip(int clip_skip) {
676+
if (clip_skip <= 0) {
677+
clip_skip = 2;
678+
}
677679
clip_l->set_clip_skip(clip_skip);
678680
clip_g->set_clip_skip(clip_skip);
679681
}
@@ -1008,14 +1010,15 @@ struct FluxCLIPEmbedder : public Conditioner {
10081010
FluxCLIPEmbedder(ggml_backend_t backend,
10091011
std::map<std::string, enum ggml_type>& tensor_types,
10101012
int clip_skip = -1) {
1011-
if (clip_skip <= 0) {
1012-
clip_skip = 2;
1013-
}
1014-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, true);
1013+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true);
10151014
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1015+
set_clip_skip(clip_skip);
10161016
}
10171017

10181018
void set_clip_skip(int clip_skip) {
1019+
if (clip_skip <= 0) {
1020+
clip_skip = 2;
1021+
}
10191022
clip_l->set_clip_skip(clip_skip);
10201023
}
10211024

@@ -1218,4 +1221,4 @@ struct FluxCLIPEmbedder : public Conditioner {
12181221
}
12191222
};
12201223

1221-
#endif
1224+
#endif

0 commit comments

Comments
 (0)