20
20
from functools import partial
21
21
22
22
import paddle
23
- from dpo_argument import DPODataArgument , DPOModelArgument , DPOTrainingArguments
24
-
25
- from paddlenlp .datasets import ZeroPaddingMapDataset , load_dataset
26
- from paddlenlp .trainer import (
27
- IntervalStrategy ,
28
- PdArgumentParser ,
29
- get_last_checkpoint ,
30
- set_seed ,
23
+ from dpo_argument import (
24
+ DPOConfig ,
25
+ DPODataArgument ,
26
+ DPOModelArgument ,
27
+ DPOTrainingArguments ,
31
28
)
29
+
30
+ from paddlenlp .datasets import (
31
+ ZeroPaddingIterableDataset ,
32
+ ZeroPaddingMapDataset ,
33
+ load_dataset ,
34
+ )
35
+ from paddlenlp .peft import LoRAConfig , LoRAModel
36
+ from paddlenlp .trainer import PdArgumentParser , get_last_checkpoint , set_seed
32
37
from paddlenlp .transformers import (
33
38
AutoConfig ,
34
39
AutoModelForCausalLM ,
40
+ AutoModelForCausalLMPipe ,
35
41
AutoTokenizer ,
36
42
LlamaForCausalLM ,
37
43
LlamaForCausalLMPipe ,
43
49
preference_collate_fn ,
44
50
preprocess_preference_data ,
45
51
)
52
+ from paddlenlp .utils .llm_utils import get_lora_target_modules
46
53
from paddlenlp .utils .log import logger
47
54
48
55
flash_mask_support_list = [LlamaForCausalLM , LlamaForCausalLMPipe ]
49
56
50
57
51
58
def main ():
52
59
"""main"""
53
- parser = PdArgumentParser ((DPOModelArgument , DPODataArgument , DPOTrainingArguments ))
54
- if len (sys .argv ) = = 2 and sys .argv [1 ].endswith (".json" ):
55
- model_args , data_args , training_args = parser .parse_json_file ( json_file = os . path . abspath ( sys . argv [ 1 ]) )
60
+ parser = PdArgumentParser ((DPOModelArgument , DPODataArgument , DPOTrainingArguments , DPOConfig ))
61
+ if len (sys .argv ) > = 2 and sys .argv [1 ].endswith (".json" ):
62
+ model_args , data_args , training_args , dpo_config = parser .parse_json_file_and_cmd_lines ( )
56
63
else :
57
- model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
58
-
59
- training_args .print_config (model_args , "Model" )
60
- training_args .print_config (data_args , "Data" )
61
- if training_args .max_steps > 0 :
62
- training_args .num_train_epochs = 1
63
- if data_args .autotuner_benchmark :
64
- training_args .num_train_epochs = 1
65
- training_args .max_steps = 5
66
- training_args .do_train = True
67
- training_args .do_export = False
68
- training_args .do_predict = False
69
- training_args .do_eval = False
70
- training_args .overwrite_output_dir = True
71
- training_args .load_best_model_at_end = False
72
- training_args .report_to = []
73
- training_args .save_strategy = IntervalStrategy .NO
74
- training_args .evaluation_strategy = IntervalStrategy .NO
75
- if data_args .benchmark :
76
- training_args .do_train = True
77
- training_args .do_export = False
78
- training_args .do_predict = False
79
- training_args .do_eval = False
80
- training_args .overwrite_output_dir = True
81
- training_args .load_best_model_at_end = False
82
- training_args .save_strategy = IntervalStrategy .NO
83
- training_args .evaluation_strategy = IntervalStrategy .NO
64
+ model_args , data_args , training_args , dpo_config = parser .parse_args_into_dataclasses ()
84
65
85
66
paddle .set_device (training_args .device )
86
67
set_seed (training_args .seed )
68
+ if dpo_config .loss_type == "orpo" :
69
+ dpo_config .reference_free = True
70
+ dpo_config .sft_loss_ratio = 1.0
71
+ dpo_config .loss_type = "or"
72
+ logger .info ("orpo loss_type is equal to sft_loss + pref_loss_ratio * or_loss." )
73
+ if dpo_config .loss_type in ["or" , "simpo" ] and not dpo_config .reference_free :
74
+ dpo_config .reference_free = True
75
+ logger .warning (f"{ dpo_config .loss_type } loss_type only supports reference_free. Set reference_free to True." )
76
+
77
+ training_args .print_config (model_args , "Model" )
78
+ training_args .print_config (data_args , "Data" )
79
+ training_args .print_config (dpo_config , "DPOConfig" )
87
80
88
81
logger .warning (
89
82
f"Process rank: { training_args .local_rank } , device: { training_args .device } , world_size: "
@@ -116,51 +109,102 @@ def main():
116
109
tensor_parallel_rank = training_args .tensor_parallel_rank ,
117
110
recompute_granularity = model_args .recompute_granularity ,
118
111
use_flash_attention = model_args .use_flash_attention ,
119
- tensor_parallel_output = True ,
112
+ tensor_parallel_output = model_args . tensor_parallel_output ,
120
113
)
121
114
if training_args .pipeline_parallel_degree > 1 :
122
115
raise ValueError ("DPO does not support pipeline parallelism yet." )
123
-
124
- if not data_args .autotuner_benchmark :
125
- ref_model = AutoModelForCausalLM .from_pretrained (** model_kwargs )
126
- config = AutoConfig .from_pretrained (** model_kwargs )
127
- model = AutoModelForCausalLM .from_config (config )
128
- model .set_state_dict (ref_model .state_dict ())
116
+ if training_args .pipeline_parallel_degree > 1 :
117
+ model_class = AutoModelForCausalLMPipe
118
+ else :
119
+ model_class = AutoModelForCausalLM
120
+ if not training_args .autotuner_benchmark or model_args .weight_quantize_algo is not None :
121
+ model = model_class .from_pretrained (** model_kwargs )
122
+ # for DPO save
123
+ model .config .dpo_config = None
124
+ if not dpo_config .reference_free and not dpo_config .lora :
125
+ config = AutoConfig .from_pretrained (** model_kwargs )
126
+ ref_model = model_class .from_config (config , dtype = dtype )
127
+ ref_model .set_state_dict (model .state_dict ())
128
+ else :
129
+ ref_model = None
129
130
else :
130
131
config = AutoConfig .from_pretrained (** model_kwargs )
131
- model = AutoModelForCausalLM .from_config (config )
132
- ref_config = AutoConfig .from_pretrained (** model_kwargs )
133
- ref_model = AutoModelForCausalLM .from_config (ref_config )
134
- model .set_state_dict (ref_model .state_dict ())
132
+ model = model_class .from_config (config , dtype = dtype )
133
+ if not dpo_config .reference_free and not dpo_config .lora :
134
+ ref_model = model_class .from_config (config , dtype = dtype )
135
+ else :
136
+ ref_model = None
135
137
136
138
if model_args .flash_mask and not model .config .use_flash_attention :
137
139
logger .warning ("`flash_mask` must use with zero padding and flash attention." )
138
140
model .config .use_flash_attention = True
139
141
140
142
if model_args .flash_mask and not any (isinstance (model , cls ) for cls in flash_mask_support_list ):
141
143
raise NotImplementedError (f"{ model .__class__ } not support flash mask." )
142
- if training_args .sequence_parallel :
144
+
145
+ if model_args .sequence_parallel :
143
146
register_sequence_parallel_allreduce_hooks (
144
- model , training_args .gradient_accumulation_steps , training_args .fuse_sequence_parallel_allreduce
147
+ model , training_args .gradient_accumulation_steps , model_args .fuse_sequence_parallel_allreduce
145
148
)
146
149
if model_args .tokenizer_name_or_path is not None :
147
150
tokenizer = AutoTokenizer .from_pretrained (model_args .tokenizer_name_or_path )
148
151
else :
149
152
tokenizer = AutoTokenizer .from_pretrained (model_args .model_name_or_path )
150
153
# TODO: support chat template in next pr
151
- # tokenizer.chat_template = None
154
+ tokenizer .chat_template = None
152
155
logger .info ("Loading model & tokenizer successfully !" )
153
156
157
+ if dpo_config .lora :
158
+ if training_args .sharding_parallel_degree > 1 :
159
+ assert (
160
+ "enable_stage1_overlap" not in training_args .sharding_parallel_config
161
+ ), "Currently not support enabling sharding_stage1_overlap in lora mode."
162
+ if model_args .lora_path is None :
163
+ target_modules = get_lora_target_modules (model )
164
+ if model_args .rslora_plus :
165
+ model_args .rslora = True
166
+ model_args .lora_plus_scale = 4
167
+ model_args .lora_alpha = 4
168
+ if model_args .weight_quantize_algo is not None :
169
+ if model_args .rslora or model_args .lora_plus_scale != 1.0 :
170
+ logger .info ("Weight quantization is not supported in LoRA+ and RsLoRA." )
171
+ if model_args .lora_alpha == - 1 :
172
+ if model_args .rslora :
173
+ model_args .lora_alpha = 4
174
+ else :
175
+ model_args .lora_alpha = 2 * model_args .lora_rank
176
+ lora_config = LoRAConfig (
177
+ target_modules = target_modules ,
178
+ r = model_args .lora_rank ,
179
+ lora_alpha = 2 * model_args .lora_rank if not model_args .rslora else 4 ,
180
+ rslora = model_args .rslora ,
181
+ lora_plus_scale = model_args .lora_plus_scale ,
182
+ tensor_parallel_degree = training_args .tensor_parallel_degree ,
183
+ dtype = dtype ,
184
+ base_model_name_or_path = model_args .model_name_or_path ,
185
+ use_quick_lora = model_args .use_quick_lora ,
186
+ )
187
+ model = LoRAModel (model , lora_config )
188
+ else :
189
+ model = LoRAModel .from_pretrained (model = model , lora_path = model_args .lora_path )
190
+
191
+ model .print_trainable_parameters ()
192
+
154
193
logger .info ("Start to create dataset" )
155
194
trans_func = partial (preprocess_preference_data , tokenizer = tokenizer , data_args = data_args , model_args = model_args )
195
+ if data_args .lazy :
196
+ zero_padding_dataset = ZeroPaddingIterableDataset
197
+ else :
198
+ zero_padding_dataset = ZeroPaddingMapDataset
156
199
if training_args .do_train and training_args .should_load_dataset :
157
200
train_ds = load_dataset (
158
201
"json" ,
159
202
data_files = data_args .train_dataset_path ,
203
+ lazy = data_args .lazy ,
160
204
)[0 ]
161
205
logger .info ("Creating train Zero Padding Data Stream. This may take a few minutes." )
162
206
train_ds = (
163
- ZeroPaddingMapDataset (
207
+ zero_padding_dataset (
164
208
train_ds .map (trans_func ),
165
209
tokenizer = tokenizer ,
166
210
max_length = data_args .max_seq_len ,
@@ -176,10 +220,11 @@ def main():
176
220
eval_ds = load_dataset (
177
221
"json" ,
178
222
data_files = data_args .dev_dataset_path ,
223
+ lazy = data_args .lazy ,
179
224
)[0 ]
180
225
logger .info ("Creating dev Zero Padding Data Stream. This may take a few minutes." )
181
226
eval_ds = (
182
- ZeroPaddingMapDataset (
227
+ zero_padding_dataset (
183
228
eval_ds .map (trans_func ),
184
229
tokenizer = tokenizer ,
185
230
max_length = data_args .max_seq_len ,
@@ -194,6 +239,7 @@ def main():
194
239
trainer = DPOTrainer (
195
240
model = model ,
196
241
ref_model = ref_model ,
242
+ dpo_config = dpo_config ,
197
243
args = training_args ,
198
244
train_dataset = train_ds ,
199
245
eval_dataset = eval_ds ,
@@ -202,17 +248,18 @@ def main():
202
248
preference_collate_fn ,
203
249
max_seq_len = data_args .max_seq_len ,
204
250
),
251
+ ignore_eos_token = True ,
205
252
)
206
253
207
254
if training_args .do_train :
208
255
train_result = trainer .train (resume_from_checkpoint = last_checkpoint )
209
256
210
- if not data_args .autotuner_benchmark and not data_args .benchmark :
257
+ if not training_args .autotuner_benchmark and not training_args .benchmark :
211
258
trainer .save_model (merge_tensor_parallel = training_args .tensor_parallel_degree > 1 )
212
259
trainer .log_metrics ("train" , train_result .metrics )
213
260
trainer .save_metrics ("train" , train_result .metrics )
214
261
trainer .save_state ()
215
- if data_args .benchmark :
262
+ if training_args .benchmark :
216
263
total_effective_tokens , total_tokens = calculate_effective_tokens (
217
264
training_args , train_ds , data_args .max_seq_len
218
265
)
0 commit comments