-
Notifications
You must be signed in to change notification settings - Fork 3k
[Trainer] support sharding for trainer. #3352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
09b53e4
supprt sharding for trainer.
ZHUI b8e3c4c
fix amp
ZHUI 19ab021
fix sharding.
ZHUI ec2589b
support sharding for t5.
ZHUI 76a7ce2
add t5-3b
ZHUI 150456d
add bf16 support.
ZHUI 280a3e2
fix memory leak
ZHUI fce23ce
optimize stage1 finetune.
ZHUI 65a5eb8
optimize v1_1 finetune.
ZHUI d8c9290
Update README.md
ZHUI b9f6f2f
fix as reviews.
ZHUI c6c568d
Merge branch 'trainer_sharding' of https://github.com/ZHUI/PaddleNLP …
ZHUI 9193cca
Merge branch 'develop' into trainer_sharding
ZHUI 1cc507f
Merge branch 'develop' into trainer_sharding
ZHUI 62b6ccd
fix bugs.
ZHUI 1101e08
Merge remote-tracking branch 'zhui/trainer_sharding' into trainer_sha…
ZHUI a9363fb
fix sharding stage2
ZHUI 769aa77
support iterable dataset.
ZHUI 0368a4e
fix save strategy.
ZHUI 835b773
Merge remote-tracking branch 'origin/develop' into trainer_sharding
ZHUI eac339e
Merge remote-tracking branch 'origin/develop' into trainer_sharding
ZHUI 69c5082
fix trainer.
ZHUI 827de30
fix bug
ZHUI 4cc41bf
Merge branch 'develop' into trainer_sharding
ZHUI File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,12 +46,53 @@ python run_glue.py \ | |
- `scheduler_type` scheduler类型,可选linear和cosine,默认linear。 | ||
- `output_dir` 表示模型保存路径。 | ||
|
||
使用trainer进行Fine-tuning: | ||
```shell | ||
python -m paddle.distributed.launch --gpus "0,1,2,3" run_glue_trainer.py \ | ||
--model_name_or_path t5-base \ | ||
--task_name rte \ | ||
--max_seq_length 256 \ | ||
--do_train \ | ||
--do_eval \ | ||
--per_device_train_batch_size 16 \ | ||
--per_device_eval_batch_size 64 \ | ||
--learning_rate 1e-4 \ | ||
--weight_decay 0.01 \ | ||
--warmup_ratio 0.1 \ | ||
--num_train_epochs 10 \ | ||
--eval_steps 200 \ | ||
--logging_steps 20 \ | ||
--save_steps 200 \ | ||
--save_total_limit 3 \ | ||
--metric_for_best_model "eval_accuarcy" \ | ||
--fp16 false \ | ||
--fp16_opt_level "O1" \ | ||
--recompute true \ | ||
--sharding "stage1" \ | ||
--overwrite_output_dir \ | ||
--disable_tqdm true \ | ||
--output_dir outputs/rte/ | ||
``` | ||
具体参数含义请参见: https://paddlenlp.readthedocs.io/zh/latest/trainer.html | ||
|
||
###### t5-base模型在GLUE开发集上的结果: | ||
| Model | cola | sst-2 | mrpc | sts-b | qqp | mnli | qnli | rte | mean | | ||
|--------------------------------|-------|-------|-------------|------------------|-------------|-------------|------|-------|-------| | ||
| | mcc | acc | acc | pearson | acc | acc | acc | acc | | | ||
| T5-base-Paddle | 61.74 | 95.18 | 90.44 | 90.09 | 91.60 | 87.18 | 93.56 | 81.95 | 86.4675 | | ||
|
||
###### t5_v1_1-base模型在GLUE开发集上的结果: | ||
使用`run_glue_trainer.py`运行,由于`t5_v1_1-base`没有在glue任务上进行训练过,直接生成label的策略需要的训练时间需要更长。 | ||
| Model | cola | sst-2 | mrpc | sts-b | qqp | mnli | qnli | rte | | ||
|--------------------------------|-------|-------|-------------|------------------|-------------|-------------|------|-------| | ||
| | mcc | acc | acc | pearson | acc | acc | acc | acc | | ||
| T5-v1_1-base Paddle | 47.6845 | 94.38 | 84.31 | 87.74 | 88.05 | 85.39 | 90.518 | 65.70 | | ||
| epoch | 100 | 10 | 100 | 100 | 3 | 3 | 10 | 100 | | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. T5_v1_1_base效果对齐了吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如线下沟通 |
||
|
||
注: | ||
- 直接生成label的finetune方式难度较大,前期基本学习如何正确生成label标签,后期才学习分类任务。 | ||
- 生成的label标签设计,标签差异大一些,效果会更好一些。 | ||
- `qqp`,`mnli`数据集适当增大训练epoch数,可以取得更好效果。 | ||
|
||
### GLUE Demo测试 | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以写一下NOTICE,目前可用的状态
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已补充