@@ -149,6 +149,9 @@ def parse_args():
149
149
parser .add_argument (
150
150
"--lr_warmup_steps" , type = int , default = 500 , help = "Number of steps for the warmup in the lr scheduler."
151
151
)
152
+ parser .add_argument (
153
+ "--use_8bit_adam" , action = "store_true" , help = "Whether or not to use 8-bit Adam from bitsandbytes."
154
+ )
152
155
parser .add_argument ("--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam optimizer." )
153
156
parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "The beta2 parameter for the Adam optimizer." )
154
157
parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-2 , help = "Weight decay to use." )
@@ -401,7 +404,19 @@ def main():
401
404
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
402
405
)
403
406
404
- optimizer = torch .optim .AdamW (
407
+ if args .use_8bit_adam :
408
+ try :
409
+ import bitsandbytes as bnb
410
+ except ImportError :
411
+ raise ImportError (
412
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
413
+ )
414
+
415
+ optimizer_class = bnb .optim .AdamW8bit
416
+ else :
417
+ optimizer_class = torch .optim .AdamW
418
+
419
+ optimizer = optimizer_class (
405
420
unet .parameters (), # only optimize unet
406
421
lr = args .learning_rate ,
407
422
betas = (args .adam_beta1 , args .adam_beta2 ),
0 commit comments