From 8308cae1784cfefdf4342be3769d072b0520150c Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Mon, 24 Jun 2024 18:57:56 +0800 Subject: [PATCH 1/9] support sft flash mask --- llm/alignment/dpo/dpo_argument.py | 4 ++-- llm/config/llama/dpo_argument.json | 2 +- llm/config/qwen/dpo_argument.json | 2 +- llm/run_finetune.py | 6 +++--- llm/utils/argument.py | 3 +++ llm/utils/data.py | 22 ++++++++++++++----- paddlenlp/data/data_collator.py | 23 ++++++++++++++++++++ paddlenlp/datasets/zero_padding_dataset.py | 10 ++++----- paddlenlp/transformers/llama/fusion_ops.py | 4 ++++ paddlenlp/transformers/llama/modeling.py | 11 ++++++++++ paddlenlp/trl/dpo_trainer.py | 10 ++++----- paddlenlp/trl/trl_data.py | 25 ++++++++++++---------- 12 files changed, 89 insertions(+), 33 deletions(-) diff --git a/llm/alignment/dpo/dpo_argument.py b/llm/alignment/dpo/dpo_argument.py index 63229100c466..6939ff63c397 100644 --- a/llm/alignment/dpo/dpo_argument.py +++ b/llm/alignment/dpo/dpo_argument.py @@ -87,8 +87,8 @@ class DPOModelArgument: "help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`." }, ) - use_attn_mask_start_row_indices: bool = field( - default=False, metadata={"help": "Whether to use attn_mask_start_row_indices in flash attention."} + use_attn_mask_startend_row_indices: bool = field( + default=False, metadata={"help": "Whether to use attn_mask_startend_row_indices in flash attention."} ) virtual_pp_degree: int = field( default=1, diff --git a/llm/config/llama/dpo_argument.json b/llm/config/llama/dpo_argument.json index b30fcc86478c..b31656392d2a 100644 --- a/llm/config/llama/dpo_argument.json +++ b/llm/config/llama/dpo_argument.json @@ -27,7 +27,7 @@ "sharding_parallel_degree": 1, "sharding": "stage1", "use_flash_attention": true, - "use_attn_mask_start_row_indices":false, + "use_attn_mask_startend_row_indices":false, "recompute": false, "recompute_granularity": "full", "dpo_beta": 0.1, diff --git a/llm/config/qwen/dpo_argument.json b/llm/config/qwen/dpo_argument.json index 716cdba59da6..263a43ee3fda 100644 --- a/llm/config/qwen/dpo_argument.json +++ b/llm/config/qwen/dpo_argument.json @@ -27,7 +27,7 @@ "sharding_parallel_degree": 1, "sharding": "stage1", "use_flash_attention": true, - "use_attn_mask_start_row_indices":false, + "use_attn_mask_startend_row_indices":false, "recompute": false, "recompute_granularity": "full", "dpo_beta": 0.1, diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 683bd871946b..953cb58df0b8 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -329,12 +329,12 @@ def neft_post_hook(module, input, output): "Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far." ) train_ds = ( - train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding)) + train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.use_attn_mask_startend_row_indices)) if train_ds is not None else None ) ptq_ds = ( - ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding)) + ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.use_attn_mask_startend_row_indices)) if ptq_ds is not None else None ) @@ -345,7 +345,7 @@ def neft_post_hook(module, input, output): ) eval_zero_padding = False dev_ds = ( - dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding)) + dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.use_attn_mask_startend_row_indices)) if dev_ds is not None else None ) diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 67ad7c5dbe2a..802a16b831d4 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -209,6 +209,9 @@ class ModelArgument: aistudio_token: str = field(default=None, metadata={"help": "The token of aistudio"}) neftune: bool = field(default=False, metadata={"help": "Whether to apply NEFT"}) neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"}) + use_attn_mask_startend_row_indices: bool = field( + default=False, metadata={"help": "Whether to use attn_mask_startend_row_indices in flash attention."} + ) @dataclass diff --git a/llm/utils/data.py b/llm/utils/data.py index eabac7456cbe..f06fa4033fc9 100644 --- a/llm/utils/data.py +++ b/llm/utils/data.py @@ -173,11 +173,12 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs): return tokenized_source, labels -def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False): +def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False): if tokenizer.chat_template is not None: - return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding) + return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding, flash_mask) tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args) + if is_test: return { **tokenized_source, @@ -194,12 +195,17 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, zero_pad if "position_ids" in tokenized_source: features["position_ids"] = list(range(seq_length)) if zero_padding: - features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + if flash_mask: + features["attn_mask_startend_row_indices"] = ( + [seq_length] * seq_length + ) + else: + features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) return features -def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False): +def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False): """convert multi-rounds conversation example Args: @@ -227,7 +233,13 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, z seq_length = len(input_ids) features = {"input_ids": input_ids, "labels": labels} if zero_padding: - features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + if flash_mask: + features["attn_mask_startend_row_indices"] = ( + [seq_length] * seq_length + ) + else: + features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + if "position_ids" in rounds_inputs: rounds_inputs["position_ids"] = rounds_inputs["position_ids"][:-1] diff --git a/paddlenlp/data/data_collator.py b/paddlenlp/data/data_collator.py index aecd186a91e4..74490ac86f3b 100644 --- a/paddlenlp/data/data_collator.py +++ b/paddlenlp/data/data_collator.py @@ -370,6 +370,11 @@ def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors labels = [feature["labels"] for feature in batch] if "labels" in batch[0].keys() else None + use_attn_mask_startend_row_indices = ( + [feature["attn_mask_startend_row_indices"] for feature in batch] + if "attn_mask_startend_row_indices" in batch[0].keys() + else None + ) # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the # same length to return tensors. if labels is not None: @@ -396,6 +401,24 @@ def __call__(self, features, return_tensors=None): feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) else: feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) + if use_attn_mask_startend_row_indices is not None: + max_length = max(len(l) for l in use_attn_mask_startend_row_indices) + if self.pad_to_multiple_of is not None: + max_length = ( + (max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of + ) + + padding_side = self.tokenizer.padding_side + for feature in batch: + remainder = [feature["attn_mask_startend_row_indices"][-1]] * ( + max_length - len(feature["attn_mask_startend_row_indices"]) + ) + if isinstance(feature["attn_mask_startend_row_indices"], list): + feature["attn_mask_startend_row_indices"] = ( + feature["attn_mask_startend_row_indices"] + remainder + if padding_side == "right" + else remainder + feature["attn_mask_startend_row_indices"] + ) batch = self.tokenizer.pad( batch, diff --git a/paddlenlp/datasets/zero_padding_dataset.py b/paddlenlp/datasets/zero_padding_dataset.py index 37b85ea86428..51794e35d4dc 100644 --- a/paddlenlp/datasets/zero_padding_dataset.py +++ b/paddlenlp/datasets/zero_padding_dataset.py @@ -28,14 +28,14 @@ class ZeroPadding: "chosen_labels", "rejected_labels", "response_indexs", - "attn_mask_start_row_indices", + "attn_mask_startend_row_indices", ] @classmethod def _pad_batch_records(cls, batch_records): # Only consider supported input keys input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys] - if "attn_mask_start_row_indices" not in input_keys and "attention_mask" not in input_keys: + if "attn_mask_startend_row_indices" not in input_keys and "attention_mask" not in input_keys: input_keys.append("attention_mask") batched_features = {key: [] for key in input_keys} sequence_sum = 0 @@ -57,9 +57,9 @@ def _pad_batch_records(cls, batch_records): seq_length = len(record["input_ids"]) # If attention_mask is not given, assume it's causal mask - if "attn_mask_start_row_indices" in record: - attn_mask_start_row_indices = [i + sequence_sum for i in record["attn_mask_start_row_indices"]] - batched_features["attn_mask_start_row_indices"].extend(attn_mask_start_row_indices) + if "attn_mask_startend_row_indices" in record: + attn_mask_startend_row_indices = [i + sequence_sum for i in record["attn_mask_startend_row_indices"]] + batched_features["attn_mask_startend_row_indices"].extend(attn_mask_startend_row_indices) else: attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool))) batched_features["attention_mask"].append(attention_mask) diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 2a273489e59b..7a6b0b51a42a 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -211,6 +211,10 @@ def fusion_flash_attention( else: if attn_mask_startend_row_indices is not None: assert alibi is None, "flash_attention_with_sparse_mask not support alibi" + if len(attn_mask_startend_row_indices.shape) == 2: + attn_mask_startend_row_indices = paddle.expand( + paddle.unsqueeze(attn_mask_startend_row_indices, axis=1), shape=[-1, num_heads, -1] + ) attn_output = F.flash_attention_with_sparse_mask( query_states, key_states, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 56c9713f0118..28b292502175 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1906,6 +1906,17 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64: + attn_mask_startend_row_indices = attn_mask_startend_row_indices.astype(paddle.int32) + logger.info("cast attn_mask_startend_row_indices to paddle.int32") + + if attn_mask_startend_row_indices is not None and attention_mask is not None: + logger.warning( + "You have provided both attn_mask_startend_row_indices and attention_mask. " + "The attn_mask_startend_row_indices will be used." + ) + outputs = self.llama( input_ids, # [bs, seq_len] position_ids=position_ids, diff --git a/paddlenlp/trl/dpo_trainer.py b/paddlenlp/trl/dpo_trainer.py index 144ceb816fdc..0429c9faa769 100644 --- a/paddlenlp/trl/dpo_trainer.py +++ b/paddlenlp/trl/dpo_trainer.py @@ -177,8 +177,8 @@ def get_batch_metrics(self, model, batch, train_eval="train"): } if "attention_mask" in batch: dpo_inputs["attention_mask"] = batch["attention_mask"] - if "attn_mask_start_row_indices" in batch: - dpo_inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"] + if "attn_mask_startend_row_indices" in batch: + dpo_inputs["attn_mask_startend_row_indices"] = batch["attn_mask_startend_row_indices"] if self.reference_free: reference_chosen_logps, reference_rejected_logps = None, None else: @@ -194,8 +194,8 @@ def get_batch_metrics(self, model, batch, train_eval="train"): } if "attention_mask" in batch: dpo_inputs["attention_mask"] = batch["attention_mask"] - if "attn_mask_start_row_indices" in batch: - dpo_inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"] + if "attn_mask_startend_row_indices" in batch: + dpo_inputs["attn_mask_startend_row_indices"] = batch["attn_mask_startend_row_indices"] if self.reference_free: reference_chosen_logps, reference_rejected_logps = None, None else: @@ -522,7 +522,7 @@ def prepare_pipeline_dpo_inputs_func(inputs): else: first_stage_keys = [ "input_ids", - "attn_mask_start_row_indices", + "attn_mask_startend_row_indices", "position_ids", ] diff --git a/paddlenlp/trl/trl_data.py b/paddlenlp/trl/trl_data.py index ca3a1ae40f7e..44758bbe50a1 100644 --- a/paddlenlp/trl/trl_data.py +++ b/paddlenlp/trl/trl_data.py @@ -159,8 +159,8 @@ def preprocess_preference_data(data, tokenizer, data_args, model_args): } # attention mask - if model_args.use_attn_mask_start_row_indices: - output_dict["attn_mask_start_row_indices"] = ( + if model_args.use_attn_mask_startend_row_indices: + output_dict["attn_mask_startend_row_indices"] = ( [seq_len] * prompt_len + [prompt_len + chosen_len] * chosen_len + [seq_len] * rejected_len ) else: @@ -183,14 +183,14 @@ def preference_collate_fn(batch, max_seq_len=None): "response_indexs": [], } sequence = batch[0] - if "attn_mask_start_row_indices" in sequence: - input_dict["attn_mask_start_row_indices"] = [] - use_attn_mask_start_row_indices = True + if "attn_mask_startend_row_indices" in sequence: + input_dict["attn_mask_startend_row_indices"] = [] + use_attn_mask_startend_row_indices = True elif "attention_mask" in sequence: input_dict["attention_mask"] = [] - use_attn_mask_start_row_indices = False + use_attn_mask_startend_row_indices = False else: - raise ValueError("attention_mask and attn_mask_start_row_indices are both None.") + raise ValueError("attention_mask and attn_mask_startend_row_indices are both None.") for i, sequence in enumerate(batch): difference = max_seq_len - len(sequence["input_ids"]) @@ -199,9 +199,12 @@ def preference_collate_fn(batch, max_seq_len=None): input_dict["position_ids"].append(sequence["position_ids"] + [0] * difference) input_dict["chosen_labels"].append(sequence["chosen_labels"] + [0] * difference) input_dict["rejected_labels"].append(sequence["rejected_labels"] + [0] * difference) - if use_attn_mask_start_row_indices: - input_dict["attn_mask_start_row_indices"].append( - [sequence["attn_mask_start_row_indices"] + [sequence["attn_mask_start_row_indices"][-1]] * difference] + if use_attn_mask_startend_row_indices: + input_dict["attn_mask_startend_row_indices"].append( + [ + sequence["attn_mask_startend_row_indices"] + + [sequence["attn_mask_startend_row_indices"][-1]] * difference + ] ) else: input_dict["attention_mask"].append( @@ -225,7 +228,7 @@ def preference_collate_fn(batch, max_seq_len=None): for key in input_dict: if key == "attention_mask": input_dict[key] = np.array(input_dict[key], dtype=bool) - elif key == "attn_mask_start_row_indices": + elif key == "attn_mask_startend_row_indices": input_dict[key] = np.array(input_dict[key], dtype=np.int32) else: input_dict[key] = np.array(input_dict[key]) From 42ccb0d3b3deae21e5a653f5a8883772c2155667 Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Wed, 26 Jun 2024 14:53:57 +0800 Subject: [PATCH 2/9] dpo support --- paddlenlp/data/data_collator.py | 20 +++++++++++--------- paddlenlp/transformers/llama/fusion_ops.py | 4 +--- paddlenlp/transformers/llama/modeling.py | 1 + 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/paddlenlp/data/data_collator.py b/paddlenlp/data/data_collator.py index 74490ac86f3b..d58daee1ac65 100644 --- a/paddlenlp/data/data_collator.py +++ b/paddlenlp/data/data_collator.py @@ -408,17 +408,19 @@ def __call__(self, features, return_tensors=None): (max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of ) - padding_side = self.tokenizer.padding_side for feature in batch: - remainder = [feature["attn_mask_startend_row_indices"][-1]] * ( - max_length - len(feature["attn_mask_startend_row_indices"]) - ) - if isinstance(feature["attn_mask_startend_row_indices"], list): - feature["attn_mask_startend_row_indices"] = ( - feature["attn_mask_startend_row_indices"] + remainder - if padding_side == "right" - else remainder + feature["attn_mask_startend_row_indices"] + pad_len = max_length - len(feature["attn_mask_startend_row_indices"]) + remainder = np.zeros([1, pad_len], dtype=np.int32) + feature["attn_mask_startend_row_indices"] = ( + np.concatenate( + [remainder, np.array([feature["attn_mask_startend_row_indices"]], dtype=np.int32) + pad_len], + axis=-1, ) + if padding_side == "left" + else np.concatenate( + [np.array([feature["attn_mask_startend_row_indices"]], dtype=np.int32), remainder], axis=-1 + ) + ) batch = self.tokenizer.pad( batch, diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 7a6b0b51a42a..290538912c9d 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -212,9 +212,7 @@ def fusion_flash_attention( if attn_mask_startend_row_indices is not None: assert alibi is None, "flash_attention_with_sparse_mask not support alibi" if len(attn_mask_startend_row_indices.shape) == 2: - attn_mask_startend_row_indices = paddle.expand( - paddle.unsqueeze(attn_mask_startend_row_indices, axis=1), shape=[-1, num_heads, -1] - ) + attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1) attn_output = F.flash_attention_with_sparse_mask( query_states, key_states, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 28b292502175..0dd672019239 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1916,6 +1916,7 @@ def forward( "You have provided both attn_mask_startend_row_indices and attention_mask. " "The attn_mask_startend_row_indices will be used." ) + attention_mask = None outputs = self.llama( input_ids, # [bs, seq_len] From 21b790acf171c05e7801e0e8c217b9fa537a3a9f Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Wed, 26 Jun 2024 16:05:23 +0800 Subject: [PATCH 3/9] update --- llm/alignment/dpo/run_dpo.py | 4 ++++ llm/run_finetune.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index aa7a09f16ad3..52dcf23a748d 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -17,6 +17,7 @@ import os import sys import time +import inspect from functools import partial import paddle @@ -124,6 +125,9 @@ def main(): ref_model = AutoModelForCausalLM.from_config(ref_config) model.set_state_dict(ref_model.state_dict()) + if model_args.use_attn_mask_startend_row_indices and "attn_mask_startend_row_indices" not in inspect.signature(model.forward).parameters: + raise NotImplementedError(f"{model.__class__} not support flash mask.") + if model_args.tokenizer_name_or_path is not None: tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) else: diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 953cb58df0b8..910c745aa905 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -14,6 +14,7 @@ import json import os import sys +import inspect from functools import partial import paddle @@ -160,6 +161,9 @@ def main(): # NOTE(gongenlei): new add autotuner_benchmark model = model_class.from_config(model_config, dtype=dtype) + if model_args.use_attn_mask_startend_row_indices and "attn_mask_startend_row_indices" not in inspect.signature(model.forward).parameters: + raise NotImplementedError(f"{model.__class__} not support flash mask.") + if training_args.do_train and model_args.neftune: # Inspired by https://github.com/neelsjain/NEFTune if hasattr(model, "get_input_embeddings"): From 4578145f3e421004065b45d1e57aaf482052c6d6 Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Wed, 26 Jun 2024 17:46:33 +0800 Subject: [PATCH 4/9] remove dense mask when using flash mask --- llm/run_finetune.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 910c745aa905..25e6a22e6b9f 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -502,6 +502,7 @@ def compute_metrics_do_generation(eval_preds): padding=padding, max_label_length=max_length, return_tensors="np", + return_attention_masks=not model_args.use_attn_mask_startend_row_indices, pad_to_multiple_of=data_args.pad_to_multiple_of, ), do_generation=data_args.eval_with_do_generation, From 56f3ba31c690d97766fd28c1385c12f3ca4a415d Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Wed, 26 Jun 2024 17:50:55 +0800 Subject: [PATCH 5/9] bugfix --- llm/run_finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 25e6a22e6b9f..ef324bd155e2 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -502,7 +502,7 @@ def compute_metrics_do_generation(eval_preds): padding=padding, max_label_length=max_length, return_tensors="np", - return_attention_masks=not model_args.use_attn_mask_startend_row_indices, + return_attention_mask=not model_args.use_attn_mask_startend_row_indices, pad_to_multiple_of=data_args.pad_to_multiple_of, ), do_generation=data_args.eval_with_do_generation, From 3bf634be7df36476ae6b156248803a4d89987814 Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Thu, 27 Jun 2024 12:41:24 +0800 Subject: [PATCH 6/9] support pp --- llm/alignment/dpo/dpo_argument.py | 4 ++-- llm/alignment/dpo/run_dpo.py | 15 +++++++++++++- llm/config/llama/dpo_argument.json | 2 +- llm/config/qwen/dpo_argument.json | 2 +- llm/run_finetune.py | 22 ++++++++++++++++----- llm/utils/argument.py | 4 ++-- paddlenlp/data/data_collator.py | 5 ++++- paddlenlp/transformers/llama/modeling.py | 4 ---- paddlenlp/transformers/llama/modeling_pp.py | 20 ++++++++++++++----- paddlenlp/trl/trl_data.py | 2 +- 10 files changed, 57 insertions(+), 23 deletions(-) diff --git a/llm/alignment/dpo/dpo_argument.py b/llm/alignment/dpo/dpo_argument.py index 6939ff63c397..86173be4ae2c 100644 --- a/llm/alignment/dpo/dpo_argument.py +++ b/llm/alignment/dpo/dpo_argument.py @@ -87,8 +87,8 @@ class DPOModelArgument: "help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`." }, ) - use_attn_mask_startend_row_indices: bool = field( - default=False, metadata={"help": "Whether to use attn_mask_startend_row_indices in flash attention."} + flash_mask: bool = field( + default=False, metadata={"help": "Whether to use flash mask in flash attention."} ) virtual_pp_degree: int = field( default=1, diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index 52dcf23a748d..4bfaf24ecbd8 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -37,8 +37,14 @@ preference_collate_fn, preprocess_preference_data, ) +from paddlenlp.transformers import ( + LlamaForCausalLM, + LlamaForCausalLMPipe, +) from paddlenlp.utils.log import logger +flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe] + def main(): """main""" @@ -125,7 +131,14 @@ def main(): ref_model = AutoModelForCausalLM.from_config(ref_config) model.set_state_dict(ref_model.state_dict()) - if model_args.use_attn_mask_startend_row_indices and "attn_mask_startend_row_indices" not in inspect.signature(model.forward).parameters: + if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): + logger.warning( + "`flash_mask` must use with zero padding and flash attention." + ) + data_args.zero_padding = True + model.config.use_flash_attention = True + + if not any(isinstance(model, cls) for cls in flash_mask_support_list): raise NotImplementedError(f"{model.__class__} not support flash mask.") if model_args.tokenizer_name_or_path is not None: diff --git a/llm/config/llama/dpo_argument.json b/llm/config/llama/dpo_argument.json index b31656392d2a..cc2b000dfcef 100644 --- a/llm/config/llama/dpo_argument.json +++ b/llm/config/llama/dpo_argument.json @@ -27,7 +27,7 @@ "sharding_parallel_degree": 1, "sharding": "stage1", "use_flash_attention": true, - "use_attn_mask_startend_row_indices":false, + "flash_mask":true, "recompute": false, "recompute_granularity": "full", "dpo_beta": 0.1, diff --git a/llm/config/qwen/dpo_argument.json b/llm/config/qwen/dpo_argument.json index 263a43ee3fda..39f9a1f5c447 100644 --- a/llm/config/qwen/dpo_argument.json +++ b/llm/config/qwen/dpo_argument.json @@ -27,7 +27,7 @@ "sharding_parallel_degree": 1, "sharding": "stage1", "use_flash_attention": true, - "use_attn_mask_startend_row_indices":false, + "flash_mask":false, "recompute": false, "recompute_granularity": "full", "dpo_beta": 0.1, diff --git a/llm/run_finetune.py b/llm/run_finetune.py index ef324bd155e2..08375fc8d132 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -52,6 +52,8 @@ AutoTokenizer, Llama3Tokenizer, LlamaTokenizer, + LlamaForCausalLM, + LlamaForCausalLMPipe, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig from paddlenlp.utils.log import logger @@ -59,6 +61,8 @@ # Fine-tune Environment Variables to support sharding stage1 overlap optimization. os.environ["USE_CASUAL_MASK"] = "False" +flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe] + def main(): # Arguments @@ -78,6 +82,7 @@ def main(): raise ValueError( "--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time" ) + # Setup GPU & distributed training paddle.set_device(training_args.device) @@ -161,7 +166,14 @@ def main(): # NOTE(gongenlei): new add autotuner_benchmark model = model_class.from_config(model_config, dtype=dtype) - if model_args.use_attn_mask_startend_row_indices and "attn_mask_startend_row_indices" not in inspect.signature(model.forward).parameters: + if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): + logger.warning( + "`flash_mask` must use with zero padding and flash attention." + ) + data_args.zero_padding = True + model.config.use_flash_attention = True + + if not any(isinstance(model, cls) for cls in flash_mask_support_list): raise NotImplementedError(f"{model.__class__} not support flash mask.") if training_args.do_train and model_args.neftune: @@ -333,12 +345,12 @@ def neft_post_hook(module, input, output): "Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far." ) train_ds = ( - train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.use_attn_mask_startend_row_indices)) + train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)) if train_ds is not None else None ) ptq_ds = ( - ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.use_attn_mask_startend_row_indices)) + ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)) if ptq_ds is not None else None ) @@ -349,7 +361,7 @@ def neft_post_hook(module, input, output): ) eval_zero_padding = False dev_ds = ( - dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.use_attn_mask_startend_row_indices)) + dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.flash_mask)) if dev_ds is not None else None ) @@ -502,7 +514,7 @@ def compute_metrics_do_generation(eval_preds): padding=padding, max_label_length=max_length, return_tensors="np", - return_attention_mask=not model_args.use_attn_mask_startend_row_indices, + return_attention_mask=not model_args.flash_mask, pad_to_multiple_of=data_args.pad_to_multiple_of, ), do_generation=data_args.eval_with_do_generation, diff --git a/llm/utils/argument.py b/llm/utils/argument.py index 802a16b831d4..63a14e4126ef 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -209,8 +209,8 @@ class ModelArgument: aistudio_token: str = field(default=None, metadata={"help": "The token of aistudio"}) neftune: bool = field(default=False, metadata={"help": "Whether to apply NEFT"}) neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"}) - use_attn_mask_startend_row_indices: bool = field( - default=False, metadata={"help": "Whether to use attn_mask_startend_row_indices in flash attention."} + flash_mask: bool = field( + default=False, metadata={"help": "Whether to use flash_mask in flash attention."} ) diff --git a/paddlenlp/data/data_collator.py b/paddlenlp/data/data_collator.py index d58daee1ac65..90f705d22056 100644 --- a/paddlenlp/data/data_collator.py +++ b/paddlenlp/data/data_collator.py @@ -402,7 +402,10 @@ def __call__(self, features, return_tensors=None): else: feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) if use_attn_mask_startend_row_indices is not None: - max_length = max(len(l) for l in use_attn_mask_startend_row_indices) + if self.max_label_length is not None: + max_length = self.max_length + else: + max_length = max(len(l) for l in use_attn_mask_startend_row_indices) if self.pad_to_multiple_of is not None: max_length = ( (max_length + self.pad_to_multiple_of - 1) // self.pad_to_multiple_of * self.pad_to_multiple_of diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index b9401b43e61a..dd38504a0e6f 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1907,10 +1907,6 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64: - attn_mask_startend_row_indices = attn_mask_startend_row_indices.astype(paddle.int32) - logger.info("cast attn_mask_startend_row_indices to paddle.int32") - if attn_mask_startend_row_indices is not None and attention_mask is not None: logger.warning( "You have provided both attn_mask_startend_row_indices and attention_mask. " diff --git a/paddlenlp/transformers/llama/modeling_pp.py b/paddlenlp/transformers/llama/modeling_pp.py index 87c59be03cdd..97b9493b9548 100644 --- a/paddlenlp/transformers/llama/modeling_pp.py +++ b/paddlenlp/transformers/llama/modeling_pp.py @@ -208,11 +208,21 @@ def forward(self, args): alibi = position_ids position_ids = attn_mask_startend_row_indices attn_mask_startend_row_indices = None - elif not self.config.alibi and position_ids is None and attn_mask_startend_row_indices is not None: - # hidden_states, attention_mask, position_ids - position_ids = attn_mask_startend_row_indices - attn_mask_startend_row_indices = None - alibi = None + elif not self.config.alibi: + if get_env_device() in ["gpu"]: + if attention_mask.dtype == paddle.int32: + attention_mask, attn_mask_startend_row_indices, position_ids = ( + None, + attention_mask, + attn_mask_startend_row_indices, + ) + elif attention_mask.dtype == paddle.int64: + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask + elif attn_mask_startend_row_indices == paddle.int64: + attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices + elif position_ids is None and attn_mask_startend_row_indices is not None: + position_ids = attn_mask_startend_row_indices + attn_mask_startend_row_indices = None has_gradient = not hidden_states.stop_gradient if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: diff --git a/paddlenlp/trl/trl_data.py b/paddlenlp/trl/trl_data.py index 44758bbe50a1..7afc3cb527cf 100644 --- a/paddlenlp/trl/trl_data.py +++ b/paddlenlp/trl/trl_data.py @@ -159,7 +159,7 @@ def preprocess_preference_data(data, tokenizer, data_args, model_args): } # attention mask - if model_args.use_attn_mask_startend_row_indices: + if model_args.flash_mask: output_dict["attn_mask_startend_row_indices"] = ( [seq_len] * prompt_len + [prompt_len + chosen_len] * chosen_len + [seq_len] * rejected_len ) From 82e4f6a35909387c8522cc0e01a9cf4d711f1d31 Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Thu, 27 Jun 2024 12:44:12 +0800 Subject: [PATCH 7/9] bugfix --- llm/alignment/dpo/run_dpo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index 4bfaf24ecbd8..f354023538ab 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -131,11 +131,10 @@ def main(): ref_model = AutoModelForCausalLM.from_config(ref_config) model.set_state_dict(ref_model.state_dict()) - if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention): + if model_args.flash_mask and not model.config.use_flash_attention: logger.warning( "`flash_mask` must use with zero padding and flash attention." ) - data_args.zero_padding = True model.config.use_flash_attention = True if not any(isinstance(model, cls) for cls in flash_mask_support_list): From 62b74d8c904ce3c5ae23c3886e7072dcaf415e90 Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Thu, 27 Jun 2024 14:29:34 +0800 Subject: [PATCH 8/9] bugfix --- paddlenlp/data/data_collator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/data/data_collator.py b/paddlenlp/data/data_collator.py index 90f705d22056..351c44867b28 100644 --- a/paddlenlp/data/data_collator.py +++ b/paddlenlp/data/data_collator.py @@ -402,7 +402,7 @@ def __call__(self, features, return_tensors=None): else: feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) if use_attn_mask_startend_row_indices is not None: - if self.max_label_length is not None: + if self.max_length is not None: max_length = self.max_length else: max_length = max(len(l) for l in use_attn_mask_startend_row_indices) From 74bc7ad880a94e2a4bff0fd5606af175de9c51c9 Mon Sep 17 00:00:00 2001 From: Ting Liu Date: Thu, 27 Jun 2024 16:58:31 +0800 Subject: [PATCH 9/9] bugfix --- llm/alignment/dpo/run_dpo.py | 2 +- llm/run_finetune.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/alignment/dpo/run_dpo.py b/llm/alignment/dpo/run_dpo.py index f354023538ab..0d542444fe55 100644 --- a/llm/alignment/dpo/run_dpo.py +++ b/llm/alignment/dpo/run_dpo.py @@ -137,7 +137,7 @@ def main(): ) model.config.use_flash_attention = True - if not any(isinstance(model, cls) for cls in flash_mask_support_list): + if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): raise NotImplementedError(f"{model.__class__} not support flash mask.") if model_args.tokenizer_name_or_path is not None: diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 08375fc8d132..8df3705fe335 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -173,7 +173,7 @@ def main(): data_args.zero_padding = True model.config.use_flash_attention = True - if not any(isinstance(model, cls) for cls in flash_mask_support_list): + if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): raise NotImplementedError(f"{model.__class__} not support flash mask.") if training_args.do_train and model_args.neftune: