File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -120,6 +120,11 @@ def parse_args():
120
120
default = 1 ,
121
121
help = "Number of updates steps to accumulate before performing a backward/update pass." ,
122
122
)
123
+ parser .add_argument (
124
+ "--gradient_checkpointing" ,
125
+ action = "store_true" ,
126
+ help = "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass." ,
127
+ )
123
128
parser .add_argument (
124
129
"--learning_rate" ,
125
130
type = float ,
@@ -388,10 +393,14 @@ def main():
388
393
args .pretrained_model_name_or_path , subfolder = "unet" , use_auth_token = args .use_auth_token
389
394
)
390
395
396
+ if args .gradient_checkpointing :
397
+ unet .enable_gradient_checkpointing ()
398
+
391
399
if args .scale_lr :
392
400
args .learning_rate = (
393
401
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
394
402
)
403
+
395
404
optimizer = torch .optim .AdamW (
396
405
unet .parameters (), # only optimize unet
397
406
lr = args .learning_rate ,
You can’t perform that action at this time.
0 commit comments