Skip to content

Commit bc55104

Browse files
authored
[LLM][TRL] Support DPO with Pipeline Parallel (#9039)
* support dpo/kto pp
1 parent 8212b53 commit bc55104

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2241
-758
lines changed

llm/alignment/dpo/dpo_argument.py

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Optional
1818

1919
from paddlenlp.trainer import TrainingArguments
20+
from paddlenlp.trainer.trainer_utils import IntervalStrategy
2021

2122

2223
def add_start_docstrings(*docstr):
@@ -42,9 +43,66 @@ class DPOTrainingArguments(TrainingArguments):
4243
default="",
4344
metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"},
4445
)
45-
dpo_beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
46-
dpo_label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"})
47-
dpo_loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"})
46+
autotuner_benchmark: bool = field(
47+
default=False,
48+
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
49+
)
50+
benchmark: bool = field(
51+
default=False,
52+
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
53+
)
54+
55+
def __post_init__(self):
56+
super().__post_init__()
57+
if self.autotuner_benchmark:
58+
self.num_train_epochs = 1
59+
self.max_steps = 5
60+
self.do_train = True
61+
self.do_export = False
62+
self.do_predict = False
63+
self.do_eval = False
64+
self.overwrite_output_dir = True
65+
self.load_best_model_at_end = False
66+
self.report_to = []
67+
self.save_strategy = IntervalStrategy.NO
68+
self.evaluation_strategy = IntervalStrategy.NO
69+
if not self.disable_tqdm:
70+
self.logging_steps = 1
71+
self.logging_strategy = IntervalStrategy.STEPS
72+
if self.benchmark:
73+
self.do_train = True
74+
self.do_export = False
75+
self.do_predict = False
76+
self.do_eval = False
77+
self.overwrite_output_dir = True
78+
self.load_best_model_at_end = False
79+
self.save_strategy = IntervalStrategy.NO
80+
self.evaluation_strategy = IntervalStrategy.NO
81+
if not self.disable_tqdm:
82+
self.logging_steps = 1
83+
self.logging_strategy = IntervalStrategy.STEPS
84+
if self.max_steps > 0:
85+
self.num_train_epochs = 1
86+
87+
88+
@dataclass
89+
class DPOConfig:
90+
"""DPOConfig"""
91+
92+
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
93+
simpo_gamma: float = field(default=0.5, metadata={"help": "the gamma parameter for SimPO loss"})
94+
normalize_logps: bool = field(
95+
default=True,
96+
metadata={"help": "Apply logprobs normalization."},
97+
)
98+
label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"})
99+
loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"})
100+
pref_loss_ratio: float = field(default=1.0, metadata={"help": "DPO loss ratio"})
101+
sft_loss_ratio: float = field(default=0.0, metadata={"help": "SFT loss ratio"})
102+
dpop_lambda: float = field(default=50, metadata={"help": "SFT loss ratio"})
103+
ref_model_update_steps: int = field(default=-1, metadata={"help": "Update ref model state dict "})
104+
reference_free: bool = field(default=False, metadata={"help": "No reference model."})
105+
lora: bool = field(default=False, metadata={"help": "Use LoRA model."})
48106

49107

50108
@dataclass
@@ -55,18 +113,16 @@ class DPODataArgument:
55113
dev_dataset_path: str = field(default="./data/dev.jsonl", metadata={"help": "Path to the dev dataset dir."})
56114
max_seq_len: int = field(default=4096, metadata={"help": "Maximum sequence length."})
57115
max_prompt_len: int = field(default=2048, metadata={"help": "Maximum prompt length."})
58-
autotuner_benchmark: bool = field(
59-
default=False,
60-
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
61-
)
62-
benchmark: bool = field(
63-
default=False,
64-
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
65-
)
66116
greedy_zero_padding: bool = field(
67117
default=False,
68118
metadata={"help": "Whether to use Greedy Zero Padding data stream."},
69119
)
120+
lazy: bool = field(
121+
default=False,
122+
metadata={
123+
"help": "Weather to return `MapDataset` or an `IterDataset`.True for `IterDataset`. False for `MapDataset`."
124+
},
125+
)
70126

71127

72128
@dataclass
@@ -95,3 +151,19 @@ class DPOModelArgument:
95151
default=False,
96152
metadata={"help": "whether to use sequence parallel"},
97153
)
154+
tensor_parallel_output: bool = field(
155+
default=True,
156+
metadata={"help": "whether to use tensor_parallel_output"},
157+
)
158+
weight_quantize_algo: str = field(
159+
default=None,
160+
metadata={"help": "Model weight quantization algorithm including 'nf4'(qlora), 'weight_only_int8'."},
161+
)
162+
# LoRA
163+
lora_rank: int = field(default=8, metadata={"help": "Lora rank."})
164+
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
165+
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
166+
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
167+
lora_alpha: int = field(default=-1, metadata={"help": "lora_alpha"})
168+
rslora_plus: bool = field(default=False, metadata={"help": "Strengthen lora performance"})
169+
use_quick_lora: bool = field(default=True, metadata={"help": "quick lora"})

llm/alignment/dpo/run_dpo.py

Lines changed: 103 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,24 @@
2020
from functools import partial
2121

2222
import paddle
23-
from dpo_argument import DPODataArgument, DPOModelArgument, DPOTrainingArguments
24-
25-
from paddlenlp.datasets import ZeroPaddingMapDataset, load_dataset
26-
from paddlenlp.trainer import (
27-
IntervalStrategy,
28-
PdArgumentParser,
29-
get_last_checkpoint,
30-
set_seed,
23+
from dpo_argument import (
24+
DPOConfig,
25+
DPODataArgument,
26+
DPOModelArgument,
27+
DPOTrainingArguments,
3128
)
29+
30+
from paddlenlp.datasets import (
31+
ZeroPaddingIterableDataset,
32+
ZeroPaddingMapDataset,
33+
load_dataset,
34+
)
35+
from paddlenlp.peft import LoRAConfig, LoRAModel
36+
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed
3237
from paddlenlp.transformers import (
3338
AutoConfig,
3439
AutoModelForCausalLM,
40+
AutoModelForCausalLMPipe,
3541
AutoTokenizer,
3642
LlamaForCausalLM,
3743
LlamaForCausalLMPipe,
@@ -43,47 +49,34 @@
4349
preference_collate_fn,
4450
preprocess_preference_data,
4551
)
52+
from paddlenlp.utils.llm_utils import get_lora_target_modules
4653
from paddlenlp.utils.log import logger
4754

4855
flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
4956

5057

5158
def main():
5259
"""main"""
53-
parser = PdArgumentParser((DPOModelArgument, DPODataArgument, DPOTrainingArguments))
54-
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
55-
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
60+
parser = PdArgumentParser((DPOModelArgument, DPODataArgument, DPOTrainingArguments, DPOConfig))
61+
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
62+
model_args, data_args, training_args, dpo_config = parser.parse_json_file_and_cmd_lines()
5663
else:
57-
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
58-
59-
training_args.print_config(model_args, "Model")
60-
training_args.print_config(data_args, "Data")
61-
if training_args.max_steps > 0:
62-
training_args.num_train_epochs = 1
63-
if data_args.autotuner_benchmark:
64-
training_args.num_train_epochs = 1
65-
training_args.max_steps = 5
66-
training_args.do_train = True
67-
training_args.do_export = False
68-
training_args.do_predict = False
69-
training_args.do_eval = False
70-
training_args.overwrite_output_dir = True
71-
training_args.load_best_model_at_end = False
72-
training_args.report_to = []
73-
training_args.save_strategy = IntervalStrategy.NO
74-
training_args.evaluation_strategy = IntervalStrategy.NO
75-
if data_args.benchmark:
76-
training_args.do_train = True
77-
training_args.do_export = False
78-
training_args.do_predict = False
79-
training_args.do_eval = False
80-
training_args.overwrite_output_dir = True
81-
training_args.load_best_model_at_end = False
82-
training_args.save_strategy = IntervalStrategy.NO
83-
training_args.evaluation_strategy = IntervalStrategy.NO
64+
model_args, data_args, training_args, dpo_config = parser.parse_args_into_dataclasses()
8465

8566
paddle.set_device(training_args.device)
8667
set_seed(training_args.seed)
68+
if dpo_config.loss_type == "orpo":
69+
dpo_config.reference_free = True
70+
dpo_config.sft_loss_ratio = 1.0
71+
dpo_config.loss_type = "or"
72+
logger.info("orpo loss_type is equal to sft_loss + pref_loss_ratio * or_loss.")
73+
if dpo_config.loss_type in ["or", "simpo"] and not dpo_config.reference_free:
74+
dpo_config.reference_free = True
75+
logger.warning(f"{dpo_config.loss_type} loss_type only supports reference_free. Set reference_free to True.")
76+
77+
training_args.print_config(model_args, "Model")
78+
training_args.print_config(data_args, "Data")
79+
training_args.print_config(dpo_config, "DPOConfig")
8780

8881
logger.warning(
8982
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: "
@@ -116,51 +109,102 @@ def main():
116109
tensor_parallel_rank=training_args.tensor_parallel_rank,
117110
recompute_granularity=model_args.recompute_granularity,
118111
use_flash_attention=model_args.use_flash_attention,
119-
tensor_parallel_output=True,
112+
tensor_parallel_output=model_args.tensor_parallel_output,
120113
)
121114
if training_args.pipeline_parallel_degree > 1:
122115
raise ValueError("DPO does not support pipeline parallelism yet.")
123-
124-
if not data_args.autotuner_benchmark:
125-
ref_model = AutoModelForCausalLM.from_pretrained(**model_kwargs)
126-
config = AutoConfig.from_pretrained(**model_kwargs)
127-
model = AutoModelForCausalLM.from_config(config)
128-
model.set_state_dict(ref_model.state_dict())
116+
if training_args.pipeline_parallel_degree > 1:
117+
model_class = AutoModelForCausalLMPipe
118+
else:
119+
model_class = AutoModelForCausalLM
120+
if not training_args.autotuner_benchmark or model_args.weight_quantize_algo is not None:
121+
model = model_class.from_pretrained(**model_kwargs)
122+
# for DPO save
123+
model.config.dpo_config = None
124+
if not dpo_config.reference_free and not dpo_config.lora:
125+
config = AutoConfig.from_pretrained(**model_kwargs)
126+
ref_model = model_class.from_config(config, dtype=dtype)
127+
ref_model.set_state_dict(model.state_dict())
128+
else:
129+
ref_model = None
129130
else:
130131
config = AutoConfig.from_pretrained(**model_kwargs)
131-
model = AutoModelForCausalLM.from_config(config)
132-
ref_config = AutoConfig.from_pretrained(**model_kwargs)
133-
ref_model = AutoModelForCausalLM.from_config(ref_config)
134-
model.set_state_dict(ref_model.state_dict())
132+
model = model_class.from_config(config, dtype=dtype)
133+
if not dpo_config.reference_free and not dpo_config.lora:
134+
ref_model = model_class.from_config(config, dtype=dtype)
135+
else:
136+
ref_model = None
135137

136138
if model_args.flash_mask and not model.config.use_flash_attention:
137139
logger.warning("`flash_mask` must use with zero padding and flash attention.")
138140
model.config.use_flash_attention = True
139141

140142
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
141143
raise NotImplementedError(f"{model.__class__} not support flash mask.")
142-
if training_args.sequence_parallel:
144+
145+
if model_args.sequence_parallel:
143146
register_sequence_parallel_allreduce_hooks(
144-
model, training_args.gradient_accumulation_steps, training_args.fuse_sequence_parallel_allreduce
147+
model, training_args.gradient_accumulation_steps, model_args.fuse_sequence_parallel_allreduce
145148
)
146149
if model_args.tokenizer_name_or_path is not None:
147150
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
148151
else:
149152
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
150153
# TODO: support chat template in next pr
151-
# tokenizer.chat_template = None
154+
tokenizer.chat_template = None
152155
logger.info("Loading model & tokenizer successfully !")
153156

157+
if dpo_config.lora:
158+
if training_args.sharding_parallel_degree > 1:
159+
assert (
160+
"enable_stage1_overlap" not in training_args.sharding_parallel_config
161+
), "Currently not support enabling sharding_stage1_overlap in lora mode."
162+
if model_args.lora_path is None:
163+
target_modules = get_lora_target_modules(model)
164+
if model_args.rslora_plus:
165+
model_args.rslora = True
166+
model_args.lora_plus_scale = 4
167+
model_args.lora_alpha = 4
168+
if model_args.weight_quantize_algo is not None:
169+
if model_args.rslora or model_args.lora_plus_scale != 1.0:
170+
logger.info("Weight quantization is not supported in LoRA+ and RsLoRA.")
171+
if model_args.lora_alpha == -1:
172+
if model_args.rslora:
173+
model_args.lora_alpha = 4
174+
else:
175+
model_args.lora_alpha = 2 * model_args.lora_rank
176+
lora_config = LoRAConfig(
177+
target_modules=target_modules,
178+
r=model_args.lora_rank,
179+
lora_alpha=2 * model_args.lora_rank if not model_args.rslora else 4,
180+
rslora=model_args.rslora,
181+
lora_plus_scale=model_args.lora_plus_scale,
182+
tensor_parallel_degree=training_args.tensor_parallel_degree,
183+
dtype=dtype,
184+
base_model_name_or_path=model_args.model_name_or_path,
185+
use_quick_lora=model_args.use_quick_lora,
186+
)
187+
model = LoRAModel(model, lora_config)
188+
else:
189+
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)
190+
191+
model.print_trainable_parameters()
192+
154193
logger.info("Start to create dataset")
155194
trans_func = partial(preprocess_preference_data, tokenizer=tokenizer, data_args=data_args, model_args=model_args)
195+
if data_args.lazy:
196+
zero_padding_dataset = ZeroPaddingIterableDataset
197+
else:
198+
zero_padding_dataset = ZeroPaddingMapDataset
156199
if training_args.do_train and training_args.should_load_dataset:
157200
train_ds = load_dataset(
158201
"json",
159202
data_files=data_args.train_dataset_path,
203+
lazy=data_args.lazy,
160204
)[0]
161205
logger.info("Creating train Zero Padding Data Stream. This may take a few minutes.")
162206
train_ds = (
163-
ZeroPaddingMapDataset(
207+
zero_padding_dataset(
164208
train_ds.map(trans_func),
165209
tokenizer=tokenizer,
166210
max_length=data_args.max_seq_len,
@@ -176,10 +220,11 @@ def main():
176220
eval_ds = load_dataset(
177221
"json",
178222
data_files=data_args.dev_dataset_path,
223+
lazy=data_args.lazy,
179224
)[0]
180225
logger.info("Creating dev Zero Padding Data Stream. This may take a few minutes.")
181226
eval_ds = (
182-
ZeroPaddingMapDataset(
227+
zero_padding_dataset(
183228
eval_ds.map(trans_func),
184229
tokenizer=tokenizer,
185230
max_length=data_args.max_seq_len,
@@ -194,6 +239,7 @@ def main():
194239
trainer = DPOTrainer(
195240
model=model,
196241
ref_model=ref_model,
242+
dpo_config=dpo_config,
197243
args=training_args,
198244
train_dataset=train_ds,
199245
eval_dataset=eval_ds,
@@ -202,17 +248,18 @@ def main():
202248
preference_collate_fn,
203249
max_seq_len=data_args.max_seq_len,
204250
),
251+
ignore_eos_token=True,
205252
)
206253

207254
if training_args.do_train:
208255
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
209256

210-
if not data_args.autotuner_benchmark and not data_args.benchmark:
257+
if not training_args.autotuner_benchmark and not training_args.benchmark:
211258
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
212259
trainer.log_metrics("train", train_result.metrics)
213260
trainer.save_metrics("train", train_result.metrics)
214261
trainer.save_state()
215-
if data_args.benchmark:
262+
if training_args.benchmark:
216263
total_effective_tokens, total_tokens = calculate_effective_tokens(
217264
training_args, train_ds, data_args.max_seq_len
218265
)

llm/config/baichuan/dpo_argument.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
"disable_tqdm": true,
2525
"load_best_model_at_end": true,
2626
"tensor_parallel_degree": 8,
27-
"sharding_parallel_degree": 1,
2827
"sharding": "stage1",
2928
"use_flash_attention": true,
3029
"recompute": false,

0 commit comments

Comments
 (0)