@@ -1088,17 +1088,22 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
1088
1088
)
1089
1089
1090
1090
# Scheduler and math around the number of training steps.
1091
- overrode_max_train_steps = False
1092
- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
1091
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1092
+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
1093
1093
if args .max_train_steps is None :
1094
- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1095
- overrode_max_train_steps = True
1094
+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
1095
+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
1096
+ num_training_steps_for_scheduler = (
1097
+ args .num_train_epochs * num_update_steps_per_epoch * accelerator .num_processes
1098
+ )
1099
+ else :
1100
+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
1096
1101
1097
1102
lr_scheduler = get_scheduler (
1098
1103
args .lr_scheduler ,
1099
1104
optimizer = optimizer ,
1100
- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
1101
- num_training_steps = args . max_train_steps * accelerator . num_processes ,
1105
+ num_warmup_steps = num_warmup_steps_for_scheduler ,
1106
+ num_training_steps = num_training_steps_for_scheduler ,
1102
1107
num_cycles = args .lr_num_cycles ,
1103
1108
power = args .lr_power ,
1104
1109
)
@@ -1110,8 +1115,14 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
1110
1115
1111
1116
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
1112
1117
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
1113
- if overrode_max_train_steps :
1118
+ if args . max_train_steps is None :
1114
1119
args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1120
+ if num_training_steps_for_scheduler != args .max_train_steps * accelerator .num_processes :
1121
+ logger .warning (
1122
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
1123
+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
1124
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1125
+ )
1115
1126
# Afterwards we recalculate our number of training epochs
1116
1127
args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
1117
1128
0 commit comments