diff --git a/llm/alignment/dpo/dpo_argument.py b/llm/alignment/dpo/dpo_argument.py index 63229100c466..86173be4ae2c 100644 --- a/llm/alignment/dpo/dpo_argument.py +++ b/llm/alignment/dpo/dpo_argument.py @@ -87,8 +87,8 @@ class DPOModelArgument: "help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`." }, ) - use_attn_mask_start_row_indices: bool = field( - default=False, metadata={"help": "Whether to use attn_mask_start_row_indices in flash attention."} + flash_mask: bool = field( + default=False, metadata={"help": "Whether to use flash mask in flash attention."} ) virtual_pp_degree: int = field( default=1, diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index aa7a09f16ad3..0d542444fe55 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -17,6 +17,7 @@ import os import sys import time +import inspect from functools import partial import paddle @@ -36,8 +37,14 @@ preference_collate_fn, preprocess_preference_data, ) +from paddlenlp.transformers import ( + LlamaForCausalLM, + LlamaForCausalLMPipe, +) from paddlenlp.utils.log import logger +flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe] + def main(): """main""" @@ -124,6 +131,15 @@ def main(): ref_model = AutoModelForCausalLM.from_config(ref_config) model.set_state_dict(ref_model.state_dict()) + if model_args.flash_mask and not model.config.use_flash_attention: + logger.warning( + "`flash_mask` must use with zero padding and flash attention." + ) + model.config.use_flash_attention = True + + if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): + raise NotImplementedError(f"{model.__class__} not support flash mask.") + if model_args.tokenizer_name_or_path is not None: tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) else: diff --git a/llm/config/llama/dpo_argument.json b/llm/config/llama/dpo_argument.json index b30fcc86478c..cc2b000dfcef 100644 --- a/llm/config/llama/dpo_argument.json +++ b/llm/config/llama/dpo_argument.json @@ -27,7 +27,7 @@ "sharding_parallel_degree": 1, "sharding": "stage1", "use_flash_attention": true, - "use_attn_mask_start_row_indices":false, + "flash_mask":true, "recompute": false, "recompute_granularity": "full", "dpo_beta": 0.1, diff --git a/llm/config/qwen/dpo_argument.json b/llm/config/qwen/dpo_argument.json index 716cdba59da6..39f9a1f5c447 100644 --- a/llm/config/qwen/dpo_argument.json +++ b/llm/config/qwen/dpo_argument.json @@ -27,7 +27,7 @@ "sharding_parallel_degree": 1, "sharding": "stage1", "use_flash_attention": true, - "use_attn_mask_start_row_indices":false, + "flash_mask":false, "recompute": false, "recompute_granularity": "full", "dpo_beta": 0.1, diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 683bd871946b..8df3705fe335 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -14,6 +14,7 @@ import json import os import sys +import inspect from functools import partial import paddle @@ -51,6 +52,8 @@ AutoTokenizer, Llama3Tokenizer, LlamaTokenizer, + LlamaForCausalLM, + LlamaForCausalLMPipe, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig from paddlenlp.utils.log import logger @@ -58,6 +61,8 @@ # Fine-tune Environment Variables to support sharding stage1 overlap optimization. os.environ["USE_CASUAL_MASK"] = "False" +flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe] + def main(): # Arguments @@ -77,6 +82,7 @@ def main(): raise ValueError( "--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time" ) + # Setup GPU & distributed training paddle.set_device(training_args.device) @@ -160,6 +166,16 @@ def main(): # NOTE(gongenlei): new add autotuner_benchmark model = model_class.from_config(model_config, dtype=dtype) + if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): + logger.warning( + "`flash_mask` must use with zero padding and flash attention." + ) + data_args.zero_padding = True + model.config.use_flash_attention = True + + if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): + raise NotImplementedError(f"{model.__class__} not support flash mask.") + if training_args.do_train and model_args.neftune: # Inspired by https://github.com/neelsjain/NEFTune if hasattr(model, "get_input_embeddings"): @@ -329,12 +345,12 @@ def neft_post_hook(module, input, output): "Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far." ) train_ds = ( - train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding)) + train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)) if train_ds is not None else None ) ptq_ds = ( - ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding)) + ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)) if ptq_ds is not None else None ) @@ -345,7 +361,7 @@ def neft_post_hook(module, input, output): ) eval_zero_padding = False dev_ds = ( - dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding)) + dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.flash_mask)) if dev_ds is not None else None ) @@ -498,6 +514,7 @@ def compute_metrics_do_generation(eval_preds): padding=padding, max_label_length=max_length, return_tensors="np", + return_attention_mask=not model_args.flash_mask, pad_to_multiple_of=data_args.pad_to_multiple_of, ), do_generation=data_args.eval_with_do_generation, diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 67ad7c5dbe2a..63a14e4126ef 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -209,6 +209,9 @@ class ModelArgument: aistudio_token: str = field(default=None, metadata={"help": "The token of aistudio"}) neftune: bool = field(default=False, metadata={"help": "Whether to apply NEFT"}) neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"}) + flash_mask: bool = field( + default=False, metadata={"help": "Whether to use flash_mask in flash attention."} + ) @dataclass diff --git a/llm/utils/data.py b/llm/utils/data.py index eabac7456cbe..f06fa4033fc9 100644 --- a/llm/utils/data.py +++ b/llm/utils/data.py @@ -173,11 +173,12 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs): return tokenized_source, labels -def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False): +def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False): if tokenizer.chat_template is not None: - return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding) + return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask) tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args) + if is_test: return { **tokenized_source, @@ -194,12 +195,17 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, zero_pad if "position_ids" in tokenized_source: features["position_ids"] = list(range(seq_length)) if zero_padding: - features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + if flash_mask: + features["attn_mask_startend_row_indices"] = ( + [seq_length] * seq_length + ) + else: + features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) return features -def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False): +def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False): """convert multi-rounds conversation example Args: @@ -227,7 +233,13 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, z seq_length = len(input_ids) features = {"input_ids": input_ids, "labels": labels} if zero_padding: - features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + if flash_mask: + features["attn_mask_startend_row_indices"] = ( + [seq_length] * seq_length + ) + else: + features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + if "position_ids" in rounds_inputs: rounds_inputs["position_ids"] = rounds_inputs["position_ids"][:-1] diff --git a/paddlenlp/data/data_collator.py b/paddlenlp/data/data_collator.py index aecd186a91e4..351c44867b28 100644 --- a/paddlenlp/data/data_collator.py +++ b/paddlenlp/data/data_collator.py @@ -370,6 +370,11 @@ def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors labels = [feature["labels"] for feature in batch] if "labels" in batch[0].keys() else None + use_attn_mask_startend_row_indices = ( + [feature["attn_mask_startend_row_indices"] for feature in batch] + if "attn_mask_startend_row_indices" in batch[0].keys() + else None + ) # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the # same length to return tensors. if labels is not None: @@ -396,6 +401,29 @@ def __call__(self, features, return_tensors=None): feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) else: feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) + if use_attn_mask_startend_row_indices is not None: + if self.max_length is not None: + max_length = self.max_length + else: + max_length = max(len(l) for l in use_attn_mask_startend_row_indices) + if self.pad_to_multiple_of is not None: + max_length = ( + (max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of + ) + + for feature in batch: + pad_len = max_length - len(feature["attn_mask_startend_row_indices"]) + remainder = np.zeros([1, pad_len], dtype=np.int32) + feature["attn_mask_startend_row_indices"] = ( + np.concatenate( + [remainder, np.array([feature["attn_mask_startend_row_indices"]], dtype=np.int32) + pad_len], + axis=-1, + ) + if padding_side == "left" + else np.concatenate( + [np.array([feature["attn_mask_startend_row_indices"]], dtype=np.int32), remainder], axis=-1 + ) + ) batch = self.tokenizer.pad( batch, diff --git a/paddlenlp/datasets/zero_padding_dataset.py b/paddlenlp/datasets/zero_padding_dataset.py index 37b85ea86428..51794e35d4dc 100644 --- a/paddlenlp/datasets/zero_padding_dataset.py +++ b/paddlenlp/datasets/zero_padding_dataset.py @@ -28,14 +28,14 @@ class ZeroPadding: "chosen_labels", "rejected_labels", "response_indexs", - "attn_mask_start_row_indices", + "attn_mask_startend_row_indices", ] @classmethod def _pad_batch_records(cls, batch_records): # Only consider supported input keys input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys] - if "attn_mask_start_row_indices" not in input_keys and "attention_mask" not in input_keys: + if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys: input_keys.append("attention_mask") batched_features = {key: [] for key in input_keys} sequence_sum = 0 @@ -57,9 +57,9 @@ def _pad_batch_records(cls, batch_records): seq_length = len(record["input_ids"]) # If attention_mask is not given, assume it's causal mask - if "attn_mask_start_row_indices" in record: - attn_mask_start_row_indices = [i + sequence_sum for i in record["attn_mask_start_row_indices"]] - batched_features["attn_mask_start_row_indices"].extend(attn_mask_start_row_indices) + if "attn_mask_startend_row_indices" in record: + attn_mask_startend_row_indices = [i + sequence_sum for i in record["attn_mask_startend_row_indices"]] + batched_features["attn_mask_startend_row_indices"].extend(attn_mask_startend_row_indices) else: attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool))) batched_features["attention_mask"].append(attention_mask) diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 2a273489e59b..290538912c9d 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -211,6 +211,8 @@ def fusion_flash_attention( else: if attn_mask_startend_row_indices is not None: assert alibi is None, "flash_attention_with_sparse_mask not support alibi" + if len(attn_mask_startend_row_indices.shape) == 2: + attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1) attn_output = F.flash_attention_with_sparse_mask( query_states, key_states, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 0e7623f12e94..dd38504a0e6f 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1906,6 +1906,14 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attn_mask_startend_row_indices is not None and attention_mask is not None: + logger.warning( + "You have provided both attn_mask_startend_row_indices and attention_mask. " + "The attn_mask_startend_row_indices will be used." + ) + attention_mask = None + outputs = self.llama( input_ids, # [bs, seq_len] position_ids=position_ids, diff --git a/paddlenlp/trl/dpo_trainer.py b/paddlenlp/trl/dpo_trainer.py index 144ceb816fdc..0429c9faa769 100644 --- a/paddlenlp/trl/dpo_trainer.py +++ b/paddlenlp/trl/dpo_trainer.py @@ -177,8 +177,8 @@ def get_batch_metrics(self, model, batch, train_eval="train"): } if "attention_mask" in batch: dpo_inputs["attention_mask"] = batch["attention_mask"] - if "attn_mask_start_row_indices" in batch: - dpo_inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"] + if "attn_mask_startend_row_indices" in batch: + dpo_inputs["attn_mask_startend_row_indices"] = batch["attn_mask_startend_row_indices"] if self.reference_free: reference_chosen_logps, reference_rejected_logps = None, None else: @@ -194,8 +194,8 @@ def get_batch_metrics(self, model, batch, train_eval="train"): } if "attention_mask" in batch: dpo_inputs["attention_mask"] = batch["attention_mask"] - if "attn_mask_start_row_indices" in batch: - dpo_inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"] + if "attn_mask_startend_row_indices" in batch: + dpo_inputs["attn_mask_startend_row_indices"] = batch["attn_mask_startend_row_indices"] if self.reference_free: reference_chosen_logps, reference_rejected_logps = None, None else: @@ -522,7 +522,7 @@ def prepare_pipeline_dpo_inputs_func(inputs): else: first_stage_keys = [ "input_ids", - "attn_mask_start_row_indices", + "attn_mask_startend_row_indices", "position_ids", ] diff --git a/paddlenlp/trl/trl_data.py b/paddlenlp/trl/trl_data.py index ca3a1ae40f7e..7afc3cb527cf 100644 --- a/paddlenlp/trl/trl_data.py +++ b/paddlenlp/trl/trl_data.py @@ -159,8 +159,8 @@ def preprocess_preference_data(data, tokenizer, data_args, model_args): } # attention mask - if model_args.use_attn_mask_start_row_indices: - output_dict["attn_mask_start_row_indices"] = ( + if model_args.flash_mask: + output_dict["attn_mask_startend_row_indices"] = ( [seq_len] * prompt_len + [prompt_len + chosen_len] * chosen_len + [seq_len] * rejected_len ) else: @@ -183,14 +183,14 @@ def preference_collate_fn(batch, max_seq_len=None): "response_indexs": [], } sequence = batch[0] - if "attn_mask_start_row_indices" in sequence: - input_dict["attn_mask_start_row_indices"] = [] - use_attn_mask_start_row_indices = True + if "attn_mask_startend_row_indices" in sequence: + input_dict["attn_mask_startend_row_indices"] = [] + use_attn_mask_startend_row_indices = True elif "attention_mask" in sequence: input_dict["attention_mask"] = [] - use_attn_mask_start_row_indices = False + use_attn_mask_startend_row_indices = False else: - raise ValueError("attention_mask and attn_mask_start_row_indices are both None.") + raise ValueError("attention_mask and attn_mask_startend_row_indices are both None.") for i, sequence in enumerate(batch): difference = max_seq_len - len(sequence["input_ids"]) @@ -199,9 +199,12 @@ def preference_collate_fn(batch, max_seq_len=None): input_dict["position_ids"].append(sequence["position_ids"] + [0] * difference) input_dict["chosen_labels"].append(sequence["chosen_labels"] + [0] * difference) input_dict["rejected_labels"].append(sequence["rejected_labels"] + [0] * difference) - if use_attn_mask_start_row_indices: - input_dict["attn_mask_start_row_indices"].append( - [sequence["attn_mask_start_row_indices"] + [sequence["attn_mask_start_row_indices"][-1]] * difference] + if use_attn_mask_startend_row_indices: + input_dict["attn_mask_startend_row_indices"].append( + [ + sequence["attn_mask_startend_row_indices"] + + [sequence["attn_mask_startend_row_indices"][-1]] * difference + ] ) else: input_dict["attention_mask"].append( @@ -225,7 +228,7 @@ def preference_collate_fn(batch, max_seq_len=None): for key in input_dict: if key == "attention_mask": input_dict[key] = np.array(input_dict[key], dtype=bool) - elif key == "attn_mask_start_row_indices": + elif key == "attn_mask_startend_row_indices": input_dict[key] = np.array(input_dict[key], dtype=np.int32) else: input_dict[key] = np.array(input_dict[key])