@@ -394,6 +394,9 @@ def get_scheduler(
394
394
learning_rate : float ,
395
395
num_warmup_steps : Optional [int ] = None ,
396
396
num_training_steps : Optional [int ] = None ,
397
+ num_cycles : Optional [float ] = 0.5 ,
398
+ lr_end : Optional [float ] = 1e-7 ,
399
+ power : Optional [float ] = 1.0 ,
397
400
):
398
401
"""
399
402
Unified API to get any scheduler from its name.
@@ -408,6 +411,15 @@ def get_scheduler(
408
411
num_training_steps (`int``, *optional*):
409
412
The number of training steps to do. This is not required by all schedulers (hence the argument being
410
413
optional), the function will raise an error if it's unset and the scheduler type requires it.
414
+ num_cycles (``float``, *optional*):
415
+ The number of waves in the cosine scheduler (the defaults is to just decrease from the max value to 0
416
+ following a half-cosine). This is not required by all schedulers (hence the argument being optional)
417
+ lr_end (``float``, *optional*):
418
+ The end LR in the polynomial scheduler. This is not required by all schedulers (hence the argument
419
+ being optional).
420
+ power (``float``, *optional*):
421
+ The power factor in the polynomial scheduler. This is not required by all schedulers (hence the argument
422
+ being optional).
411
423
"""
412
424
name = SchedulerType (name )
413
425
schedule_func = TYPE_TO_SCHEDULER_FUNCTION [name ]
@@ -425,6 +437,23 @@ def get_scheduler(
425
437
if num_training_steps is None :
426
438
raise ValueError (f"{ name } requires `num_training_steps`, please provide that argument." )
427
439
440
+ if name == SchedulerType .COSINE :
441
+ return schedule_func (
442
+ learning_rate ,
443
+ num_warmup_steps = num_warmup_steps ,
444
+ num_training_steps = num_training_steps ,
445
+ num_cycles = num_cycles ,
446
+ )
447
+
448
+ if name == SchedulerType .POLYNOMIAL :
449
+ return schedule_func (
450
+ learning_rate ,
451
+ num_warmup_steps = num_warmup_steps ,
452
+ num_training_steps = num_training_steps ,
453
+ lr_end = lr_end ,
454
+ power = power ,
455
+ )
456
+
428
457
return schedule_func (learning_rate , num_warmup_steps = num_warmup_steps , num_training_steps = num_training_steps )
429
458
430
459
0 commit comments