Skip to content

Commit 825979d

Browse files
authored
[training] fix: registration of out_channels in the control flux scripts. (#10367)
* fix: registration of out_channels in the control flux scripts. * free memory.
1 parent 023b0e0 commit 825979d

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

examples/flux-control/train_control_flux.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def main(args):
795795
flux_transformer.x_embedder = new_linear
796796

797797
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
798-
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
798+
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
799799

800800
def unwrap_model(model):
801801
model = accelerator.unwrap_model(model)
@@ -1166,6 +1166,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11661166
flux_transformer.to(torch.float32)
11671167
flux_transformer.save_pretrained(args.output_dir)
11681168

1169+
del flux_transformer
1170+
del text_encoding_pipeline
1171+
del vae
1172+
free_memory()
1173+
11691174
# Run a final round of validation.
11701175
image_logs = None
11711176
if args.validation_prompt is not None:

examples/flux-control/train_control_lora_flux.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def main(args):
830830
flux_transformer.x_embedder = new_linear
831831

832832
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
833-
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
833+
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
834834

835835
if args.train_norm_layers:
836836
for name, param in flux_transformer.named_parameters():
@@ -1319,6 +1319,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13191319
transformer_lora_layers=transformer_lora_layers,
13201320
)
13211321

1322+
del flux_transformer
1323+
del text_encoding_pipeline
1324+
del vae
1325+
free_memory()
1326+
13221327
# Run a final round of validation.
13231328
image_logs = None
13241329
if args.validation_prompt is not None:

0 commit comments

Comments
 (0)