Skip to content

Commit bcc0fbc

Browse files
committed
support more argument settings for scheduler
1 parent acfd537 commit bcc0fbc

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,9 @@ def create_scheduler(self, num_training_steps: int):
12411241
learning_rate=self.args.learning_rate,
12421242
num_warmup_steps=warmup,
12431243
num_training_steps=num_training_steps,
1244+
num_cycles=self.args.num_cycles,
1245+
lr_end=self.args.lr_end,
1246+
power=self.args.power,
12441247
)
12451248

12461249
return self.lr_scheduler

paddlenlp/trainer/trainer_utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ def get_scheduler(
394394
learning_rate: float,
395395
num_warmup_steps: Optional[int] = None,
396396
num_training_steps: Optional[int] = None,
397+
num_cycles: Optional[float] = 0.5,
398+
lr_end: Optional[float] = 1e-7,
399+
power: Optional[float] = 1.0,
397400
):
398401
"""
399402
Unified API to get any scheduler from its name.
@@ -408,6 +411,15 @@ def get_scheduler(
408411
num_training_steps (`int``, *optional*):
409412
The number of training steps to do. This is not required by all schedulers (hence the argument being
410413
optional), the function will raise an error if it's unset and the scheduler type requires it.
414+
num_cycles (``float``, *optional*):
415+
The number of waves in the cosine scheduler (the defaults is to just decrease from the max value to 0
416+
following a half-cosine). This is not required by all schedulers (hence the argument being optional)
417+
lr_end (``float``, *optional*):
418+
The end LR in the polynomial scheduler. This is not required by all schedulers (hence the argument
419+
being optional).
420+
power (``float``, *optional*):
421+
The power factor in the polynomial scheduler. This is not required by all schedulers (hence the argument
422+
being optional).
411423
"""
412424
name = SchedulerType(name)
413425
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
@@ -425,6 +437,23 @@ def get_scheduler(
425437
if num_training_steps is None:
426438
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
427439

440+
if name == SchedulerType.COSINE:
441+
return schedule_func(
442+
learning_rate,
443+
num_warmup_steps=num_warmup_steps,
444+
num_training_steps=num_training_steps,
445+
num_cycles=num_cycles,
446+
)
447+
448+
if name == SchedulerType.POLYNOMIAL:
449+
return schedule_func(
450+
learning_rate,
451+
num_warmup_steps=num_warmup_steps,
452+
num_training_steps=num_training_steps,
453+
lr_end=lr_end,
454+
power=power,
455+
)
456+
428457
return schedule_func(learning_rate, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
429458

430459

paddlenlp/trainer/training_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ class TrainingArguments:
139139
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
140140
warmup_steps (`int`, *optional*, defaults to 0):
141141
Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
142+
num_cycles (`float`, *optional*, defaults to 0.5):
143+
The number of waves in the cosine scheduler.
144+
lr_end (`float`, *optional*, defaults to 1e-7):
145+
The end LR used in the polynomial scheduler.
146+
power (`float`, *optional*, defaults to 1.0):
147+
The power factor used in the polynomial scheduler.
148+
142149
log_on_each_node (`bool`, *optional*, defaults to `True`):
143150
In multinode distributed training, whether to log using `log_level` once per node, or only on the main
144151
node.
@@ -363,6 +370,9 @@ class TrainingArguments:
363370
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
364371
)
365372
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
373+
num_cycles: float = field(default=0.5, metadata={"help": "The number of waves in the cosine scheduler."})
374+
lr_end: float = field(default=1e-7, metadata={"help": "The end LR in the polynomial scheduler."})
375+
power: float = field(default=1.0, metadata={"help": "The power factor in the polynomial scheduler."})
366376

367377
log_on_each_node: bool = field(
368378
default=True,

0 commit comments

Comments
 (0)