Skip to content

[llm]support dpo pp for qwen & llama #9695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
LlamaForCausalLM,
LlamaForCausalLMPipe,
Qwen2ForCausalLM,
Qwen2ForCausalLMPipe,
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.trl import (
Expand All @@ -53,7 +54,7 @@
from paddlenlp.trl.llm_utils import get_lora_target_modules
from paddlenlp.utils.log import logger

flash_mask_support_list = [Qwen2ForCausalLM, LlamaForCausalLM, LlamaForCausalLMPipe]
flash_mask_support_list = [Qwen2ForCausalLM, Qwen2ForCausalLMPipe, LlamaForCausalLM, LlamaForCausalLMPipe]


def main():
Expand All @@ -74,7 +75,20 @@ def main():
if dpo_config.loss_type in ["or", "simpo"] and not dpo_config.reference_free:
dpo_config.reference_free = True
logger.warning(f"{dpo_config.loss_type} loss_type only supports reference_free. Set reference_free to True.")

if training_args.pipeline_parallel_degree > 1:
assert (
hasattr(training_args, "pipeline_parallel_config")
and "enable_clear_every_step_cache" in training_args.pipeline_parallel_config
), "Should set '--pipeline_parallel_config enable_clear_every_step_cache' in bash script for pp."
if model_args.sequence_parallel:
if training_args.pipeline_parallel_degree > 1:
assert (
hasattr(training_args, "pipeline_parallel_config")
and "disable_partial_send_recv" in training_args.pipeline_parallel_config
), "Should set '--pipeline_parallel_config disable_partial_send_recv' in bash script for pp with sp."
if training_args.tensor_parallel_degree <= 1:
model_args.sequence_parallel = False
logger.info("Tensor_parallel_degree = 1. Set sequence_parallel to False.")
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
training_args.print_config(dpo_config, "DPOConfig")
Expand Down Expand Up @@ -112,16 +126,15 @@ def main():
use_flash_attention=model_args.use_flash_attention,
tensor_parallel_output=model_args.tensor_parallel_output,
)
if training_args.pipeline_parallel_degree > 1:
raise ValueError("DPO does not support pipeline parallelism yet.")

if training_args.pipeline_parallel_degree > 1:
model_class = AutoModelForCausalLMPipe
model_kwargs["dpo_config"] = dpo_config
else:
model_class = AutoModelForCausalLM
if not training_args.autotuner_benchmark or model_args.weight_quantize_algo is not None:
model = model_class.from_pretrained(**model_kwargs)
# for DPO save
model.config.dpo_config = None
if not dpo_config.reference_free and not dpo_config.lora:
config = AutoConfig.from_pretrained(**model_kwargs)
ref_model = model_class.from_config(config, dtype=dtype)
Expand All @@ -135,7 +148,8 @@ def main():
ref_model = model_class.from_config(config, dtype=dtype)
else:
ref_model = None

if training_args.pipeline_parallel_degree > 1:
model.config.dpo_config = None
if model_args.flash_mask and not model.config.use_flash_attention:
logger.warning("`flash_mask` must use with zero padding and flash attention.")
model.config.use_flash_attention = True
Expand Down
File renamed without changes.
262 changes: 262 additions & 0 deletions paddlenlp/transformers/kto_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy

from paddlenlp.transformers import (
AllGatherVarlenOp,
fused_head_and_loss_fn,
parallel_linear,
parallel_matmul,
sequence_parallel_sparse_mask_labels,
)
from paddlenlp.utils import infohub


class KTOCriterion(nn.Layer):
"""KTO Criterion"""

def __init__(self, config, kto_config=None, ignore_label=0, use_infohub=False):
super(KTOCriterion, self).__init__()
self.config = config
if kto_config is None:
if getattr(self.config, "kto_config", None) is None:
raise ValueError("KTO Criterion requires model_config.kto_config.")
self.kto_config = copy.deepcopy(config.kto_config)

Check warning on line 42 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L37-L42

Added lines #L37 - L42 were not covered by tests
else:
self.kto_config = kto_config
if self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1:
self.logprobs = ParallelCrossEntropy()

Check warning on line 46 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L44-L46

Added lines #L44 - L46 were not covered by tests
else:
self.logprobs = nn.CrossEntropyLoss(reduction="none")
self.use_infohub = use_infohub
self.ignore_label = ignore_label

Check warning on line 50 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L48-L50

Added lines #L48 - L50 were not covered by tests
# allgather kl in criterion
topo = fleet.get_hybrid_communicate_group()._topo
parallel_groups = topo.get_comm_list("pipe")
ranks = []
for group in parallel_groups:
ranks.append(group[-1])
self.comm_group = paddle.distributed.new_group(ranks=ranks)

Check warning on line 57 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L52-L57

Added lines #L52 - L57 were not covered by tests

def _nested_gather(self, tensors):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
concatenating them to `gathered`
"""
local_rank = -1
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
if env_local_rank != -1 and env_local_rank != local_rank and paddle.distributed.get_world_size() > 1:
local_rank = env_local_rank
if tensors is None:
return
if local_rank != -1:
output_tensors = []
paddle.distributed.all_gather(

Check warning on line 72 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L64-L72

Added lines #L64 - L72 were not covered by tests
output_tensors, paddle.tile(tensors, repeat_times=[1, 1]), group=self.comm_group
)
tensors = paddle.concat(output_tensors, axis=0)
return tensors

Check warning on line 76 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L75-L76

Added lines #L75 - L76 were not covered by tests

def kto_logps(self, logits, response_labels, response_kl_labels, response_indexs):
"""KTO logprobs"""
labels = response_labels + response_kl_labels
if self.config.use_fused_head_and_loss_fn:
hidden_states, weight, bias, transpose_y = logits
elif self.config.use_sparse_head_and_loss_fn:
hidden_states, weight, bias = logits
if self.config.use_sparse_head_and_loss_fn:
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, self.ignore_label)

Check warning on line 87 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L80-L87

Added lines #L80 - L87 were not covered by tests

hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx, axis=0)
hidden_states = AllGatherVarlenOp.apply(hidden_states)

Check warning on line 90 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L89-L90

Added lines #L89 - L90 were not covered by tests
else:
labels = labels.flatten()
sparse_tgt_idx = paddle.nonzero(labels != self.ignore_label).flatten()
labels = paddle.take_along_axis(labels, sparse_tgt_idx, axis=0)

Check warning on line 94 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L92-L94

Added lines #L92 - L94 were not covered by tests

hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]])
hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx.unsqueeze(-1), axis=0)
if self.config.use_fused_head_and_loss_fn:
per_token_logps = -fused_head_and_loss_fn(

Check warning on line 99 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L96-L99

Added lines #L96 - L99 were not covered by tests
hidden_states,
weight,
bias,
labels,
None,
transpose_y,
self.config.vocab_size,
self.config.tensor_parallel_degree,
self.config.tensor_parallel_output,
self.config.fused_linear,
getattr(self.config, "chunk_size", 1024),
return_token_loss=True,
ignore_index=self.ignore_label,
)
elif self.config.use_sparse_head_and_loss_fn:
if bias is None:
logits = parallel_matmul(hidden_states, weight, self.config.tensor_parallel_output)

Check warning on line 116 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L114-L116

Added lines #L114 - L116 were not covered by tests
else:
logits = parallel_linear(

Check warning on line 118 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L118

Added line #L118 was not covered by tests
hidden_states,
weight,
bias,
self.config.tensor_parallel_output,
)
logits = logits.astype("float32")
per_token_logps = -self.logprobs(logits, labels)

Check warning on line 125 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L124-L125

Added lines #L124 - L125 were not covered by tests
else:
logits = logits.astype("float32")
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")

Check warning on line 129 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L127-L129

Added lines #L127 - L129 were not covered by tests
# bs, seq
per_token_logps = -self.logprobs(logits, labels.unsqueeze(2)).squeeze(2)

Check warning on line 131 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L131

Added line #L131 was not covered by tests

if len(response_indexs.shape) == 3:
response_indexs = response_indexs[0]
if self.config.use_sparse_head_and_loss_fn:
chosen_logps_list = [

Check warning on line 136 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L133-L136

Added lines #L133 - L136 were not covered by tests
(per_token_logps[response_index[1] : response_index[2]]).sum()
for response_index in response_indexs
if response_index[4] == 1
]
rejected_logps_list = [

Check warning on line 141 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L141

Added line #L141 was not covered by tests
(per_token_logps[response_index[1] : response_index[2]]).sum()
for response_index in response_indexs
if response_index[4] == 0
]
kl_logps_list = [

Check warning on line 146 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L146

Added line #L146 was not covered by tests
(per_token_logps[response_index[2] : response_index[3]]).sum() for response_index in response_indexs
]
else:
chosen_logps_list = [

Check warning on line 150 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L150

Added line #L150 was not covered by tests
(per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum()
for response_index in response_indexs
if response_index[4] == 1
]
rejected_logps_list = [

Check warning on line 155 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L155

Added line #L155 was not covered by tests
(per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum()
for response_index in response_indexs
if response_index[4] == 0
]
kl_logps_list = [

Check warning on line 160 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L160

Added line #L160 was not covered by tests
(per_token_logps[response_index[0]][response_index[2] : response_index[3]]).sum()
for response_index in response_indexs
]
if len(chosen_logps_list) == 0:
chosen_logps = paddle.zeros([0], dtype="float32")

Check warning on line 165 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L164-L165

Added lines #L164 - L165 were not covered by tests
else:
chosen_logps = paddle.stack(chosen_logps_list, axis=0)
if len(rejected_logps_list) == 0:
rejected_logps = paddle.zeros([0], dtype="float32")

Check warning on line 169 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L167-L169

Added lines #L167 - L169 were not covered by tests
else:
rejected_logps = paddle.stack(rejected_logps_list, axis=0)
kl_logps = paddle.stack(kl_logps_list, axis=0)
return chosen_logps, rejected_logps, kl_logps

Check warning on line 173 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L171-L173

Added lines #L171 - L173 were not covered by tests

def kto_loss(
self,
policy_chosen_logps,
policy_rejected_logps,
policy_kl_logps,
reference_chosen_logps,
reference_rejected_logps,
reference_kl_logps,
):
"""KTO Loss"""
kl = (policy_kl_logps - reference_kl_logps).mean().detach()
kl = self._nested_gather(paddle.tile(kl, repeat_times=[1, 1])).mean().clip(min=0)
if policy_chosen_logps.shape[0] == 0 or reference_chosen_logps.shape[0] == 0:
chosen_losses = paddle.zeros([0])

Check warning on line 188 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L185-L188

Added lines #L185 - L188 were not covered by tests
else:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
chosen_losses = 1 - F.sigmoid(self.kto_config.beta * (chosen_logratios - kl))
if policy_rejected_logps.shape[0] == 0 or reference_rejected_logps.shape[0] == 0:
rejected_losses = paddle.zeros([0])

Check warning on line 193 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L190-L193

Added lines #L190 - L193 were not covered by tests
else:
rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_losses = 1 - F.sigmoid(self.kto_config.beta * (kl - rejected_logratios))
losses = paddle.concat(

Check warning on line 197 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L195-L197

Added lines #L195 - L197 were not covered by tests
(
self.kto_config.desirable_weight * chosen_losses,
self.kto_config.undesirable_weight * rejected_losses,
),
0,
)
return losses.mean(), kl

Check warning on line 204 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L204

Added line #L204 was not covered by tests

def forward(
self,
logits,
labels,
):
"""Forward"""
(

Check warning on line 212 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L212

Added line #L212 was not covered by tests
response_labels,
response_kl_labels,
response_indexs,
reference_chosen_logps,
reference_rejected_logps,
reference_kl_logps,
) = labels
if reference_chosen_logps is None or reference_rejected_logps is None or reference_kl_logps is None:
(

Check warning on line 221 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L220-L221

Added lines #L220 - L221 were not covered by tests
reference_chosen_logps,
reference_rejected_logps,
reference_kl_logps,
) = self.kto_logps(logits, response_labels, response_kl_labels, response_indexs)
if self.use_infohub:
infohub.reference_chosen_logps.append(reference_chosen_logps)
infohub.reference_rejected_logps.append(reference_rejected_logps)
infohub.reference_kl_logps.append(reference_kl_logps)

Check warning on line 229 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L226-L229

Added lines #L226 - L229 were not covered by tests
# pipeline mode requires return loss when self._compute_loss is True
return paddle.zeros([1])

Check warning on line 231 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L231

Added line #L231 was not covered by tests
else:
return (

Check warning on line 233 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L233

Added line #L233 was not covered by tests
reference_chosen_logps,
reference_rejected_logps,
reference_kl_logps,
)
policy_chosen_logps, policy_rejected_logps, policy_kl_logps = self.kto_logps(

Check warning on line 238 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L238

Added line #L238 was not covered by tests
logits, response_labels, response_kl_labels, response_indexs
)
loss, kl = self.kto_loss(

Check warning on line 241 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L241

Added line #L241 was not covered by tests
policy_chosen_logps,
policy_rejected_logps,
policy_kl_logps,
reference_chosen_logps,
reference_rejected_logps,
reference_kl_logps,
)
if self.use_infohub:
infohub.policy_chosen_logps.append(policy_chosen_logps.detach())
infohub.policy_rejected_logps.append(policy_rejected_logps.detach())
infohub.policy_kl_logps.append(policy_kl_logps.detach())
infohub.kl.append(kl.detach())
return loss

Check warning on line 254 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L249-L254

Added lines #L249 - L254 were not covered by tests
else:
return (

Check warning on line 256 in paddlenlp/transformers/kto_criterion.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/kto_criterion.py#L256

Added line #L256 was not covered by tests
policy_chosen_logps,
policy_rejected_logps,
policy_kl_logps,
loss,
kl,
)
2 changes: 2 additions & 0 deletions paddlenlp/transformers/llama/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
use_flash_attention_for_generation=False,
use_last_token_for_generation=False,
immediate_clear_past_key_value=False,
dpo_config=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(
self.use_flash_attention_for_generation = use_flash_attention_for_generation
self.use_last_token_for_generation = use_last_token_for_generation
self.immediate_clear_past_key_value = immediate_clear_past_key_value
self.dpo_config = dpo_config

super().__init__(
pad_token_id=pad_token_id,
Expand Down
6 changes: 5 additions & 1 deletion paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from paddlenlp.utils.tools import get_env_device

from ..dpo_criterion import DPOCriterion
from .modeling import (
LlamaConfig,
LlamaDecoderLayer,
Expand Down Expand Up @@ -423,4 +424,7 @@
# PipelinePretrainedModel.__init__(self.super(), config=config)

def get_loss_fn(self, config):
return LlamaPretrainingCriterion(config)
if config.dpo_config is not None:
return DPOCriterion(config, use_infohub=True)

Check warning on line 428 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L427-L428

Added lines #L427 - L428 were not covered by tests
else:
return LlamaPretrainingCriterion(config)

Check warning on line 430 in paddlenlp/transformers/llama/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_pp.py#L430

Added line #L430 was not covered by tests
2 changes: 2 additions & 0 deletions paddlenlp/transformers/qwen/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
dpo_config=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -74,6 +75,7 @@ def __init__(
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
self.dpo_config = dpo_config

super().__init__(
pad_token_id=pad_token_id,
Expand Down
Loading