-
Notifications
You must be signed in to change notification settings - Fork 3k
[Trainer] support sharding for trainer. #3352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
paddle.distributed.all_reduce( | ||
p.bw_storage, group=self.dp_group) | ||
|
||
elif (args.recompute and args.local_rank != -1): | ||
fused_allreduce_gradients(list(model.parameters()), | ||
None) | ||
|
||
if self.do_grad_scaling: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@haohongxiang 这里的scaler使用体验,请与官方scaler一致。
@@ -1117,11 +1266,25 @@ def _save_checkpoint(self, model, metrics=None): | |||
|
|||
self.save_model(output_dir) | |||
|
|||
if self.sharding is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@haohongxiang 提供接口,rank 0卡,收集参数到cpu
paddlenlp/trainer/trainer_base.py
Outdated
if self.do_grad_scaling: | ||
self.scaler.minimize(self.optimizer, tr_loss) | ||
# TODO: fix sharding stage2 stage3 with original scaler useage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处的api使用上有问题
@@ -395,6 +395,35 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并 | |||
|
|||
The value of initial scale_loss for fp16. (default: 32768) | |||
|
|||
--sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以写一下NOTICE,目前可用的状态
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充
|--------------------------------|-------|-------|-------------|------------------|-------------|-------------|------|-------| | ||
| | mcc | acc | acc | pearson | acc | acc | acc | acc | | ||
| T5-v1_1-base Paddle | 47.6845 | 94.38 | 84.31 | 87.74 | 88.05 | 85.39 | 90.518 | 65.70 | | ||
| epoch | 100 | 10 | 100 | 100 | 3 | 3 | 10 | 100 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
T5_v1_1_base效果对齐了吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如线下沟通
if self.label2id: | ||
label = self.label2id[label] | ||
if pred not in self.label2id: | ||
pred = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么label为0的时候,pred = 1?
这里的逻辑可以再具体说一下
生成的label不在label list里面,最终预付label 0,这块的情况在encoder不会出现,这里能具体说一下数据指标怎么对齐了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for sharding+dp
@@ -971,7 +971,8 @@ def greedy_search(self, input_ids, logits_processors, max_length, | |||
probs = F.softmax(logits) | |||
probs = paddle.log(probs) | |||
next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1) | |||
next_scores = paddle.index_sample(probs, next_tokens) | |||
next_scores = paddle.index_sample(probs.astype("float32"), | |||
next_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的话,index_sample 没有 fp16/bf16 kernel
f"{self.dtype} not recognized. `dtype` should be set to either `paddle.float32` or `paddle.float16`" | ||
) | ||
encoder_extended_attention_mask = ( | ||
1.0 - encoder_extended_attention_mask) * -1e4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For bf16 dtype
labels.flatten()) | ||
loss = loss_fct( | ||
lm_logits.reshape( | ||
shape=[-1, lm_logits.shape[-1]]).astype("float32"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CrossEntropyLoss 没有fp16/bf16 kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Description
support sharding for trainer.
stage1: 可以支持
stage2:部分支持
stage3:暂不支持