Skip to content

Commit 16ecc08

Browse files
committed
add grad ckpt
1 parent 2894a92 commit 16ecc08

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def parse_args():
120120
default=1,
121121
help="Number of updates steps to accumulate before performing a backward/update pass.",
122122
)
123+
parser.add_argument(
124+
"--gradient_checkpointing",
125+
action="store_true",
126+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
127+
)
123128
parser.add_argument(
124129
"--learning_rate",
125130
type=float,
@@ -388,10 +393,14 @@ def main():
388393
args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token
389394
)
390395

396+
if args.gradient_checkpointing:
397+
unet.enable_gradient_checkpointing()
398+
391399
if args.scale_lr:
392400
args.learning_rate = (
393401
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
394402
)
403+
395404
optimizer = torch.optim.AdamW(
396405
unet.parameters(), # only optimize unet
397406
lr=args.learning_rate,

0 commit comments

Comments
 (0)