-
Notifications
You must be signed in to change notification settings - Fork 3k
[Trainer] Support skip data intervals #8989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
f8840bd
8b2cc1d
f75a6dd
224ce88
9dd33a5
435586a
fb407d8
67ef207
f2e7a31
f7cef77
b06f856
1cdbf1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -138,6 +138,7 @@ | |||||
get_scheduler, | ||||||
has_length, | ||||||
set_seed, | ||||||
should_skip_data, | ||||||
speed_metrics, | ||||||
) | ||||||
from .training_args import TrainingArguments | ||||||
|
@@ -277,7 +278,7 @@ | |||||
# Seed must be set before instantiating the model when using model | ||||||
set_seed(seed=self.args.seed) | ||||||
|
||||||
if model is None: | ||||||
if model is None and not args.debug_data: | ||||||
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") | ||||||
|
||||||
if self.args.to_static: | ||||||
|
@@ -339,7 +340,7 @@ | |||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." | ||||||
) | ||||||
|
||||||
if self.args.pipeline_parallel_degree > 1 and self.args.use_hybrid_parallel: | ||||||
if self.args.pipeline_parallel_degree > 1 and self.args.use_hybrid_parallel and not args.debug_data: | ||||||
from paddle.distributed.fleet.meta_parallel import PipelineLayer | ||||||
|
||||||
assert (isinstance(model, LoRAModel) and isinstance(model.model, PipelineLayer)) or isinstance( | ||||||
|
@@ -357,6 +358,7 @@ | |||||
self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load | ||||||
if self.args.use_async_save: | ||||||
self._async_optimizer_saver = AsyncSaver() | ||||||
self.skip_global_steps = 0 | ||||||
|
||||||
if args.max_steps > 0: | ||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs") | ||||||
|
@@ -377,26 +379,28 @@ | |||||
|
||||||
self.do_grad_scaling = False | ||||||
self.enable_autocast_context_manager = False | ||||||
if args.fp16 or args.bf16: | ||||||
# set do_grad_scaling, enable_autocast_context_manager | ||||||
self._wrap_amp_model(args, model) | ||||||
|
||||||
if args.recompute: | ||||||
if not args.debug_data: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 额,debug_data 是模型啥的都不跑是吗? 这个有必要对完暴露吗?还是开发完了,删掉? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对,debug_data就是只打印数据不加载模型,而且也不训练,这里是想作为一个通用功能加进来。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果只是我们内部使用的debug模式的话,我感觉加的意义不是很大。 |
||||||
if args.fp16 or args.bf16: | ||||||
# set do_grad_scaling, enable_autocast_context_manager | ||||||
self._wrap_amp_model(args, model) | ||||||
|
||||||
def fn(layer): | ||||||
if hasattr(layer, "enable_recompute") and ( | ||||||
layer.enable_recompute is False or layer.enable_recompute == 0 | ||||||
): | ||||||
layer.enable_recompute = True | ||||||
if args.recompute: | ||||||
|
||||||
model.apply(fn) | ||||||
def fn(layer): | ||||||
if hasattr(layer, "enable_recompute") and ( | ||||||
layer.enable_recompute is False or layer.enable_recompute == 0 | ||||||
): | ||||||
layer.enable_recompute = True | ||||||
|
||||||
default_label_names = ( | ||||||
["start_positions", "end_positions"] | ||||||
if "QusetionAnswering" in type(self.model).__name__ or "UIE" in type(self.model).__name__ | ||||||
else ["labels"] | ||||||
) | ||||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names | ||||||
model.apply(fn) | ||||||
|
||||||
default_label_names = ( | ||||||
["start_positions", "end_positions"] | ||||||
if "QusetionAnswering" in type(self.model).__name__ or "UIE" in type(self.model).__name__ | ||||||
else ["labels"] | ||||||
) | ||||||
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names | ||||||
|
||||||
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) | ||||||
self.print_config() | ||||||
|
@@ -924,6 +928,7 @@ | |||||
step_control = 0 # used in loop control, reset to 0 after every step | ||||||
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) | ||||||
|
||||||
step = -1 | ||||||
for step, inputs in enumerate(epoch_iterator): | ||||||
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: | ||||||
inputs = split_inputs_sequence_dim(inputs) | ||||||
|
@@ -960,6 +965,31 @@ | |||||
steps_trained_progress_bar.close() | ||||||
steps_trained_progress_bar = None | ||||||
|
||||||
# Skip data | ||||||
if should_skip_data(self.state.global_step, self.args.skip_data_intervals): | ||||||
logger.warning(f"Skip data at global step {self.state.global_step+1}, sub step {step_control}") | ||||||
logger.warning(f"{self.tokenizer.batch_decode(inputs['input_ids'], skip_special_tokens=True)}") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个就不要加了吧
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个warning是用来打印跳过的数据的,如果去掉的话也是OK的,这里主要是想让用户知道跳过的数据都是啥。 |
||||||
|
||||||
if (step_control + 1) % args.gradient_accumulation_steps == 0 or ( | ||||||
# last step in epoch but step is always smaller than gradient_accumulation_steps | ||||||
steps_in_epoch <= args.gradient_accumulation_steps | ||||||
and (step + 1) == steps_in_epoch | ||||||
): | ||||||
self.skip_global_steps += 1 | ||||||
self.state.global_step += 1 | ||||||
self.state.epoch = epoch + (step + 1) / steps_in_epoch | ||||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control) | ||||||
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个也不需要了吧? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _maybe_log_save_evaluate这里是为了去走: PaddleNLP/paddlenlp/trainer/trainer.py Line 1308 in 48820cb
2._globalstep_last_logged的更新: PaddleNLP/paddlenlp/trainer/trainer.py Line 1346 in 48820cb
3.正常的eval流程。不然最后eval计算consumed_samples的时候会有问题https://github.com/PaddlePaddle/PaddleNLP/blob/48820cbc1fe986004f817c0517886735675732d2/paddlenlp/trainer/trainer.py#L2792C6-L2797C18 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我主要的担心的是,skip数据的时候,碰到了eval 或 者 save 等各种各样的call back 是否有问题。 |
||||||
self._print_timer() | ||||||
step_control = 0 | ||||||
else: | ||||||
self.control = self.callback_handler.on_substep_end(args, self.state, self.control) | ||||||
step_control += 1 | ||||||
if self.control.should_epoch_stop or self.control.should_training_stop: | ||||||
break | ||||||
self.timers and self.timers("read-data").start() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我感觉很多东西你可能不需要啊,没有计算的话,一些call_back 触发不知道有没有问题? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是为了进行一些判断,比如是否应该进行eval、save和停止训练。没有经过前反向计算直接执行callback我测试的时候没有报错,不过可能确实会有一些没测试到的潜在风险。 |
||||||
continue | ||||||
|
||||||
if step_control % args.gradient_accumulation_steps == 0: | ||||||
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) | ||||||
self.timers and self.timers("forward-backward").start() | ||||||
|
@@ -1181,7 +1211,10 @@ | |||||
) | ||||||
|
||||||
self._total_loss_scalar += tr_loss.item() | ||||||
train_loss = self._total_loss_scalar / self.state.global_step | ||||||
if self.state.global_step == self.skip_global_steps: | ||||||
train_loss = 0.0 | ||||||
else: | ||||||
train_loss = self._total_loss_scalar / (self.state.global_step - self.skip_global_steps) | ||||||
|
||||||
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) | ||||||
|
||||||
|
Uh oh!
There was an error while loading. Please reload this page.