Skip to content

Commit 55a2686

Browse files
committed
Use env variable to control chroma padding settings
1 parent 67cc996 commit 55a2686

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

conditioner.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,16 @@ struct PixArtCLIPEmbedder : public Conditioner {
13181318
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
13191319
auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask);
13201320

1321+
const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK");
1322+
if (SD_CHROMA_USE_T5_MASK != nullptr) {
1323+
std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
1324+
if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") {
1325+
t5_attn_mask_chunk = NULL;
1326+
} else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") {
1327+
LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK);
1328+
}
1329+
}
1330+
13211331
t5->compute(n_threads,
13221332
input_ids,
13231333
&chunk_hidden_states,

flux.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,30 @@ namespace Flux {
10841084
c_concat = to_backend(c_concat);
10851085
}
10861086
if (flux_params.is_chroma) {
1087-
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), 1);
1087+
int mask_pad = 1;
1088+
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
1089+
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
1090+
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
1091+
try {
1092+
mask_pad = std::stoi(mask_pad_str);
1093+
} catch (const std::invalid_argument&) {
1094+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1095+
} catch (const std::out_of_range&) {
1096+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1097+
}
1098+
}
1099+
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad);
1100+
1101+
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
1102+
if (SD_CHROMA_USE_DIT_MASK != nullptr) {
1103+
std::string sd_chroma_use_DiT_mask_str = SD_CHROMA_USE_DIT_MASK;
1104+
if (sd_chroma_use_DiT_mask_str == "OFF" || sd_chroma_use_DiT_mask_str == "FALSE") {
1105+
y = NULL;
1106+
} else if (sd_chroma_use_DiT_mask_str != "ON" && sd_chroma_use_DiT_mask_str != "TRUE") {
1107+
LOG_WARN("SD_CHROMA_USE_DIT_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_DIT_MASK);
1108+
}
1109+
}
1110+
10881111
// ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it
10891112
range = arange(0, 344);
10901113
precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size());

0 commit comments

Comments
 (0)