Skip to content

[Trainer] support sharding for trainer. #3352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Nov 15, 2022

Conversation

ZHUI
Copy link
Collaborator

@ZHUI ZHUI commented Sep 22, 2022

PR types

New features

PR changes

APIs

Description

support sharding for trainer.

stage1: 可以支持

stage2:部分支持

  • offload 暂不支持,需要修复pure_fp16

stage3:暂不支持

  • 模型保存存在问题

@ZHUI ZHUI marked this pull request as ready for review September 23, 2022 10:03
paddle.distributed.all_reduce(
p.bw_storage, group=self.dp_group)

elif (args.recompute and args.local_rank != -1):
fused_allreduce_gradients(list(model.parameters()),
None)

if self.do_grad_scaling:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haohongxiang 这里的scaler使用体验,请与官方scaler一致。

@@ -1117,11 +1266,25 @@ def _save_checkpoint(self, model, metrics=None):

self.save_model(output_dir)

if self.sharding is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haohongxiang 提供接口,rank 0卡,收集参数到cpu

Comment on lines 683 to 684
if self.do_grad_scaling:
self.scaler.minimize(self.optimizer, tr_loss)
# TODO: fix sharding stage2 stage3 with original scaler useage.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处的api使用上有问题

@ZHUI ZHUI requested a review from gongweibao October 9, 2022 09:43
@@ -395,6 +395,35 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并

The value of initial scale_loss for fp16. (default: 32768)

--sharding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以写一下NOTICE,目前可用的状态

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充

|--------------------------------|-------|-------|-------------|------------------|-------------|-------------|------|-------|
| | mcc | acc | acc | pearson | acc | acc | acc | acc |
| T5-v1_1-base Paddle | 47.6845 | 94.38 | 84.31 | 87.74 | 88.05 | 85.39 | 90.518 | 65.70 |
| epoch | 100 | 10 | 100 | 100 | 3 | 3 | 10 | 100 |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T5_v1_1_base效果对齐了吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如线下沟通

if self.label2id:
label = self.label2id[label]
if pred not in self.label2id:
pred = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么label为0的时候,pred = 1?
这里的逻辑可以再具体说一下

生成的label不在label list里面,最终预付label 0,这块的情况在encoder不会出现,这里能具体说一下数据指标怎么对齐了?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已补充注释

@ZHUI ZHUI requested a review from wawltor October 24, 2022 11:44
haohongxiang
haohongxiang previously approved these changes Oct 25, 2022
Copy link
Contributor

@haohongxiang haohongxiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for sharding+dp

@@ -971,7 +971,8 @@ def greedy_search(self, input_ids, logits_processors, max_length,
probs = F.softmax(logits)
probs = paddle.log(probs)
next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
next_scores = paddle.index_sample(probs, next_tokens)
next_scores = paddle.index_sample(probs.astype("float32"),
next_tokens)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的话,index_sample 没有 fp16/bf16 kernel

f"{self.dtype} not recognized. `dtype` should be set to either `paddle.float32` or `paddle.float16`"
)
encoder_extended_attention_mask = (
1.0 - encoder_extended_attention_mask) * -1e4
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bf16 dtype

labels.flatten())
loss = loss_fct(
lm_logits.reshape(
shape=[-1, lm_logits.shape[-1]]).astype("float32"),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CrossEntropyLoss 没有fp16/bf16 kernel

wawltor
wawltor previously approved these changes Nov 14, 2022
Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZHUI ZHUI merged commit b35b8d6 into PaddlePaddle:develop Nov 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants