1
1
import argparse
2
+ import math
2
3
import os
3
4
4
5
import torch
29
30
def main (args ):
30
31
logging_dir = os .path .join (args .output_dir , args .logging_dir )
31
32
accelerator = Accelerator (
33
+ gradient_accumulation_steps = args .gradient_accumulation_steps ,
32
34
mixed_precision = args .mixed_precision ,
33
35
log_with = "tensorboard" ,
34
36
logging_dir = logging_dir ,
@@ -105,6 +107,8 @@ def transforms(examples):
105
107
model , optimizer , train_dataloader , lr_scheduler
106
108
)
107
109
110
+ num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
111
+
108
112
ema_model = EMAModel (model , inv_gamma = args .ema_inv_gamma , power = args .ema_power , max_value = args .ema_max_decay )
109
113
110
114
if args .push_to_hub :
@@ -117,7 +121,7 @@ def transforms(examples):
117
121
global_step = 0
118
122
for epoch in range (args .num_epochs ):
119
123
model .train ()
120
- progress_bar = tqdm (total = len ( train_dataloader ) , disable = not accelerator .is_local_main_process )
124
+ progress_bar = tqdm (total = num_update_steps_per_epoch , disable = not accelerator .is_local_main_process )
121
125
progress_bar .set_description (f"Epoch { epoch } " )
122
126
for step , batch in enumerate (train_dataloader ):
123
127
clean_images = batch ["input" ]
@@ -146,13 +150,16 @@ def transforms(examples):
146
150
ema_model .step (model )
147
151
optimizer .zero_grad ()
148
152
149
- progress_bar .update (1 )
153
+ # Checks if the accelerator has performed an optimization step behind the scenes
154
+ if accelerator .sync_gradients :
155
+ progress_bar .update (1 )
156
+ global_step += 1
157
+
150
158
logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ], "step" : global_step }
151
159
if args .use_ema :
152
160
logs ["ema_decay" ] = ema_model .decay
153
161
progress_bar .set_postfix (** logs )
154
162
accelerator .log (logs , step = global_step )
155
- global_step += 1
156
163
progress_bar .close ()
157
164
158
165
accelerator .wait_for_everyone ()
0 commit comments