Skip to content

Commit 19a53ba

Browse files
[train_controlnet_sdxl.py] Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#8476)
* Create diffusers.yml * num_train_epochs --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent f9ba2ff commit 19a53ba

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,17 +1088,22 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
10881088
)
10891089

10901090
# Scheduler and math around the number of training steps.
1091-
overrode_max_train_steps = False
1092-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1091+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1092+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
10931093
if args.max_train_steps is None:
1094-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1095-
overrode_max_train_steps = True
1094+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1095+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1096+
num_training_steps_for_scheduler = (
1097+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1098+
)
1099+
else:
1100+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
10961101

10971102
lr_scheduler = get_scheduler(
10981103
args.lr_scheduler,
10991104
optimizer=optimizer,
1100-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1101-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1105+
num_warmup_steps=num_warmup_steps_for_scheduler,
1106+
num_training_steps=num_training_steps_for_scheduler,
11021107
num_cycles=args.lr_num_cycles,
11031108
power=args.lr_power,
11041109
)
@@ -1110,8 +1115,14 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
11101115

11111116
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
11121117
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1113-
if overrode_max_train_steps:
1118+
if args.max_train_steps is None:
11141119
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1120+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1121+
logger.warning(
1122+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1123+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1124+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1125+
)
11151126
# Afterwards we recalculate our number of training epochs
11161127
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
11171128

0 commit comments

Comments
 (0)