Skip to content

Commit 935f102

Browse files
authored
[LLM] support trainer for gpt and llama pre-training. (#6053)
* support trainer for gpt and llama pretrain. * update * fix * fix use flash. * support llama pre-train and post-train * fix * support use virtual pp degree.
1 parent a37bdfe commit 935f102

File tree

16 files changed

+1461
-110
lines changed

16 files changed

+1461
-110
lines changed

examples/language_model/llama/README.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,65 @@ Llama 模型的权重的使用则需要遵循[License](../../../paddlenlp/transf
1313

1414
<a name="1"></a>
1515

16+
## 预训练
17+
18+
预训练数据制作参考[此处](../../../model_zoo/ernie-1.0/preprocess/docs/OpenWebText2.md)
19+
20+
为了方便用户运行测试本模型,本项目提供了处理好的100k条doc的训练样本:
21+
```shell
22+
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_ids.npy
23+
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_idx.npz
24+
```
25+
26+
将所有预处理得到的文件统一放入一个文件夹中,以备训练使用:
27+
28+
```
29+
mkdir data
30+
mv llama_openwebtext_100k_ids.npy ./data
31+
mv llama_openwebtext_100k_idx.npz ./data
32+
```
33+
34+
使用下面脚本,即可在llama-7b的基础上,继续训练.
35+
```shell
36+
python -u -m paddle.distributed.launch \
37+
--gpus "0,1,2,3,4,5,6,7" \
38+
--log_dir "output/$task_name""_log" \
39+
run_pretrain.py \
40+
--model_type "llama" \
41+
--model_name_or_path "facebook/llama-7b" \
42+
--tokenizer_name_or_path "facebook/llama-7b" \
43+
--input_dir "./data" \
44+
--output_dir "output/$task_name" \
45+
--split 949,50,1 \
46+
--max_seq_length 2048 \
47+
--per_device_train_batch_size 1 \
48+
--per_device_eval_batch_size 1 \
49+
--use_flash_attention 1 \
50+
--use_fused_rms_norm 0 \
51+
--fp16 \
52+
--fp16_opt_level "O2" \
53+
--scale_loss 1024 \
54+
--learning_rate 0.00001 \
55+
--min_learning_rate 0.000005 \
56+
--max_steps 10000 \
57+
--save_steps 5000 \
58+
--weight_decay 0.01 \
59+
--warmup_ratio 0.01 \
60+
--max_grad_norm 1.0 \
61+
--logging_steps 20\
62+
--dataloader_num_workers 1 \
63+
--sharding "stage2" \
64+
--eval_steps 1000 \
65+
--report_to "visualdl" \
66+
--disable_tqdm true \
67+
--continue_training 1\
68+
--recompute 1 \
69+
--do_train \
70+
--do_eval \
71+
--device "gpu"
72+
```
73+
74+
1675
## 微调
1776

1877
```shell
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../..//model_zoo/gpt/dataset.py

examples/language_model/llama/modeling_pp.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class PipelinePretrainedModel(PretrainedModel):
121121
_pipeline_name_mapping = None
122122

123123
def __init__(self, config, *args, **kwargs):
124+
raise ValueError()
124125
super().__init__(config, *args, **kwargs)
125126

126127
def add_sequential_layer(self, layer_desc, name_prefix=""):
@@ -138,23 +139,39 @@ def _set_pipeline_name_mapping(self, mappings=None):
138139
else:
139140
mapping = {}
140141
state_dict_keys = list(super().state_dict().keys())
142+
first_key = state_dict_keys[0].split(".")
143+
# if use virtual pp_degree, the prefix is like 0.0.xxx
144+
# else it will be like 0.xxx
145+
use_virtual_pp_degree = first_key[0].isdigit() and first_key[1].isdigit()
146+
141147
prefixs = self.get_sequential_name_prefixs()
142148
for k in state_dict_keys:
143149
name_splited = k.split(".")
144-
name_splited[0] = prefixs[name_splited[0]]
145-
mapping[".".join(name_splited)] = k
150+
if use_virtual_pp_degree:
151+
idx = str(int(name_splited[0]) + int(name_splited[1]))
152+
single_name = [prefixs[idx]]
153+
single_name.extend(name_splited[2:])
154+
else:
155+
idx = name_splited[0]
156+
single_name = [prefixs[idx]]
157+
single_name.extend(name_splited[1:])
158+
mapping[".".join(single_name)] = k
159+
146160
self._pipeline_name_mapping = mapping
147161

148162
return self._pipeline_name_mapping
149163

150164
def state_dict(self, *args, **kwargs):
151165
state_dict = super().state_dict(*args, **kwargs)
152-
prefixs = self.get_sequential_name_prefixs()
166+
167+
if self._pipeline_name_mapping is None:
168+
self._set_pipeline_name_mapping()
169+
assert len(self._pipeline_name_mapping) > 0, "The pipeline stage must have parameters!"
170+
pp_to_single_mapping = {v: k for k, v in self._pipeline_name_mapping.items()}
171+
153172
for k in list(state_dict.keys()):
154173
v = state_dict.pop(k)
155-
name_splited = k.split(".")
156-
name_splited[0] = prefixs[name_splited[0]]
157-
state_dict[".".join(name_splited)] = v
174+
state_dict[pp_to_single_mapping[k]] = v
158175

159176
return state_dict
160177

@@ -169,7 +186,8 @@ def set_state_dict(self, state_dict, *args, **kwargs):
169186
continue
170187
state_dict[self._pipeline_name_mapping[k]] = v
171188

172-
return super().set_state_dict(state_dict, *args, **kwargs)
189+
ret = super().set_state_dict(state_dict, *args, **kwargs)
190+
return ret
173191

174192

175193
class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
@@ -182,28 +200,25 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
182200
config_class = LlamaConfig
183201

184202
_get_tensor_parallel_mappings = LlamaPretrainedModel._get_tensor_parallel_mappings
203+
185204
# NO base_model_prefix !!!!
186205

187206
def __init__(
188207
self,
189208
config,
190-
# num_partitions=1,
191-
# topology=None,
192-
use_recompute=None,
193-
# fused_linear=False,
194-
# fuse_attn_qkv=False,
209+
# use_recompute=None,
195210
# scale_qk_by_layer_num=True,
196-
recompute_granularity="full",
197-
virtual_pp_degree=1,
211+
# recompute_granularity="full",
212+
# virtual_pp_degree=4,
198213
# sequence_parallel=False,
199214
# no_recompute_layers=None,
200215
pp_recompute_interval=1,
201-
# use_flash_attn=False,
202-
# fused_softmax_with_triangular=False,
203216
):
204217
self.config = config
205-
if use_recompute is None:
206-
use_recompute = self.config.use_recompute
218+
219+
use_recompute = self.config.use_recompute
220+
recompute_granularity = self.config.recompute_granularity
221+
virtual_pp_degree = self.config.virtual_pp_degree
207222

208223
hcg = get_hcg()
209224
tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1)

0 commit comments

Comments
 (0)