Skip to content

Commit 87bc752

Browse files
committed
add option for 8bit adam
1 parent 16ecc08 commit 87bc752

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def parse_args():
149149
parser.add_argument(
150150
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
151151
)
152+
parser.add_argument(
153+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
154+
)
152155
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
153156
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
154157
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
@@ -401,7 +404,19 @@ def main():
401404
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
402405
)
403406

404-
optimizer = torch.optim.AdamW(
407+
if args.use_8bit_adam:
408+
try:
409+
import bitsandbytes as bnb
410+
except ImportError:
411+
raise ImportError(
412+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
413+
)
414+
415+
optimizer_class = bnb.optim.AdamW8bit
416+
else:
417+
optimizer_class = torch.optim.AdamW
418+
419+
optimizer = optimizer_class(
405420
unet.parameters(), # only optimize unet
406421
lr=args.learning_rate,
407422
betas=(args.adam_beta1, args.adam_beta2),

0 commit comments

Comments
 (0)