Skip to content

Commit 6f2eff6

Browse files
authored
Merge pull request #1 from PaddlePaddle/ppo-4d/support_uc
PPP 4d/support uc
2 parents 860e61d + 757d3a7 commit 6f2eff6

File tree

9 files changed

+824
-70
lines changed

9 files changed

+824
-70
lines changed

examples/RLHF/models/score_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import paddle
1818
from paddle import nn
1919

20+
import paddlenlp
2021
from paddlenlp.transformers import (
2122
LlamaConfig,
2223
LlamaModel,
@@ -132,3 +133,6 @@ def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]:
132133

133134
mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
134135
return mappings
136+
137+
138+
paddlenlp.transformers.LlamaModelForScore = LlamaModelForScore

examples/RLHF/models/score_model_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ class AutoModelForScore(_BaseAutoModelClass):
4949
_score_module_name: str = "models.score_model"
5050

5151
@classmethod
52-
def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file_path):
53-
with io.open(config_file_path, encoding="utf-8") as f:
54-
config = json.load(f)
52+
def _get_model_class_from_config(cls, pretrained_model_name_or_path, config_file_path, config=None):
53+
if config is None:
54+
with io.open(config_file_path, encoding="utf-8") as f:
55+
config = json.load(f)
5556

5657
# Get class name corresponds to this configuration
5758
if is_standard_config(config):

examples/RLHF/ppo_config.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"reward_model_name_or_path": "PKU-Alignment/beaver-7b-v1.0-reward",
77
"_actor_model_name_or_path": "facebook/llama-7b",
88
"_reward_model_name_or_path": "facebook/llama-7b",
9-
"output_dir": "/root/paddlejob/workspace/guosheng/checkpoints/ppo-sd14pp2-test",
9+
"output_dir": "./ppo-sd14pp2-test",
1010
"max_length": 512,
1111
"temperature": 1.0,
1212
"num_return_sequences":1,
@@ -52,5 +52,7 @@
5252
"comment-PKU_Beaver-max_grad_norm": 1.0,
5353
"max_grad_norm": 1.0,
5454
"adam_beta1": 0.9,
55-
"adam_beta2": 0.95
55+
"adam_beta2": 0.95,
56+
"eval_mode": "tensor_parallel",
57+
"offload_level": "eval"
5658
}

examples/RLHF/ppo_main.py

Lines changed: 212 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,45 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import os
1617
import sys
18+
19+
# os.environ["http_proxy"] = "http://10.162.37.16:8128"
20+
# os.environ["https_proxy"] = "http://10.162.37.16:8128"
21+
# os.environ["no_proxy"] = "localhost,bcebos.com"
22+
# launch would unset http_proxy
23+
# export https_proxy=http://172.19.57.45:3128
24+
25+
# os.environ["http_proxy"] = "http://172.19.56.199:3128"
26+
# os.environ["https_proxy"] = "http://172.19.56.199:3128"
27+
28+
# os.environ["http_proxy"] = "http://172.19.57.45:3128"
29+
# os.environ["https_proxy"] = "http://172.19.57.45:3128"
30+
31+
os.environ["http_proxy"] = "http://10.162.37.16:8128"
32+
os.environ["https_proxy"] = "http://10.162.37.16:8128"
33+
os.environ["no_proxy"] = "localhost,bcebos.com"
34+
35+
# os.environ["http_proxy"] = "agent.baidu.com:8118"
36+
# os.environ["https_proxy"] = "agent.baidu.com:8118"
37+
1738
from dataclasses import dataclass, field
1839
from typing import Any, Dict, Tuple
1940

2041
import paddle
2142
from data import PromptOnlyDataset, SupervisedDataset, parse_dataset
22-
from ppo_trainer import PPOTrainer
43+
from models import AutoModelForScore
44+
from models.score_model import LlamaModelForScore # noqa
45+
from ppo_trainer import PPOTrainer, cleanup_tensor_space, offload_tensor_to_cpu
2346

2447
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, get_last_checkpoint
25-
from paddlenlp.transformers import AutoConfig, AutoTokenizer, LlamaTokenizer
48+
from paddlenlp.transformers import (
49+
AutoConfig,
50+
AutoModelForCausalLM,
51+
AutoTokenizer,
52+
LlamaTokenizer,
53+
)
2654
from paddlenlp.utils.log import logger
2755

2856

@@ -108,6 +136,18 @@ class TrainingArguments(TrainingArguments):
108136
default=16,
109137
metadata={"help": "Batch size (per device) for the training dataloader."},
110138
)
139+
eval_mode: str = field(
140+
default=None,
141+
metadata={
142+
"help": "eval mode for actor model and reward_critic_model, optional for: None, single, tensor_parallel."
143+
},
144+
)
145+
146+
offload_level: str = field(
147+
default=None,
148+
metadata={"help": "Offload model, optional for: eval, reward, optimizer, train_model"},
149+
)
150+
111151
# save_generation_output: bool = field(
112152
# default=False,
113153
# metadata={"help": "Whether to save generated text to file when eval"},
@@ -179,6 +219,10 @@ def main():
179219
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
180220
training_args.print_config(model_args, "Model")
181221
training_args.print_config(data_args, "Data")
222+
if training_args.eval_mode is not None and len(training_args.eval_mode) == 0:
223+
training_args.eval_mode = None
224+
if training_args.eval_mode is None and training_args.offload_level is not None:
225+
training_args.offload_level = training_args.offload_level.replace("eval", "")
182226

183227
# Setup GPU & distributed training
184228
paddle.set_device(training_args.device)
@@ -214,21 +258,17 @@ def main():
214258
dtype = "float32"
215259
training_args.max_length = data_args.max_length
216260

261+
model_class_lm, model_class_score = AutoModelForCausalLM, AutoModelForScore
217262
if training_args.pipeline_parallel_degree > 1:
218-
global AutoModelForCausalLM, AutoModelForScore
219263
from models.model_pp import LlamaPolicyPipe, LlamaValuePipe
220264

221-
AutoModelForCausalLM = LlamaPolicyPipe
222-
AutoModelForScore = LlamaValuePipe
265+
model_class_lm = LlamaPolicyPipe
266+
model_class_score = LlamaValuePipe
223267
extra_args = {
224268
"ptx_coeff": training_args.ptx_coeff,
225269
"clip_range_ratio": training_args.clip_range_ratio,
226270
}
227271
else:
228-
from models import AutoModelForScore
229-
230-
from paddlenlp.transformers import AutoModelForCausalLM
231-
232272
extra_args = {}
233273

234274
# actor model
@@ -241,18 +281,42 @@ def main():
241281
)
242282
if hasattr(model_config, "use_flash_attention"):
243283
model_config.use_flash_attention = model_args.use_flash_attention
244-
actor_model = AutoModelForCausalLM.from_pretrained(
284+
285+
# model_config.num_hidden_layers = 2
286+
287+
actor_model = model_class_lm.from_pretrained(
245288
model_args.actor_model_name_or_path,
246289
config=model_config,
247290
**extra_args,
248291
# ptx_coeff=training_args.ptx_coeff,
249292
# clip_range_ratio=training_args.clip_range_ratio,
250293
)
251-
# reference model
252-
actor_reference_model = AutoModelForCausalLM.from_pretrained(
253-
model_args.actor_model_name_or_path,
254-
config=model_config,
255-
)
294+
if training_args.eval_mode is not None:
295+
config = copy.deepcopy(actor_model.config)
296+
if training_args.eval_mode == "single":
297+
config.tensor_parallel_degree = -1
298+
config.tensor_parallel_rank = 0
299+
actor_eval_model = AutoModelForCausalLM.from_config(config)
300+
# actor_eval_model = AutoModelForCausalLM.from_pretrained(model_args.actor_model_name_or_path, config=config)
301+
else:
302+
actor_eval_model = None
303+
304+
# todo reference model
305+
if training_args.eval_mode is not None:
306+
config = copy.deepcopy(model_config)
307+
if training_args.eval_mode == "single":
308+
config.tensor_parallel_degree = -1
309+
config.tensor_parallel_rank = 0
310+
actor_reference_model = AutoModelForCausalLM.from_pretrained(
311+
model_args.actor_model_name_or_path,
312+
config=config,
313+
)
314+
else:
315+
actor_reference_model = model_class_lm.from_pretrained(
316+
model_args.actor_model_name_or_path,
317+
config=model_config,
318+
)
319+
256320
actor_tokenizer = AutoTokenizer.from_pretrained(
257321
model_args.actor_model_name_or_path, model_max_length=data_args.max_length, padding_side="left"
258322
)
@@ -267,19 +331,33 @@ def main():
267331
)
268332
if hasattr(model_config, "use_flash_attention"):
269333
model_config.use_flash_attention = model_args.use_flash_attention
270-
reward_model = AutoModelForScore.from_pretrained(
271-
model_args.reward_model_name_or_path,
272-
config=model_config,
273-
score_type="reward",
274-
do_normalize=training_args.normalize_reward,
275-
)
334+
# model_config.num_hidden_layers = 2
335+
# todo
336+
if training_args.eval_mode is not None:
337+
config = copy.deepcopy(model_config)
338+
if training_args.eval_mode == "single":
339+
config.tensor_parallel_degree = -1
340+
config.tensor_parallel_rank = 0
341+
reward_model = AutoModelForScore.from_pretrained(
342+
model_args.reward_model_name_or_path,
343+
config=config,
344+
score_type="reward",
345+
do_normalize=training_args.normalize_reward,
346+
)
347+
else:
348+
reward_model = model_class_score.from_pretrained(
349+
model_args.reward_model_name_or_path,
350+
config=model_config,
351+
score_type="reward",
352+
do_normalize=training_args.normalize_reward,
353+
)
276354
reward_tokenizer = AutoTokenizer.from_pretrained(
277355
model_args.reward_model_name_or_path, model_max_length=data_args.max_length, padding_side="right"
278356
)
279357
# critic model
280358
if model_args.reward_critic_model_name_or_path is None:
281359
model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path
282-
reward_critic_model = AutoModelForScore.from_pretrained(
360+
reward_critic_model = model_class_score.from_pretrained(
283361
model_args.reward_critic_model_name_or_path,
284362
config=model_config,
285363
score_type="critic",
@@ -289,6 +367,92 @@ def main():
289367
reward_critic_tokenizer = AutoTokenizer.from_pretrained(
290368
model_args.reward_critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left"
291369
)
370+
if training_args.eval_mode is not None:
371+
config = copy.deepcopy(reward_critic_model.config)
372+
if training_args.eval_mode == "single":
373+
config.tensor_parallel_degree = -1
374+
config.tensor_parallel_rank = 0
375+
reward_critic_eval_model = AutoModelForScore.from_config(config)
376+
# reward_critic_eval_model = AutoModelForScore.from_pretrained(
377+
# model_args.reward_critic_model_name_or_path,config=model_config
378+
# )
379+
else:
380+
reward_critic_eval_model = None
381+
382+
# # actor model
383+
# model_config = AutoConfig.from_pretrained(
384+
# model_args.actor_model_name_or_path,
385+
# tensor_parallel_output=False,
386+
# tensor_parallel_degree=training_args.tensor_parallel_degree,
387+
# tensor_parallel_rank=training_args.tensor_parallel_rank,
388+
# dtype=dtype,
389+
# )
390+
# model_config.num_hidden_layers = 2
391+
# if hasattr(model_config, "use_flash_attention"):
392+
# model_config.use_flash_attention = model_args.use_flash_attention
393+
# actor_model = AutoModelForCausalLM.from_pretrained(
394+
# model_args.actor_model_name_or_path,
395+
# config=model_config,
396+
# )
397+
#
398+
# if training_args.eval_mode is not None:
399+
# config = copy.deepcopy(actor_model.config)
400+
# if training_args.eval_mode == "single":
401+
# config.tensor_parallel_degree = -1
402+
# config.tensor_parallel_rank = 0
403+
# actor_eval_model = AutoModelForCausalLM.from_config(config)
404+
# else:
405+
# actor_eval_model = None
406+
#
407+
# # reference model
408+
# actor_reference_model = AutoModelForCausalLM.from_pretrained(
409+
# model_args.actor_model_name_or_path,
410+
# config=model_config,
411+
# )
412+
# actor_tokenizer = AutoTokenizer.from_pretrained(
413+
# model_args.actor_model_name_or_path, model_max_length=data_args.max_length, padding_side="left"
414+
# )
415+
#
416+
# # reward model
417+
# model_config = AutoConfig.from_pretrained(
418+
# model_args.reward_model_name_or_path,
419+
# tensor_parallel_output=False,
420+
# tensor_parallel_degree=training_args.tensor_parallel_degree,
421+
# tensor_parallel_rank=training_args.tensor_parallel_rank,
422+
# dtype=dtype,
423+
# )
424+
# model_config.num_hidden_layers = 2
425+
# if hasattr(model_config, "use_flash_attention"):
426+
# model_config.use_flash_attention = model_args.use_flash_attention
427+
# reward_model = AutoModelForScore.from_pretrained(
428+
# model_args.reward_model_name_or_path,
429+
# config=model_config,
430+
# score_type="reward",
431+
# do_normalize=training_args.normalize_reward,
432+
# )
433+
# reward_tokenizer = AutoTokenizer.from_pretrained(
434+
# model_args.reward_model_name_or_path, model_max_length=data_args.max_length, padding_side="right"
435+
# )
436+
#
437+
# # critic model
438+
# if model_args.reward_critic_model_name_or_path is None:
439+
# model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path
440+
# reward_critic_model = AutoModelForScore.from_pretrained(
441+
# model_args.reward_critic_model_name_or_path, config=model_config, score_type="critic", do_normalize=False
442+
# )
443+
# reward_critic_tokenizer = AutoTokenizer.from_pretrained(
444+
# model_args.reward_critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left"
445+
# )
446+
#
447+
# if training_args.eval_mode is not None:
448+
# config = copy.deepcopy(reward_critic_model.config)
449+
# if training_args.eval_mode == "single":
450+
# config.tensor_parallel_degree = -1
451+
# config.tensor_parallel_rank = 0
452+
# reward_critic_eval_model = AutoModelForScore.from_config(config)
453+
# else:
454+
# reward_critic_eval_model = None
455+
292456
for tokenizer in [actor_tokenizer, reward_tokenizer, reward_critic_tokenizer]:
293457
if isinstance(tokenizer, LlamaTokenizer) and tokenizer.pad_token_id is None:
294458
tokenizer.pad_token_id = tokenizer.eos_token_id
@@ -307,8 +471,33 @@ def main():
307471
else None
308472
)
309473

474+
# offload
475+
# cleanup actor_eval_model, reward_critic_eval_model
476+
# offload actor_reference_model reward_model
477+
478+
if training_args.offload_level is not None:
479+
if "eval" in training_args.offload_level:
480+
cleanup_tensor_space(actor_eval_model.state_dict())
481+
cleanup_tensor_space(reward_critic_eval_model.state_dict())
482+
if "reward" in training_args.offload_level:
483+
# if pp mode, should lazy offload
484+
offload_tensor_to_cpu(actor_reference_model.state_dict())
485+
offload_tensor_to_cpu(reward_model.state_dict())
486+
310487
trainer = PPOTrainer(
311-
model=(actor_model, actor_reference_model, reward_model, reward_critic_model),
488+
# (policy_model, reference_model, reward_model, value_model)
489+
# policy_model, sft_model, reward_model, value_model
490+
# (policy_model, reference_model, reward_model, value_model,
491+
# (policy_model, reference_model, reward_model, value_model, policy_eval_model, value_eval_model
492+
# (actor_model, actor_reference_model, reward_model, reward_critic_model, actor_eval_model, reward_critic_eval_model
493+
model=(
494+
actor_model,
495+
actor_reference_model,
496+
reward_model,
497+
reward_critic_model,
498+
actor_eval_model,
499+
reward_critic_eval_model,
500+
),
312501
args=training_args,
313502
train_dataset=train_ds,
314503
eval_dataset=dev_ds,

0 commit comments

Comments
 (0)