Skip to content

Commit a21dd7d

Browse files
patil-surajnatolambert
authored andcommitted
[train_unconditional] fix gradient accumulation. (#308)
fix grad accum
1 parent 305a1a1 commit a21dd7d

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import math
23
import os
34

45
import torch
@@ -29,6 +30,7 @@
2930
def main(args):
3031
logging_dir = os.path.join(args.output_dir, args.logging_dir)
3132
accelerator = Accelerator(
33+
gradient_accumulation_steps=args.gradient_accumulation_steps,
3234
mixed_precision=args.mixed_precision,
3335
log_with="tensorboard",
3436
logging_dir=logging_dir,
@@ -105,6 +107,8 @@ def transforms(examples):
105107
model, optimizer, train_dataloader, lr_scheduler
106108
)
107109

110+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
111+
108112
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
109113

110114
if args.push_to_hub:
@@ -117,7 +121,7 @@ def transforms(examples):
117121
global_step = 0
118122
for epoch in range(args.num_epochs):
119123
model.train()
120-
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
124+
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
121125
progress_bar.set_description(f"Epoch {epoch}")
122126
for step, batch in enumerate(train_dataloader):
123127
clean_images = batch["input"]
@@ -146,13 +150,16 @@ def transforms(examples):
146150
ema_model.step(model)
147151
optimizer.zero_grad()
148152

149-
progress_bar.update(1)
153+
# Checks if the accelerator has performed an optimization step behind the scenes
154+
if accelerator.sync_gradients:
155+
progress_bar.update(1)
156+
global_step += 1
157+
150158
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
151159
if args.use_ema:
152160
logs["ema_decay"] = ema_model.decay
153161
progress_bar.set_postfix(**logs)
154162
accelerator.log(logs, step=global_step)
155-
global_step += 1
156163
progress_bar.close()
157164

158165
accelerator.wait_for_everyone()

0 commit comments

Comments
 (0)