12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import copy
15
16
import os
16
17
import sys
18
+
19
+ # os.environ["http_proxy"] = "http://10.162.37.16:8128"
20
+ # os.environ["https_proxy"] = "http://10.162.37.16:8128"
21
+ # os.environ["no_proxy"] = "localhost,bcebos.com"
22
+ # launch would unset http_proxy
23
+ # export https_proxy=http://172.19.57.45:3128
24
+
25
+ # os.environ["http_proxy"] = "http://172.19.56.199:3128"
26
+ # os.environ["https_proxy"] = "http://172.19.56.199:3128"
27
+
28
+ # os.environ["http_proxy"] = "http://172.19.57.45:3128"
29
+ # os.environ["https_proxy"] = "http://172.19.57.45:3128"
30
+
31
+ os .environ ["http_proxy" ] = "http://10.162.37.16:8128"
32
+ os .environ ["https_proxy" ] = "http://10.162.37.16:8128"
33
+ os .environ ["no_proxy" ] = "localhost,bcebos.com"
34
+
35
+ # os.environ["http_proxy"] = "agent.baidu.com:8118"
36
+ # os.environ["https_proxy"] = "agent.baidu.com:8118"
37
+
17
38
from dataclasses import dataclass , field
18
39
from typing import Any , Dict , Tuple
19
40
20
41
import paddle
21
42
from data import PromptOnlyDataset , SupervisedDataset , parse_dataset
22
- from ppo_trainer import PPOTrainer
43
+ from models import AutoModelForScore
44
+ from models .score_model import LlamaModelForScore # noqa
45
+ from ppo_trainer import PPOTrainer , cleanup_tensor_space , offload_tensor_to_cpu
23
46
24
47
from paddlenlp .trainer import PdArgumentParser , TrainingArguments , get_last_checkpoint
25
- from paddlenlp .transformers import AutoConfig , AutoTokenizer , LlamaTokenizer
48
+ from paddlenlp .transformers import (
49
+ AutoConfig ,
50
+ AutoModelForCausalLM ,
51
+ AutoTokenizer ,
52
+ LlamaTokenizer ,
53
+ )
26
54
from paddlenlp .utils .log import logger
27
55
28
56
@@ -108,6 +136,18 @@ class TrainingArguments(TrainingArguments):
108
136
default = 16 ,
109
137
metadata = {"help" : "Batch size (per device) for the training dataloader." },
110
138
)
139
+ eval_mode : str = field (
140
+ default = None ,
141
+ metadata = {
142
+ "help" : "eval mode for actor model and reward_critic_model, optional for: None, single, tensor_parallel."
143
+ },
144
+ )
145
+
146
+ offload_level : str = field (
147
+ default = None ,
148
+ metadata = {"help" : "Offload model, optional for: eval, reward, optimizer, train_model" },
149
+ )
150
+
111
151
# save_generation_output: bool = field(
112
152
# default=False,
113
153
# metadata={"help": "Whether to save generated text to file when eval"},
@@ -179,6 +219,10 @@ def main():
179
219
model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
180
220
training_args .print_config (model_args , "Model" )
181
221
training_args .print_config (data_args , "Data" )
222
+ if training_args .eval_mode is not None and len (training_args .eval_mode ) == 0 :
223
+ training_args .eval_mode = None
224
+ if training_args .eval_mode is None and training_args .offload_level is not None :
225
+ training_args .offload_level = training_args .offload_level .replace ("eval" , "" )
182
226
183
227
# Setup GPU & distributed training
184
228
paddle .set_device (training_args .device )
@@ -214,21 +258,17 @@ def main():
214
258
dtype = "float32"
215
259
training_args .max_length = data_args .max_length
216
260
261
+ model_class_lm , model_class_score = AutoModelForCausalLM , AutoModelForScore
217
262
if training_args .pipeline_parallel_degree > 1 :
218
- global AutoModelForCausalLM , AutoModelForScore
219
263
from models .model_pp import LlamaPolicyPipe , LlamaValuePipe
220
264
221
- AutoModelForCausalLM = LlamaPolicyPipe
222
- AutoModelForScore = LlamaValuePipe
265
+ model_class_lm = LlamaPolicyPipe
266
+ model_class_score = LlamaValuePipe
223
267
extra_args = {
224
268
"ptx_coeff" : training_args .ptx_coeff ,
225
269
"clip_range_ratio" : training_args .clip_range_ratio ,
226
270
}
227
271
else :
228
- from models import AutoModelForScore
229
-
230
- from paddlenlp .transformers import AutoModelForCausalLM
231
-
232
272
extra_args = {}
233
273
234
274
# actor model
@@ -241,18 +281,42 @@ def main():
241
281
)
242
282
if hasattr (model_config , "use_flash_attention" ):
243
283
model_config .use_flash_attention = model_args .use_flash_attention
244
- actor_model = AutoModelForCausalLM .from_pretrained (
284
+
285
+ # model_config.num_hidden_layers = 2
286
+
287
+ actor_model = model_class_lm .from_pretrained (
245
288
model_args .actor_model_name_or_path ,
246
289
config = model_config ,
247
290
** extra_args ,
248
291
# ptx_coeff=training_args.ptx_coeff,
249
292
# clip_range_ratio=training_args.clip_range_ratio,
250
293
)
251
- # reference model
252
- actor_reference_model = AutoModelForCausalLM .from_pretrained (
253
- model_args .actor_model_name_or_path ,
254
- config = model_config ,
255
- )
294
+ if training_args .eval_mode is not None :
295
+ config = copy .deepcopy (actor_model .config )
296
+ if training_args .eval_mode == "single" :
297
+ config .tensor_parallel_degree = - 1
298
+ config .tensor_parallel_rank = 0
299
+ actor_eval_model = AutoModelForCausalLM .from_config (config )
300
+ # actor_eval_model = AutoModelForCausalLM.from_pretrained(model_args.actor_model_name_or_path, config=config)
301
+ else :
302
+ actor_eval_model = None
303
+
304
+ # todo reference model
305
+ if training_args .eval_mode is not None :
306
+ config = copy .deepcopy (model_config )
307
+ if training_args .eval_mode == "single" :
308
+ config .tensor_parallel_degree = - 1
309
+ config .tensor_parallel_rank = 0
310
+ actor_reference_model = AutoModelForCausalLM .from_pretrained (
311
+ model_args .actor_model_name_or_path ,
312
+ config = config ,
313
+ )
314
+ else :
315
+ actor_reference_model = model_class_lm .from_pretrained (
316
+ model_args .actor_model_name_or_path ,
317
+ config = model_config ,
318
+ )
319
+
256
320
actor_tokenizer = AutoTokenizer .from_pretrained (
257
321
model_args .actor_model_name_or_path , model_max_length = data_args .max_length , padding_side = "left"
258
322
)
@@ -267,19 +331,33 @@ def main():
267
331
)
268
332
if hasattr (model_config , "use_flash_attention" ):
269
333
model_config .use_flash_attention = model_args .use_flash_attention
270
- reward_model = AutoModelForScore .from_pretrained (
271
- model_args .reward_model_name_or_path ,
272
- config = model_config ,
273
- score_type = "reward" ,
274
- do_normalize = training_args .normalize_reward ,
275
- )
334
+ # model_config.num_hidden_layers = 2
335
+ # todo
336
+ if training_args .eval_mode is not None :
337
+ config = copy .deepcopy (model_config )
338
+ if training_args .eval_mode == "single" :
339
+ config .tensor_parallel_degree = - 1
340
+ config .tensor_parallel_rank = 0
341
+ reward_model = AutoModelForScore .from_pretrained (
342
+ model_args .reward_model_name_or_path ,
343
+ config = config ,
344
+ score_type = "reward" ,
345
+ do_normalize = training_args .normalize_reward ,
346
+ )
347
+ else :
348
+ reward_model = model_class_score .from_pretrained (
349
+ model_args .reward_model_name_or_path ,
350
+ config = model_config ,
351
+ score_type = "reward" ,
352
+ do_normalize = training_args .normalize_reward ,
353
+ )
276
354
reward_tokenizer = AutoTokenizer .from_pretrained (
277
355
model_args .reward_model_name_or_path , model_max_length = data_args .max_length , padding_side = "right"
278
356
)
279
357
# critic model
280
358
if model_args .reward_critic_model_name_or_path is None :
281
359
model_args .reward_critic_model_name_or_path = model_args .reward_model_name_or_path
282
- reward_critic_model = AutoModelForScore .from_pretrained (
360
+ reward_critic_model = model_class_score .from_pretrained (
283
361
model_args .reward_critic_model_name_or_path ,
284
362
config = model_config ,
285
363
score_type = "critic" ,
@@ -289,6 +367,92 @@ def main():
289
367
reward_critic_tokenizer = AutoTokenizer .from_pretrained (
290
368
model_args .reward_critic_model_name_or_path , model_max_length = data_args .max_length , padding_side = "left"
291
369
)
370
+ if training_args .eval_mode is not None :
371
+ config = copy .deepcopy (reward_critic_model .config )
372
+ if training_args .eval_mode == "single" :
373
+ config .tensor_parallel_degree = - 1
374
+ config .tensor_parallel_rank = 0
375
+ reward_critic_eval_model = AutoModelForScore .from_config (config )
376
+ # reward_critic_eval_model = AutoModelForScore.from_pretrained(
377
+ # model_args.reward_critic_model_name_or_path,config=model_config
378
+ # )
379
+ else :
380
+ reward_critic_eval_model = None
381
+
382
+ # # actor model
383
+ # model_config = AutoConfig.from_pretrained(
384
+ # model_args.actor_model_name_or_path,
385
+ # tensor_parallel_output=False,
386
+ # tensor_parallel_degree=training_args.tensor_parallel_degree,
387
+ # tensor_parallel_rank=training_args.tensor_parallel_rank,
388
+ # dtype=dtype,
389
+ # )
390
+ # model_config.num_hidden_layers = 2
391
+ # if hasattr(model_config, "use_flash_attention"):
392
+ # model_config.use_flash_attention = model_args.use_flash_attention
393
+ # actor_model = AutoModelForCausalLM.from_pretrained(
394
+ # model_args.actor_model_name_or_path,
395
+ # config=model_config,
396
+ # )
397
+ #
398
+ # if training_args.eval_mode is not None:
399
+ # config = copy.deepcopy(actor_model.config)
400
+ # if training_args.eval_mode == "single":
401
+ # config.tensor_parallel_degree = -1
402
+ # config.tensor_parallel_rank = 0
403
+ # actor_eval_model = AutoModelForCausalLM.from_config(config)
404
+ # else:
405
+ # actor_eval_model = None
406
+ #
407
+ # # reference model
408
+ # actor_reference_model = AutoModelForCausalLM.from_pretrained(
409
+ # model_args.actor_model_name_or_path,
410
+ # config=model_config,
411
+ # )
412
+ # actor_tokenizer = AutoTokenizer.from_pretrained(
413
+ # model_args.actor_model_name_or_path, model_max_length=data_args.max_length, padding_side="left"
414
+ # )
415
+ #
416
+ # # reward model
417
+ # model_config = AutoConfig.from_pretrained(
418
+ # model_args.reward_model_name_or_path,
419
+ # tensor_parallel_output=False,
420
+ # tensor_parallel_degree=training_args.tensor_parallel_degree,
421
+ # tensor_parallel_rank=training_args.tensor_parallel_rank,
422
+ # dtype=dtype,
423
+ # )
424
+ # model_config.num_hidden_layers = 2
425
+ # if hasattr(model_config, "use_flash_attention"):
426
+ # model_config.use_flash_attention = model_args.use_flash_attention
427
+ # reward_model = AutoModelForScore.from_pretrained(
428
+ # model_args.reward_model_name_or_path,
429
+ # config=model_config,
430
+ # score_type="reward",
431
+ # do_normalize=training_args.normalize_reward,
432
+ # )
433
+ # reward_tokenizer = AutoTokenizer.from_pretrained(
434
+ # model_args.reward_model_name_or_path, model_max_length=data_args.max_length, padding_side="right"
435
+ # )
436
+ #
437
+ # # critic model
438
+ # if model_args.reward_critic_model_name_or_path is None:
439
+ # model_args.reward_critic_model_name_or_path = model_args.reward_model_name_or_path
440
+ # reward_critic_model = AutoModelForScore.from_pretrained(
441
+ # model_args.reward_critic_model_name_or_path, config=model_config, score_type="critic", do_normalize=False
442
+ # )
443
+ # reward_critic_tokenizer = AutoTokenizer.from_pretrained(
444
+ # model_args.reward_critic_model_name_or_path, model_max_length=data_args.max_length, padding_side="left"
445
+ # )
446
+ #
447
+ # if training_args.eval_mode is not None:
448
+ # config = copy.deepcopy(reward_critic_model.config)
449
+ # if training_args.eval_mode == "single":
450
+ # config.tensor_parallel_degree = -1
451
+ # config.tensor_parallel_rank = 0
452
+ # reward_critic_eval_model = AutoModelForScore.from_config(config)
453
+ # else:
454
+ # reward_critic_eval_model = None
455
+
292
456
for tokenizer in [actor_tokenizer , reward_tokenizer , reward_critic_tokenizer ]:
293
457
if isinstance (tokenizer , LlamaTokenizer ) and tokenizer .pad_token_id is None :
294
458
tokenizer .pad_token_id = tokenizer .eos_token_id
@@ -307,8 +471,33 @@ def main():
307
471
else None
308
472
)
309
473
474
+ # offload
475
+ # cleanup actor_eval_model, reward_critic_eval_model
476
+ # offload actor_reference_model reward_model
477
+
478
+ if training_args .offload_level is not None :
479
+ if "eval" in training_args .offload_level :
480
+ cleanup_tensor_space (actor_eval_model .state_dict ())
481
+ cleanup_tensor_space (reward_critic_eval_model .state_dict ())
482
+ if "reward" in training_args .offload_level :
483
+ # if pp mode, should lazy offload
484
+ offload_tensor_to_cpu (actor_reference_model .state_dict ())
485
+ offload_tensor_to_cpu (reward_model .state_dict ())
486
+
310
487
trainer = PPOTrainer (
311
- model = (actor_model , actor_reference_model , reward_model , reward_critic_model ),
488
+ # (policy_model, reference_model, reward_model, value_model)
489
+ # policy_model, sft_model, reward_model, value_model
490
+ # (policy_model, reference_model, reward_model, value_model,
491
+ # (policy_model, reference_model, reward_model, value_model, policy_eval_model, value_eval_model
492
+ # (actor_model, actor_reference_model, reward_model, reward_critic_model, actor_eval_model, reward_critic_eval_model
493
+ model = (
494
+ actor_model ,
495
+ actor_reference_model ,
496
+ reward_model ,
497
+ reward_critic_model ,
498
+ actor_eval_model ,
499
+ reward_critic_eval_model ,
500
+ ),
312
501
args = training_args ,
313
502
train_dataset = train_ds ,
314
503
eval_dataset = dev_ds ,
0 commit comments