Skip to content

Commit 4ab0df1

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP into add_split_param
2 parents ae9ddce + 76a118b commit 4ab0df1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+3007
-5424
lines changed

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,19 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩
115115

116116
* 大模型预训练、精调(包含 SFT、PEFT 技术)、对齐、量化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型预训练、精调、对齐、量化支持列表如下:
117117

118-
| 模型名称/能力支持 | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
119-
|:------------------:|:--------:|:---:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
120-
| Llama |||||||||
121-
| Qwen |||||| 🚧 | 🚧 ||
122-
| Mixtral ||| || 🚧 | 🚧 | 🚧 | 🚧 |
123-
| Mistral ||| ||| 🚧 | 🚧 ||
124-
| Baichuan/Baichuan2 |||||| 🚧 |||
125-
| ChatGLM-6B ||| || 🚧 | 🚧 |||
126-
| ChatGLM2/ChatGLM3 ||| || 🚧 | 🚧 |||
127-
| Bloom ||| || 🚧 | 🚧 |||
128-
| GPT-3 ||| 🚧 | 🚧 | 🚧 | 🚧 | 🚧 ||
129-
| OPT ||| | 🚧 | 🚧 | 🚧 | 🚧 ||
130-
| Yuan2 ||| | 🚧 | 🚧 | 🚧 | 🚧 ||
118+
| 模型名称/能力支持 | Pretrain | SFT | FlashMask | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
119+
|:------------------:|:--------:|:---:|:---------:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
120+
| Llama ||| | ||||||
121+
| Qwen ||| | ||| 🚧 | 🚧 ||
122+
| Mixtral ||| 🚧 | || 🚧 | 🚧 | 🚧 | 🚧 |
123+
| Mistral ||| 🚧 | ||| 🚧 | 🚧 ||
124+
| Baichuan/Baichuan2 ||| | ||| 🚧 |||
125+
| ChatGLM-6B ||| 🚧 | || 🚧 | 🚧 |||
126+
| ChatGLM2/ChatGLM3 ||| 🚧 | || 🚧 | 🚧 |||
127+
| Bloom ||| 🚧 | || 🚧 | 🚧 |||
128+
| GPT-3 ||| 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 ||
129+
| OPT ||| 🚧 | | 🚧 | 🚧 | 🚧 | 🚧 ||
130+
| Yuan2 ||| 🚧 | | 🚧 | 🚧 | 🚧 | 🚧 ||
131131
------------------------------------------------------------------------------------------
132132

133133
* [大模型推理](./llm/docs/predict/inference.md)已支持 LLaMA 系列、Qwen 系列、Mistral 系列、ChatGLM 系列、Bloom 系列和 Baichuan 系列,支持 Weight Only INT8及 INT4推理,支持 WAC(权重、激活、Cache KV)进行 INT8、FP8量化的推理,【LLM】模型推理支持列表如下:

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
大模型统一存储文档 <llm/docs/unified_checkpoint.md>
5858
混合并行训练教程 <llm/docs/llm_trainer.rst>
5959
模型权重转换教程 <llm/docs/torch2paddle.md>
60+
大模型DPO文档 <llm/docs/dpo.md>
6061

6162
.. toctree::
6263
:maxdepth: 1

docs/llm/docs/dpo.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../llm/docs/dpo.md

llm/README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,21 @@
1515

1616
## 🛠️ 支持模型列表 🛠️
1717

18-
| Model | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
18+
| Model | Pretrain | SFT | LoRA | Prefix Tuning | DPO/SimPO/ORPO | RLHF | Quantization | Torch convert |
1919
|----------------------------------------|----------|-----|------|---------------|-----|------|--------------|---------------|
2020
| [LLaMA](./config/llama) |||||||||
2121
| [Qwen](./config/qwen) |||||| 🚧 | 🚧 ||
22-
| [Mixtral](./config/mixtral) ||||| 🚧 | 🚧 | 🚧 | 🚧 |
22+
| [Mixtral](./config/mixtral) ||||| | 🚧 | 🚧 | 🚧 |
2323
| [Mistral](./config/mistral) |||||| 🚧 | 🚧 ||
2424
| [Baichuan/Baichuan2](./config/llama) |||||| 🚧 |||
2525
| [ChatGLM-6B](./config/chatglm) ||||| 🚧 | 🚧 |||
26-
| [ChatGLM2/ChatGLM3](./config/chatglm2) ||||| 🚧 | 🚧 |||
26+
| [ChatGLM2/ChatGLM3](./config/chatglm2) ||||| | 🚧 |||
2727
| [Bloom](./config/bloom) ||||| 🚧 | 🚧 |||
2828
| [GPT-3](./config/gpt-3) ||| 🚧 | 🚧 | 🚧 | 🚧 | 🚧 ||
2929
| [OPT](./config/opt) | 🚧 ||| 🚧 | 🚧 | 🚧 | 🚧 ||
30+
| [Gemma](./config/gemma) | 🚧 ||🚧 | 🚧 || 🚧 | 🚧 | 🚧 |
31+
| [Yuan](./config/yuan) |||| 🚧 || 🚧 | 🚧 | 🚧 |
32+
3033

3134
- ✅: Supported
3235
- 🚧: In Progress
@@ -193,6 +196,7 @@ tar -zxvf ultrafeedback_binarized.tar.gz
193196
# DPO 启动命令参考
194197
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/llama/dpo_argument.json
195198
```
199+
更多 DPO 技术细节和使用说明详见[DPO 文档](./docs/dpo.md)
196200

197201
#### 3.2 RLHF
198202

llm/alignment/dpo/dpo_argument.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,11 @@ class DPOConfig:
9191

9292
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
9393
simpo_gamma: float = field(default=0.5, metadata={"help": "the gamma parameter for SimPO loss"})
94-
normalize_logps: bool = field(
95-
default=True,
96-
metadata={"help": "Apply logprobs normalization."},
97-
)
9894
label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"})
9995
loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"})
10096
pref_loss_ratio: float = field(default=1.0, metadata={"help": "DPO loss ratio"})
10197
sft_loss_ratio: float = field(default=0.0, metadata={"help": "SFT loss ratio"})
102-
dpop_lambda: float = field(default=50, metadata={"help": "SFT loss ratio"})
98+
dpop_lambda: float = field(default=50, metadata={"help": "dpop_lambda"})
10399
ref_model_update_steps: int = field(default=-1, metadata={"help": "Update ref model state dict "})
104100
reference_free: bool = field(default=False, metadata={"help": "No reference model."})
105101
lora: bool = field(default=False, metadata={"help": "Use LoRA model."})
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"model_name_or_path": "deepseek-ai/DeepSeek-V2-Lite",
3+
"tokenizer_name_or_path": "deepseek-ai/DeepSeek-V2-Lite",
4+
"input_dir": "./data",
5+
"output_dir": "./checkpoints/pretrain_ckpts",
6+
"per_device_train_batch_size": 1,
7+
"gradient_accumulation_steps": 1,
8+
"per_device_eval_batch_size": 1,
9+
"tensor_parallel_degree": 1,
10+
"pipeline_parallel_degree": 1,
11+
"sharding_parallel_degree": 1,
12+
"sharding": "stage2",
13+
"virtual_pp_degree": 1,
14+
"sequence_parallel": 0,
15+
"use_flash_attention": true,
16+
"max_seq_length": 4096,
17+
"learning_rate": 3e-05,
18+
"min_learning_rate": 3e-06,
19+
"warmup_steps": 30,
20+
"logging_steps": 1,
21+
"max_steps": 10000,
22+
"save_steps": 5000,
23+
"eval_steps": 1000,
24+
"weight_decay": 0.01,
25+
"bf16": true,
26+
"fp16_opt_level": "O2",
27+
"warmup_ratio": 0.01,
28+
"max_grad_norm": 1.0,
29+
"dataloader_num_workers": 1,
30+
"continue_training": 1,
31+
"do_train": true,
32+
"do_eval": true,
33+
"do_predict": true,
34+
"disable_tqdm": true,
35+
"recompute": true,
36+
"distributed_dataloader": 1,
37+
"recompute_granularity": "full",
38+
"unified_checkpoint": true,
39+
"save_total_limit": 2
40+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"model_name_or_path": "deepseek-ai/DeepSeek-V2-Lite",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/sft_ckpts",
5+
"per_device_train_batch_size": 1,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 3e-05,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 1024,
16+
"max_length": 2048,
17+
"bf16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 1,
28+
"pipeline_parallel_degree": 1,
29+
"sharding": "stage2",
30+
"zero_padding": false,
31+
"unified_checkpoint": true,
32+
"use_flash_attention": true
33+
}

0 commit comments

Comments
 (0)