From e80ac29752e06b54c43b352803747e2d3299eb5e Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 21 Mar 2024 04:27:59 +0000 Subject: [PATCH 1/4] Add matryoshka training --- .../examples/constrative_train/README.md | 108 +++++++++ .../examples/constrative_train/arguments.py | 101 ++++++++ pipelines/examples/constrative_train/data.py | 139 +++++++++++ .../data/toy_finetune_data.jsonl | 10 + .../constrative_train/evaluation/__init__.py | 13 ++ .../evaluation/benchmarks.py | 216 ++++++++++++++++++ .../evaluation/prediction.py | 153 +++++++++++++ .../constrative_train/models/__init__.py | 13 ++ .../constrative_train/models/modeling.py | 194 ++++++++++++++++ .../constrative_train/requirements.txt | 5 + pipelines/examples/constrative_train/train.py | 157 +++++++++++++ pipelines/examples/constrative_train/utils.py | 28 +++ 12 files changed, 1137 insertions(+) create mode 100644 pipelines/examples/constrative_train/README.md create mode 100644 pipelines/examples/constrative_train/arguments.py create mode 100644 pipelines/examples/constrative_train/data.py create mode 100644 pipelines/examples/constrative_train/data/toy_finetune_data.jsonl create mode 100644 pipelines/examples/constrative_train/evaluation/__init__.py create mode 100644 pipelines/examples/constrative_train/evaluation/benchmarks.py create mode 100644 pipelines/examples/constrative_train/evaluation/prediction.py create mode 100644 pipelines/examples/constrative_train/models/__init__.py create mode 100644 pipelines/examples/constrative_train/models/modeling.py create mode 100644 pipelines/examples/constrative_train/requirements.txt create mode 100644 pipelines/examples/constrative_train/train.py create mode 100644 pipelines/examples/constrative_train/utils.py diff --git a/pipelines/examples/constrative_train/README.md b/pipelines/examples/constrative_train/README.md new file mode 100644 index 000000000000..a1f82a3d33ea --- /dev/null +++ b/pipelines/examples/constrative_train/README.md @@ -0,0 +1,108 @@ +# generative-search + +## 安装 + +推荐安装gpu版本的[PaddlePalle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),以cuda11.7的paddle为例,安装命令如下: + +``` +python -m pip install paddlepaddle-gpu==2.6.0.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` +安装其他依赖: +``` +pip install -r requirements.txt +``` + +下载dureader中文数据集: + +``` +cd data +wget https://paddlenlp.bj.bcebos.com/datasets/dureader_dual.train.jsonl +``` + +## 运行 + +### 单卡训练 + +``` +export CUDA_VISIBLE_DEVICES=0 +python train.py --do_train \ + --model_name_or_path rocketqa-zh-base-query-encoder \ + --output_dir ./checkpoints \ + --train_data ./data/dureader_dual.train.jsonl \ + --overwrite_output_dir \ + --fine_tune_type sft \ + --sentence_pooling_method cls \ + --num_train_epochs 3 \ + --per_device_train_batch_size 64 \ + --learning_rate 3e-5 \ + --train_group_size 4 \ + --recompute \ + --passage_max_len 512 \ + --use_matryoshka +``` + +- `model_name_or_path`: 选择预训练模型,可选rocketqa-zh-base-query-encoder +- `output_dir`: 模型保存路径 +- `train_data`: 训练数据集路径,这里使用的是dureader中文数据集 +- `overwrite_output_dir`: 是否覆盖模型保存路径,默认为False +- `fine_tune_type`: 训练模式,可选sft和lora, bitfit等策略 +- `sentence_pooling_method`: 句子池化方法,可选cls和mean, cls为CLS层,mean为平均池化 +- `num_train_epochs`: 训练轮数 +- `per_device_train_batch_size`: 单卡训练batch大小 +- `learning_rate`: 学习率 +- `train_group_size`: 每个训练集正负样本的数据,默认为8,例如train_group_size=4,则每个训练集包含1个正样本和3个负样本 +- `max_example_num_per_dataset`: 每个训练集的最大样本数,默认为100000000 +- `recompute`: 是否重新计算,默认为False +- `query_max_len`: query的最大长度,默认为32 +- `query_instruction_for_retrieval`: query的检索指令,默认为None +- `passage_instruction_for_retrieval`: passage的检索指令,默认为None +- `passage_max_len`: passage的最大长度,默认为512 +- `use_matryoshka`: 是否使用俄罗斯套娃策略(matryoshka),默认为False +- `matryoshka_dims`: 俄罗斯套娃策略的维度,默认为[64, 128, 256, 512, 768] +- `matryoshka_loss_weights`: 俄罗斯套娃策略的损失权重,默认为[1, 1, 1, 1, 1] +- `use_inbatch_neg`: 是否使用in batch negatives策略,默认为False +- `use_flash_attention`: 是否使用flash attention,默认为False +- `temperature`: in batch negatives策略的temperature参数,默认为0.02 +- `negatives_cross_device`: 跨设备in batch negatives策略,默认为False +- `margin`: in batch negatives策略的margin参数,默认为0.2 + +### 多卡训练 + +单卡训练效率过低,batch_size较小,建议使用多卡训练,对于对比学习训练推荐使用大batch_size,多卡训练,示例命令如下: + +``` +python -m paddle.distributed.launch --gpus "1,2,3,4" train.py --do_train \ + --model_name_or_path rocketqa-zh-base-query-encoder \ + --output_dir ./checkpoints \ + --train_data ./data/dual.train.json \ + --overwrite_output_dir \ + --fine_tune_type sft \ + --sentence_pooling_method cls \ + --num_train_epochs 3 \ + --per_device_train_batch_size 32 \ + --learning_rate 3e-5 \ + --train_group_size 8 \ + --recompute \ + --passage_max_len 512 \ + --use_matryoshka +``` + +## 评估 + +评估脚本: + +``` +export CUDA_VISIBLE_DEVICES=0 +python evaluation/benchmarks.py --model_type bert \ + --query_model checkpoints/checkpoint-1500 \ + --passage_model checkpoints/checkpoint-1500 \ + --query_max_length 64 \ + --passage_max_length 512 \ + --evaluate_all +``` +- `model_type`: 模型的类似,可选bert或roberta等等 +- `query_model`: query向量模型的路径 +- `passage_model`: passage向量模型的路径 +- `query_max_length`: query的最大长度 +- `passage_max_length`: passage的最大长度 +- `evaluate_all`: 是否评估所有的checkpoint,默认为False,即只评估指定的checkpoint diff --git a/pipelines/examples/constrative_train/arguments.py b/pipelines/examples/constrative_train/arguments.py new file mode 100644 index 000000000000..6a9fe71dd542 --- /dev/null +++ b/pipelines/examples/constrative_train/arguments.py @@ -0,0 +1,101 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import List, Optional + +from paddlenlp.trainer import TrainingArguments + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + ) + + normalized: bool = field(default=True) + use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"}) + + +@dataclass +class DataArguments: + train_data: str = field(default=None, metadata={"help": "Path to train data"}) + train_group_size: int = field(default=8) + + query_max_len: int = field( + default=32, + metadata={ + "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + + passage_max_len: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization for passage. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + + max_example_num_per_dataset: int = field( + default=100000000, + metadata={"help": "the max number of examples for each dataset"}, + ) + + query_instruction_for_retrieval: str = field(default=None, metadata={"help": "instruction for query"}) + passage_instruction_for_retrieval: str = field(default=None, metadata={"help": "instruction for passage"}) + + def __post_init__(self): + if not os.path.exists(self.train_data): + raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path") + + +@dataclass +class RetrieverTrainingArguments(TrainingArguments): + negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"}) + temperature: Optional[float] = field(default=0.02) + margin: Optional[float] = field(default=0.2) + fix_position_embedding: bool = field( + default=False, metadata={"help": "Freeze the parameters of position embeddings"} + ) + sentence_pooling_method: str = field( + default="mean", + metadata={"help": "the pooling method, should be weighted_mean"}, + ) + fine_tune_type: str = field( + default="sft", + metadata={"help": "fine-tune type for retrieval,eg: sft, bitfit, lora"}, + ) + use_inbatch_neg: bool = field(default=False, metadata={"help": "use passages in the same batch as negatives"}) + + use_matryoshka: bool = field(default=False, metadata={"help": "use matryoshka for flexible embedding size"}) + + matryoshka_dims: List[int] = field( + default_factory=lambda: [64, 128, 256, 512, 768], + metadata={"help": "matryoshka dims"}, + ) + matryoshka_loss_weights: List[float] = field( + default_factory=lambda: [1, 1, 1, 1, 1], + metadata={"help": "matryoshka loss weights"}, + ) diff --git a/pipelines/examples/constrative_train/data.py b/pipelines/examples/constrative_train/data.py new file mode 100644 index 000000000000..8d98c9ab56fc --- /dev/null +++ b/pipelines/examples/constrative_train/data.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os.path +import random +from dataclasses import dataclass + +import datasets +from arguments import DataArguments +from paddle.io import Dataset + +from paddlenlp.data import DataCollatorWithPadding +from paddlenlp.transformers import PretrainedTokenizer + + +class TrainDatasetForEmbedding(Dataset): + def __init__( + self, + args: DataArguments, + tokenizer: PretrainedTokenizer, + query_max_len: int = 64, + passage_max_len: int = 1048, + is_batch_negative: bool = False, + ): + if os.path.isdir(args.train_data): + train_datasets = [] + for file in os.listdir(args.train_data): + temp_dataset = datasets.load_dataset( + "json", + data_files=os.path.join(args.train_data, file), + split="train", + ) + if len(temp_dataset) > args.max_example_num_per_dataset: + temp_dataset = temp_dataset.select( + random.sample( + list(range(len(temp_dataset))), + args.max_example_num_per_dataset, + ) + ) + train_datasets.append(temp_dataset) + self.dataset = datasets.concatenate_datasets(train_datasets) + else: + self.dataset = datasets.load_dataset("json", data_files=args.train_data, split="train") + self.tokenizer = tokenizer + self.args = args + self.total_len = len(self.dataset) + self.query_max_len = query_max_len + self.passage_max_len = passage_max_len + self.is_batch_negative = is_batch_negative + + def __len__(self): + return self.total_len + + def __getitem__(self, item): + query = self.dataset[item]["query"] + if self.args.query_instruction_for_retrieval is not None: + query = self.args.query_instruction_for_retrieval + query + query = self.tokenizer( + query, + truncation=True, + max_length=self.query_max_len, + return_attention_mask=False, + truncation_side="right", + ) + passages = [] + pos = random.choice(self.dataset[item]["pos"]) + passages.append(pos) + # Add negative examples + if not self.is_batch_negative: + if len(self.dataset[item]["neg"]) < self.args.train_group_size - 1: + num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]["neg"])) + negs = random.sample(self.dataset[item]["neg"] * num, self.args.train_group_size - 1) + else: + negs = random.sample(self.dataset[item]["neg"], self.args.train_group_size - 1) + passages.extend(negs) + + if self.args.passage_instruction_for_retrieval is not None: + passages = [self.args.passage_instruction_for_retrieval + p for p in passages] + passages = self.tokenizer( + passages, + truncation=True, + max_length=self.passage_max_len, + return_attention_mask=False, + truncation_side="right", + ) + # Convert passages to input_ids + passages_tackle = [] + for i in range(len(passages["input_ids"])): + passages_tackle.append({"input_ids": passages["input_ids"][i]}) + return query, passages_tackle + + +@dataclass +class EmbedCollator(DataCollatorWithPadding): + """ + Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg] + and pass batch separately to the actual collator. + Abstract out data detail for the model. + """ + + query_max_len: int = 32 + passage_max_len: int = 128 + + def __call__(self, features): + query = [f[0] for f in features] + passage = [f[1] for f in features] + if isinstance(query[0], list): + query = sum(query, []) + if isinstance(passage[0], list): + passage = sum(passage, []) + q_collated = self.tokenizer.pad( + query, + padding="max_length", + max_length=self.query_max_len, + return_attention_mask=True, + pad_to_multiple_of=None, + return_tensors="pd", + ) + d_collated = self.tokenizer.pad( + passage, + padding="max_length", + max_length=self.passage_max_len, + return_attention_mask=True, + pad_to_multiple_of=None, + return_tensors="pd", + ) + return {"query": q_collated, "passage": d_collated} diff --git a/pipelines/examples/constrative_train/data/toy_finetune_data.jsonl b/pipelines/examples/constrative_train/data/toy_finetune_data.jsonl new file mode 100644 index 000000000000..51cad7d16e63 --- /dev/null +++ b/pipelines/examples/constrative_train/data/toy_finetune_data.jsonl @@ -0,0 +1,10 @@ +{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.", "The man is talking about hawaii.", "A woman is standing outside.", "The battle was over. ", "A group of people plays volleyball."]} +{"query": "A woman standing on a high cliff on one leg looking over a river.", "pos": ["A woman is standing on a cliff."], "neg": ["A woman sits on a chair.", "George Bush told the Republicans there was no way he would let them even consider this foolish idea, against his top advisors advice.", "The family was falling apart.", "no one showed up to the meeting", "A boy is sitting outside playing in the sand.", "Ended as soon as I received the wire.", "A child is reading in her bedroom."]} +{"query": "Two woman are playing instruments; one a clarinet, the other a violin.", "pos": ["Some people are playing a tune."], "neg": ["Two women are playing a guitar and drums.", "A man is skiing down a mountain.", "The fatal dose was not taken when the murderer thought it would be.", "Person on bike", "The girl is standing, leaning against the archway.", "A group of women watch soap operas.", "No matter how old people get they never forget. "]} +{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]} +{"query": "A yellow dog running along a forest path.", "pos": ["a dog is running"], "neg": ["a cat is running", "Steele did not keep her original story.", "The rule discourages people to pay their child support.", "A man in a vest sits in a car.", "Person in black clothing, with white bandanna and sunglasses waits at a bus stop.", "Neither the Globe or Mail had comments on the current state of Canada's road system. ", "The Spring Creek facility is old and outdated."]} +{"query": "It sets out essential activities in each phase along with critical factors related to those activities.", "pos": ["Critical factors for essential activities are set out."], "neg": ["It lays out critical activities but makes no provision for critical factors related to those activities.", "People are assembled in protest.", "The state would prefer for you to do that.", "A girl sits beside a boy.", "Two males are performing.", "Nobody is jumping", "Conrad was being plotted against, to be hit on the head."]} +{"query": "A man giving a speech in a restaurant.", "pos": ["A person gives a speech."], "neg": ["The man sits at the table and eats food.", "This is definitely not an endorsement.", "They sold their home because they were retiring and not because of the loan.", "The seal of Missouri is perfect.", "Someone is raising their hand.", "An athlete is competing in the 1500 meter swimming competition.", "Two men watching a magic show."]} +{"query": "Indians having a gathering with coats and food and drinks.", "pos": ["A group of Indians are having a gathering with food and drinks"], "neg": ["A group of Indians are having a funeral", "It is only staged on Winter afternoons in Palma's large bullring.", "Right information can empower the legal service practices and the justice system. ", "Meanwhile, the mainland was empty of population.", "Two children is sleeping.", "a fisherman is trying to catch a monkey", "the people are in a train"]} +{"query": "A woman with violet hair rides her bicycle outside.", "pos": ["A woman is riding her bike."], "neg": ["A woman is jogging in the park.", "The street was lined with white-painted houses.", "A group watches a movie inside.", "man at picnics cut steak", "Several chefs are sitting down and talking about food.", "The Commission notes that no significant alternatives were considered.", "We ran out of firewood and had to use pine needles for the fire."]} +{"query": "A man pulls two women down a city street in a rickshaw.", "pos": ["A man is in a city."], "neg": ["A man is a pilot of an airplane.", "It is boring and mundane.", "The morning sunlight was shining brightly and it was warm. ", "Two people jumped off the dock.", "People watching a spaceship launch.", "Mother Teresa is an easy choice.", "It's worth being able to go at a pace you prefer."]} \ No newline at end of file diff --git a/pipelines/examples/constrative_train/evaluation/__init__.py b/pipelines/examples/constrative_train/evaluation/__init__.py new file mode 100644 index 000000000000..fd05a9208165 --- /dev/null +++ b/pipelines/examples/constrative_train/evaluation/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pipelines/examples/constrative_train/evaluation/benchmarks.py b/pipelines/examples/constrative_train/evaluation/benchmarks.py new file mode 100644 index 000000000000..62f00a11570f --- /dev/null +++ b/pipelines/examples/constrative_train/evaluation/benchmarks.py @@ -0,0 +1,216 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import csv +import glob +import math +import time +from collections import defaultdict +from typing import Dict, List, cast + +from datasets import load_dataset +from mteb.abstasks import AbsTaskRetrieval +from prediction import Eval_modle + +csv.field_size_limit(500 * 1024 * 1024) + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument('--model_type', choices=['bloom', 'llama', 'baichuan', "bert", 'roberta', 'ernie'], default="bloom", help="The model types") +parser.add_argument("--query_model", default="bigscience/bloomz-7b1-mt", type=str, help="The ann index name") +parser.add_argument("--passage_model", default="bigscience/bloomz-7b1-mt", type=str, help="The ann index name") +parser.add_argument("--query_max_length", default=64, type=int, help="Number of element to retrieve from embedding search") +parser.add_argument("--passage_max_length", default=512, type=int, help="The embedding_dim of index") +parser.add_argument("--evaluate_all", action="store_true", help="Evaluate all checkpoints") +parser.add_argument("--checkpoint_dir", default="checkpoints", type=str, help="The checkpoints root directory") + +args = parser.parse_args() +# yapf: enable + + +class PaddleModel: + def __init__( + self, + query_model, + corpus_model, + model_type="bloom", + batch_size=1, + max_seq_len=512, + sep=" ", + pooling_mode="mean_tokens", + **kwargs, + ): + self.query_model = Eval_modle( + model=query_model, + max_seq_len=max_seq_len, + batch_size=batch_size, + model_type=model_type, + ) + self.sep = sep + + def encode_queries(self, queries: List[str], batch_size: int, **kwargs): + return self.query_model.run(queries, batch_size=batch_size, max_seq_len=args.query_max_length, **kwargs) + + def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs): + if type(corpus) is dict: + sentences = [ + (corpus["title"][i] + self.sep + corpus["text"][i]).strip() + if "title" in corpus + else corpus["text"][i].strip() + for i in range(len(corpus["text"])) + ] + else: + sentences = [ + (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() + for doc in corpus + ] + return self.query_model.run( + sentences, + batch_size=batch_size, + max_seq_len=args.passage_max_length, + **kwargs, + ) + + +class T2RRetrieval(AbsTaskRetrieval): + def __init__(self, num_max_passages: "int | None" = None, **kwargs): + super().__init__(**kwargs) + self.num_max_passages = num_max_passages or math.inf + + @property + def description(self): + return { + "name": "T2RankingRetrieval", + "reference": "https://huggingface.co/datasets/THUIR/T2Ranking", + "type": "Retrieval", + "category": "s2p", + "eval_splits": ["dev"], + "eval_langs": ["zh"], + "main_score": "ndcg_at_10", + } + + def evaluate( + self, + model_query, + model_corpus, + model_type="bloom", + split="test", + batch_size=32, + corpus_chunk_size=None, + target_devices=None, + score_function="cos_sim", + **kwargs, + ): + from beir.retrieval.evaluation import EvaluateRetrieval + + if not self.data_loaded: + self.load_data() + corpus, queries, relevant_docs = ( + self.corpus[split], + self.queries[split], + self.relevant_docs[split], + ) + + from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES + + model = PaddleModel(model_query, model_corpus, model_type) + + model = DRES( + model, + batch_size=batch_size, + corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000, + **kwargs, + ) + retriever = EvaluateRetrieval(model, score_function=score_function) # or "cos_sim" or "dot" + start_time = time.time() + results = retriever.retrieve(corpus, queries) + end_time = time.time() + print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time)) + + ndcg, _map, recall, precision = retriever.evaluate(relevant_docs, results, retriever.k_values) + mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr") + + scores = { + **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, + **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, + **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, + **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, + **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()}, + } + print(scores) + return scores + + def load_data(self, **kwargs): + corpus, queries, qrels = load_t2ranking_for_retraviel(self.num_max_passages) + self.corpus, self.queries, self.relevant_docs = {}, {}, {} + self.corpus["dev"] = corpus + self.queries["dev"] = queries + self.relevant_docs["dev"] = qrels + self.data_loaded = True + + +def load_t2ranking_for_retraviel(num_max_passages: float): + collection_dataset = load_dataset("THUIR/T2Ranking", "collection")["train"] # type: ignore + dev_queries_dataset = load_dataset("THUIR/T2Ranking", "queries.dev")["train"] # type: ignore + dev_rels_dataset = load_dataset("THUIR/T2Ranking", "qrels.dev")["train"] # type: ignore + corpus = {} + for index in range(min(len(collection_dataset), num_max_passages)): + record = collection_dataset[index] + record = cast(dict, record) + pid: int = record["pid"] + corpus[str(pid)] = {"text": record["text"]} + queries = {} + for record in dev_queries_dataset: + record = cast(dict, record) + queries[str(record["qid"])] = record["text"] + + all_qrels = defaultdict(dict) + for record in dev_rels_dataset: + record = cast(dict, record) + pid: int = record["pid"] + if pid > num_max_passages: + continue + all_qrels[str(record["qid"])][str(record["pid"])] = record["rel"] + valid_qrels = {} + for qid, qrels in all_qrels.items(): + if len(set(list(qrels.values())) - set([0])) >= 1: + valid_qrels[qid] = qrels + valid_queries = {} + for qid, query in queries.items(): + if qid in valid_qrels: + valid_queries[qid] = query + print(f"valid qrels: {len(valid_qrels)}") + return corpus, valid_queries, valid_qrels + + +if __name__ == "__main__": + tasks = T2RRetrieval(num_max_passages=10000) + if args.evaluate_all: + checkpoints = glob.glob(f"{args.checkpoint_dir}/checkpoint-*") + checkpoints.sort() + for checkpoint in checkpoints: + tasks.evaluate( + model_query=checkpoint, + model_corpus=checkpoint, + model_type=args.model_type, + split="dev", + ) + + else: + tasks.evaluate( + model_query=args.query_model, + model_corpus=args.passage_model, + model_type=args.model_type, + split="dev", + ) diff --git a/pipelines/examples/constrative_train/evaluation/prediction.py b/pipelines/examples/constrative_train/evaluation/prediction.py new file mode 100644 index 000000000000..7dcf9513e8bf --- /dev/null +++ b/pipelines/examples/constrative_train/evaluation/prediction.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +import numpy as np +import paddle + +from paddlenlp.data import DataCollatorWithPadding +from paddlenlp.transformers import AutoTokenizer + +sys.path.append(os.path.abspath(".")) +from models.modeling import BiEncoderModel + + +class Eval_modle: + def __init__( + self, + model: str = None, + batch_size: int = 1, + max_seq_len: int = 512, + return_tensors: str = "np", + model_type: str = "bloom", + ): + self.model = model + self.batch_size = batch_size + self.return_tensors = return_tensors + self.model_type = model_type + self._construct_model() + self._construct_tokenizer() + + def _construct_model(self): + """ + Construct the inference model for the predictor. + """ + if self.model_type in ["bert", "roberta", "ernie"]: + self._model = BiEncoderModel.from_pretrained( + model_name_or_path=self.model, + normalized=True, + sentence_pooling_method="cls", + ) + print(f"loading checkpoints {self.model}") + else: + raise NotImplementedError + + self._model.eval() + + def _construct_tokenizer(self): + """ + Construct the tokenizer for the predictor. + """ + self._tokenizer = AutoTokenizer.from_pretrained(self.model) + self._tokenizer.padding_side = "right" + self.pad_token_id = self._tokenizer.convert_tokens_to_ids(self._tokenizer.pad_token) + # Fix windows dtype bug + self._collator = DataCollatorWithPadding(self._tokenizer, return_tensors="pd") + + def _batchify(self, data, batch_size, max_seq_len=None): + """ + Generate input batches. + """ + + def _parse_batch(batch_examples, max_seq_len=None): + if isinstance(batch_examples[0], str): + to_tokenize = [batch_examples] + to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize] + if max_seq_len is None: + max_seq_len = self.max_seq_len + tokenized_inputs = self._tokenizer( + to_tokenize[0], + padding=True, + truncation=True, + max_seq_len=max_seq_len, + return_attention_mask=True, + ) + return tokenized_inputs + + # Seperates data into some batches. + if max_seq_len is None: + max_seq_len = self.max_seq_len + one_batch = [] + for example in range(len(data)): + one_batch.append(data[example]) + if len(one_batch) == batch_size: + yield _parse_batch(one_batch, max_seq_len) + one_batch = [] + if one_batch: + yield _parse_batch(one_batch, max_seq_len) + + def _check_input_text(self, inputs): + """ + Check whether the input text meet the requirement. + """ + # inputs = inputs[0] + if isinstance(inputs, str): + if len(inputs) == 0: + raise ValueError("Invalid inputs, input text should not be empty text, please check your input.") + inputs = [inputs] + elif isinstance(inputs, list): + if not (isinstance(inputs[0], str) and len(inputs[0].strip()) > 0): + raise TypeError( + "Invalid inputs, input text should be list of str, and first element of list should not be empty text." + ) + else: + raise TypeError( + "Invalid inputs, input text should be str or list of str, but type of {} found!".format(type(inputs)) + ) + return inputs + + def _preprocess(self, inputs, batch_size=None, max_seq_len=None, **kwargs): + """ + Transform the raw inputs to the model inputs, two steps involved: + 1) Transform the raw text/image to token ids/pixel_values. + 2) Generate the other model inputs from the raw text/image and token ids/pixel_values. + """ + inputs = self._check_input_text(inputs) + if batch_size is None: + batch_size = self.batch_size + if max_seq_len is None: + max_seq_len = self.max_seq_len + batches = self._batchify(inputs, batch_size, max_seq_len) + outputs = {"batches": batches, "inputs": inputs} + return outputs + + def _run_model(self, inputs, **kwargs): + all_feats = [] + with paddle.no_grad(): + for batch_inputs in inputs["batches"]: + batch_inputs = self._collator(batch_inputs) + token_embeddings = self._model.encode(batch_inputs) + all_feats.append(token_embeddings.detach().cpu().numpy()) + return all_feats + + def _postprocess(self, inputs): + inputs = np.concatenate(inputs, axis=0) + return inputs + + def run(self, *args, **kwargs): + inputs = self._preprocess(*args, **kwargs) + outputs = self._run_model(inputs, **kwargs) + results = self._postprocess(outputs) + return results diff --git a/pipelines/examples/constrative_train/models/__init__.py b/pipelines/examples/constrative_train/models/__init__.py new file mode 100644 index 000000000000..fd05a9208165 --- /dev/null +++ b/pipelines/examples/constrative_train/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pipelines/examples/constrative_train/models/modeling.py b/pipelines/examples/constrative_train/models/modeling.py new file mode 100644 index 000000000000..8f1a719d5151 --- /dev/null +++ b/pipelines/examples/constrative_train/models/modeling.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import paddle +import paddle.distributed as dist +import paddle.nn as nn + +from paddlenlp.transformers import AutoConfig, AutoModel, PretrainedModel +from paddlenlp.transformers.model_outputs import ModelOutput +from paddlenlp.utils.log import logger + + +@dataclass +class EncoderOutput(ModelOutput): + q_reps: Optional[paddle.Tensor] = None + p_reps: Optional[paddle.Tensor] = None + loss: Optional[paddle.Tensor] = None + scores: Optional[paddle.Tensor] = None + + +class BiEncoderModel(PretrainedModel): + def __init__( + self, + model_name_or_path: str = None, + normalized: bool = False, + sentence_pooling_method: str = "cls", + negatives_cross_device: bool = False, + temperature: float = 1.0, + use_inbatch_neg: bool = True, + margin: float = 0.3, + matryoshka_dims: Optional[List[int]] = None, + matryoshka_loss_weights: Optional[List[float]] = None, + ): + super().__init__() + self.model = AutoModel.from_pretrained(model_name_or_path) + self.model_config = AutoConfig.from_pretrained(model_name_or_path) + self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") + + self.normalized = normalized + self.sentence_pooling_method = sentence_pooling_method + self.temperature = temperature + self.use_inbatch_neg = use_inbatch_neg + self.config = self.model_config + self.margin = margin + self.matryoshka_dims = matryoshka_dims + + if self.matryoshka_dims: + self.matryoshka_loss_weights = ( + matryoshka_loss_weights if matryoshka_loss_weights else [1] * len(self.matryoshka_dims) + ) + else: + self.matryoshka_loss_weights = None + + if not normalized: + self.temperature = 1.0 + logger.info("reset temperature = 1.0 due to using inner product to compute similarity") + + self.negatives_cross_device = negatives_cross_device + if self.negatives_cross_device: + if not dist.is_initialized(): + raise ValueError("Distributed training has not been initialized for representation all gather.") + self.process_rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def sentence_embedding(self, hidden_state, mask): + if self.sentence_pooling_method == "mean": + s = paddle.sum(hidden_state * mask.unsqueeze(-1).float(), axis=1) + d = mask.sum(axis=1, keepdim=True).float() + return s / d + elif self.sentence_pooling_method == "cls": + return hidden_state[:, 0] + + def get_model_config( + self, + ): + return self.model_config.to_dict() + + def encode(self, features): + psg_out = self.model(**features, return_dict=True) + p_reps = self.sentence_embedding(psg_out.last_hidden_state, features["attention_mask"]) + return p_reps + + def compute_similarity(self, q_reps, p_reps): + # q_reps [batch_size, embedding_dim] + # p_reps [batch_size, embedding_dim] + return paddle.matmul(q_reps, p_reps.transpose([1, 0])) + + def forward( + self, + inputs: Dict[str, paddle.Tensor] = None, + teacher_score: paddle.Tensor = None, + ): + query = inputs["query"] + passage = inputs["passage"] + q_reps = self.encode(query) + p_reps = self.encode(passage) + + if self.training: + # Cross device negatives + if self.negatives_cross_device: + q_reps = self._dist_gather_tensor(q_reps) + p_reps = self._dist_gather_tensor(p_reps) + + if self.matryoshka_dims: + loss = 0.0 + for loss_weight, dim in zip(self.matryoshka_loss_weights, self.matryoshka_dims): + reduced_q = q_reps[:, :dim] + reduced_d = p_reps[:, :dim] + if self.normalized: + reduced_q = paddle.nn.functional.normalize(reduced_q, axis=-1) + reduced_d = paddle.nn.functional.normalize(reduced_d, axis=-1) + scores = self.compute_similarity(reduced_q, reduced_d) + scores = scores / self.temperature + scores = scores.reshape([q_reps.shape[0], -1]) + + target = paddle.arange(scores.shape[0], dtype="int64") + target = target * (p_reps.shape[0] // q_reps.shape[0]) + dim_loss = self.compute_loss(scores, target) + loss += loss_weight * dim_loss + + elif self.use_inbatch_neg: + if self.normalized: + q_reps = paddle.nn.functional.normalize(q_reps, axis=-1) + p_reps = paddle.nn.functional.normalize(p_reps, axis=-1) + # In batch negatives + scores = self.compute_similarity(q_reps, p_reps) + # Substract margin from all positive samples cosine_sim() + margin_diag = paddle.full(shape=[q_reps.shape[0]], fill_value=self.margin, dtype=q_reps.dtype) + scores = scores - paddle.diag(margin_diag) + # Scale cosine to ease training converge + scores = scores / self.temperature + target = paddle.arange(0, q_reps.shape[0], dtype="int64") + loss = self.compute_loss(scores, target) + else: + if self.normalized: + q_reps = paddle.nn.functional.normalize(q_reps, axis=-1) + p_reps = paddle.nn.functional.normalize(p_reps, axis=-1) + scores = self.compute_similarity(q_reps, p_reps) + scores = scores / self.temperature + scores = scores.reshape([q_reps.shape[0], -1]) + + target = paddle.arange(scores.shape[0], dtype="int64") + target = target * (p_reps.shape[0] // q_reps.shape[0]) + loss = self.compute_loss(scores, target) + + else: + scores = self.compute_similarity(q_reps, p_reps) + loss = None + return EncoderOutput( + loss=loss, + scores=scores, + q_reps=q_reps, + p_reps=p_reps, + ) + + def compute_loss(self, scores, target): + return self.cross_entropy(scores, target) + + def _dist_gather_tensor(self, t: Optional[paddle.Tensor]): + if t is None: + return None + + all_tensors = [paddle.empty_like(t) for _ in range(self.world_size)] + dist.all_gather(all_tensors, t) + + all_tensors[self.process_rank] = t + all_tensors = paddle.concat(all_tensors, axis=0) + + return all_tensors + + @classmethod + def from_pretrained(cls, **kwargs): + # Instantiate model. + model = cls(**kwargs) + return model + + def save_pretrained(self, output_dir: str, **kwargs): + state_dict = self.model.state_dict() + state_dict = type(state_dict)({k: v.clone().cpu() for k, v in state_dict.items()}) + self.model.save_pretrained(output_dir, state_dict=state_dict) diff --git a/pipelines/examples/constrative_train/requirements.txt b/pipelines/examples/constrative_train/requirements.txt new file mode 100644 index 000000000000..5409f81204e2 --- /dev/null +++ b/pipelines/examples/constrative_train/requirements.txt @@ -0,0 +1,5 @@ +paddlenlp>2.6.1 +datasets +torch==2.0.1 +mteb[beir] +typer==0.9.0 \ No newline at end of file diff --git a/pipelines/examples/constrative_train/train.py b/pipelines/examples/constrative_train/train.py new file mode 100644 index 000000000000..da494062df18 --- /dev/null +++ b/pipelines/examples/constrative_train/train.py @@ -0,0 +1,157 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from arguments import DataArguments, ModelArguments +from arguments import RetrieverTrainingArguments as TrainingArguments +from data import EmbedCollator, TrainDatasetForEmbedding +from models.modeling import BiEncoderModel +from utils import BiTrainer + +from paddlenlp.peft import LoRAConfig, LoRAModel +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed +from paddlenlp.transformers import AutoTokenizer +from paddlenlp.utils.log import logger + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Set the dtype for loading model + dtype = None + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + else: + dtype = "float32" + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." + ) + + if training_args.pipeline_parallel_degree > 1 and training_args.negatives_cross_device: + raise ValueError("Pipeline parallelism does not support cross batch negatives.") + # Setup logging + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}," + + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}", + ) + logger.info(f"Training/evaluation parameters {training_args}") + logger.info(f"Model parameters {model_args}") + logger.info(f"Data parameters {data_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + # Set seed + set_seed(training_args.seed) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + use_fast=False, + ) + tokenizer.padding_side = "right" + + model = BiEncoderModel.from_pretrained( + model_name_or_path=model_args.model_name_or_path, + normalized=model_args.normalized, + sentence_pooling_method=training_args.sentence_pooling_method, + negatives_cross_device=training_args.negatives_cross_device, + temperature=training_args.temperature, + margin=training_args.margin, + use_inbatch_neg=training_args.use_inbatch_neg, + matryoshka_dims=training_args.matryoshka_dims if training_args.use_matryoshka else None, + matryoshka_loss_weights=training_args.matryoshka_loss_weights if training_args.use_matryoshka else None, + ) + + if training_args.fix_position_embedding: + for k, v in model.named_parameters(): + if "position_embeddings" in k: + logger.info(f"Freeze the parameters for {k}") + v.stop_gradient = True + + if training_args.fine_tune_type == "bitfit": + for k, v in model.named_parameters(): + # Only bias are allowed for training + if "bias" in k: + v.stop_gradient = False + else: + logger.info(f"Freeze the parameters for {k} shape: {v.shape}") + v.stop_gradient = True + + if training_args.fine_tune_type == "lora": + if "llama" in model_args.model_name_or_path or "baichuan" in model_args.model_name_or_path: + target_modules = [".*q_proj.*", ".*k_proj.*", ".*v_proj.*"] + else: + target_modules = [".*query_key_value.*"] + + lora_config = LoRAConfig( + target_modules=target_modules, + r=8, + lora_alpha=32, + dtype=dtype, + ) + model = LoRAModel(model, lora_config) + model.mark_only_lora_as_trainable() + model.print_trainable_parameters() + + train_dataset = TrainDatasetForEmbedding( + args=data_args, + tokenizer=tokenizer, + query_max_len=data_args.query_max_len, + passage_max_len=data_args.passage_max_len, + is_batch_negative=model_args.is_batch_negative, + ) + + trainer = BiTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=EmbedCollator( + tokenizer, + query_max_len=data_args.query_max_len, + passage_max_len=data_args.passage_max_len, + ), + tokenizer=tokenizer, + ) + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=last_checkpoint) + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + +if __name__ == "__main__": + main() diff --git a/pipelines/examples/constrative_train/utils.py b/pipelines/examples/constrative_train/utils.py new file mode 100644 index 000000000000..c5077e48a8ac --- /dev/null +++ b/pipelines/examples/constrative_train/utils.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlenlp.trainer import Trainer + + +class BiTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + outputs = model(inputs) + loss = outputs.loss + + return (loss, outputs) if return_outputs else loss From 1680b103be522a4e17a2a0ee97b2d4c090753624 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 21 Mar 2024 04:34:16 +0000 Subject: [PATCH 2/4] Update README.md --- pipelines/examples/constrative_train/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pipelines/examples/constrative_train/README.md b/pipelines/examples/constrative_train/README.md index a1f82a3d33ea..311de4ca74b7 100644 --- a/pipelines/examples/constrative_train/README.md +++ b/pipelines/examples/constrative_train/README.md @@ -1,4 +1,4 @@ -# generative-search +# 向量检索模型训练 ## 安装 @@ -106,3 +106,7 @@ python evaluation/benchmarks.py --model_type bert \ - `query_max_length`: query的最大长度 - `passage_max_length`: passage的最大长度 - `evaluate_all`: 是否评估所有的checkpoint,默认为False,即只评估指定的checkpoint + +## Reference + +[1] Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham M. Kakade, Prateek Jain, Ali Farhadi: Matryoshka Representation Learning. NeurIPS 2022 From 5c96a7c77b5ebce5c524e01169b2c2980051e72b Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 21 Mar 2024 09:35:27 +0000 Subject: [PATCH 3/4] Update modeling.py --- .../README.md | 4 +- .../arguments.py | 0 .../data.py | 0 .../data/toy_finetune_data.jsonl | 0 .../evaluation/__init__.py | 0 .../evaluation/benchmarks.py | 0 .../evaluation/prediction.py | 0 .../models/__init__.py | 0 .../models/modeling.py | 64 +++++++++++-------- .../requirements.txt | 0 .../train.py | 2 +- .../utils.py | 0 12 files changed, 39 insertions(+), 31 deletions(-) rename pipelines/examples/{constrative_train => contrastive_training}/README.md (95%) rename pipelines/examples/{constrative_train => contrastive_training}/arguments.py (100%) rename pipelines/examples/{constrative_train => contrastive_training}/data.py (100%) rename pipelines/examples/{constrative_train => contrastive_training}/data/toy_finetune_data.jsonl (100%) rename pipelines/examples/{constrative_train => contrastive_training}/evaluation/__init__.py (100%) rename pipelines/examples/{constrative_train => contrastive_training}/evaluation/benchmarks.py (100%) rename pipelines/examples/{constrative_train => contrastive_training}/evaluation/prediction.py (100%) rename pipelines/examples/{constrative_train => contrastive_training}/models/__init__.py (100%) rename pipelines/examples/{constrative_train => contrastive_training}/models/modeling.py (77%) rename pipelines/examples/{constrative_train => contrastive_training}/requirements.txt (100%) rename pipelines/examples/{constrative_train => contrastive_training}/train.py (99%) rename pipelines/examples/{constrative_train => contrastive_training}/utils.py (100%) diff --git a/pipelines/examples/constrative_train/README.md b/pipelines/examples/contrastive_training/README.md similarity index 95% rename from pipelines/examples/constrative_train/README.md rename to pipelines/examples/contrastive_training/README.md index 311de4ca74b7..13bd26e5c183 100644 --- a/pipelines/examples/constrative_train/README.md +++ b/pipelines/examples/contrastive_training/README.md @@ -2,7 +2,7 @@ ## 安装 -推荐安装gpu版本的[PaddlePalle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),以cuda11.7的paddle为例,安装命令如下: +推荐安装gpu版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),以cuda11.7的paddle为例,安装命令如下: ``` python -m pip install paddlepaddle-gpu==2.6.0.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html @@ -12,7 +12,7 @@ python -m pip install paddlepaddle-gpu==2.6.0.post117 -f https://www.paddlepaddl pip install -r requirements.txt ``` -下载dureader中文数据集: +下载DuReader-Retrieval中文数据集: ``` cd data diff --git a/pipelines/examples/constrative_train/arguments.py b/pipelines/examples/contrastive_training/arguments.py similarity index 100% rename from pipelines/examples/constrative_train/arguments.py rename to pipelines/examples/contrastive_training/arguments.py diff --git a/pipelines/examples/constrative_train/data.py b/pipelines/examples/contrastive_training/data.py similarity index 100% rename from pipelines/examples/constrative_train/data.py rename to pipelines/examples/contrastive_training/data.py diff --git a/pipelines/examples/constrative_train/data/toy_finetune_data.jsonl b/pipelines/examples/contrastive_training/data/toy_finetune_data.jsonl similarity index 100% rename from pipelines/examples/constrative_train/data/toy_finetune_data.jsonl rename to pipelines/examples/contrastive_training/data/toy_finetune_data.jsonl diff --git a/pipelines/examples/constrative_train/evaluation/__init__.py b/pipelines/examples/contrastive_training/evaluation/__init__.py similarity index 100% rename from pipelines/examples/constrative_train/evaluation/__init__.py rename to pipelines/examples/contrastive_training/evaluation/__init__.py diff --git a/pipelines/examples/constrative_train/evaluation/benchmarks.py b/pipelines/examples/contrastive_training/evaluation/benchmarks.py similarity index 100% rename from pipelines/examples/constrative_train/evaluation/benchmarks.py rename to pipelines/examples/contrastive_training/evaluation/benchmarks.py diff --git a/pipelines/examples/constrative_train/evaluation/prediction.py b/pipelines/examples/contrastive_training/evaluation/prediction.py similarity index 100% rename from pipelines/examples/constrative_train/evaluation/prediction.py rename to pipelines/examples/contrastive_training/evaluation/prediction.py diff --git a/pipelines/examples/constrative_train/models/__init__.py b/pipelines/examples/contrastive_training/models/__init__.py similarity index 100% rename from pipelines/examples/constrative_train/models/__init__.py rename to pipelines/examples/contrastive_training/models/__init__.py diff --git a/pipelines/examples/constrative_train/models/modeling.py b/pipelines/examples/contrastive_training/models/modeling.py similarity index 77% rename from pipelines/examples/constrative_train/models/modeling.py rename to pipelines/examples/contrastive_training/models/modeling.py index 8f1a719d5151..54b5c0f6b35f 100644 --- a/pipelines/examples/constrative_train/models/modeling.py +++ b/pipelines/examples/contrastive_training/models/modeling.py @@ -99,6 +99,28 @@ def compute_similarity(self, q_reps, p_reps): # p_reps [batch_size, embedding_dim] return paddle.matmul(q_reps, p_reps.transpose([1, 0])) + def hard_negative_loss(self, q_reps, p_reps): + scores = self.compute_similarity(q_reps, p_reps) + scores = scores / self.temperature + scores = scores.reshape([q_reps.shape[0], -1]) + + target = paddle.arange(scores.shape[0], dtype="int64") + target = target * (p_reps.shape[0] // q_reps.shape[0]) + loss = self.compute_loss(scores, target) + return scores, loss + + def in_batch_negative_loss(self, q_reps, p_reps): + # In batch negatives + scores = self.compute_similarity(q_reps, p_reps) + # Substract margin from all positive samples cosine_sim() + margin_diag = paddle.full(shape=[q_reps.shape[0]], fill_value=self.margin, dtype=q_reps.dtype) + scores = scores - paddle.diag(margin_diag) + # Scale cosine to ease training converge + scores = scores / self.temperature + target = paddle.arange(0, q_reps.shape[0], dtype="int64") + loss = self.compute_loss(scores, target) + return scores, loss + def forward( self, inputs: Dict[str, paddle.Tensor] = None, @@ -109,6 +131,12 @@ def forward( q_reps = self.encode(query) p_reps = self.encode(passage) + # For non-matryoshka loss, we normalize the representations + if not self.matryoshka_dims: + if self.normalized: + q_reps = paddle.nn.functional.normalize(q_reps, axis=-1) + p_reps = paddle.nn.functional.normalize(p_reps, axis=-1) + if self.training: # Cross device negatives if self.negatives_cross_device: @@ -117,45 +145,25 @@ def forward( if self.matryoshka_dims: loss = 0.0 + scores = 0.0 for loss_weight, dim in zip(self.matryoshka_loss_weights, self.matryoshka_dims): reduced_q = q_reps[:, :dim] reduced_d = p_reps[:, :dim] if self.normalized: reduced_q = paddle.nn.functional.normalize(reduced_q, axis=-1) reduced_d = paddle.nn.functional.normalize(reduced_d, axis=-1) - scores = self.compute_similarity(reduced_q, reduced_d) - scores = scores / self.temperature - scores = scores.reshape([q_reps.shape[0], -1]) - target = paddle.arange(scores.shape[0], dtype="int64") - target = target * (p_reps.shape[0] // q_reps.shape[0]) - dim_loss = self.compute_loss(scores, target) + if self.use_inbatch_neg: + dim_score, dim_loss = self.in_batch_negative_loss(reduced_q, reduced_d) + else: + dim_score, dim_loss = self.hard_negative_loss(reduced_q, reduced_d) + scores += dim_score loss += loss_weight * dim_loss elif self.use_inbatch_neg: - if self.normalized: - q_reps = paddle.nn.functional.normalize(q_reps, axis=-1) - p_reps = paddle.nn.functional.normalize(p_reps, axis=-1) - # In batch negatives - scores = self.compute_similarity(q_reps, p_reps) - # Substract margin from all positive samples cosine_sim() - margin_diag = paddle.full(shape=[q_reps.shape[0]], fill_value=self.margin, dtype=q_reps.dtype) - scores = scores - paddle.diag(margin_diag) - # Scale cosine to ease training converge - scores = scores / self.temperature - target = paddle.arange(0, q_reps.shape[0], dtype="int64") - loss = self.compute_loss(scores, target) + scores, loss = self.in_batch_negative_loss(q_reps, p_reps) else: - if self.normalized: - q_reps = paddle.nn.functional.normalize(q_reps, axis=-1) - p_reps = paddle.nn.functional.normalize(p_reps, axis=-1) - scores = self.compute_similarity(q_reps, p_reps) - scores = scores / self.temperature - scores = scores.reshape([q_reps.shape[0], -1]) - - target = paddle.arange(scores.shape[0], dtype="int64") - target = target * (p_reps.shape[0] // q_reps.shape[0]) - loss = self.compute_loss(scores, target) + scores, loss = self.hard_negative_loss(q_reps, p_reps) else: scores = self.compute_similarity(q_reps, p_reps) diff --git a/pipelines/examples/constrative_train/requirements.txt b/pipelines/examples/contrastive_training/requirements.txt similarity index 100% rename from pipelines/examples/constrative_train/requirements.txt rename to pipelines/examples/contrastive_training/requirements.txt diff --git a/pipelines/examples/constrative_train/train.py b/pipelines/examples/contrastive_training/train.py similarity index 99% rename from pipelines/examples/constrative_train/train.py rename to pipelines/examples/contrastive_training/train.py index da494062df18..b040f3e088ae 100644 --- a/pipelines/examples/constrative_train/train.py +++ b/pipelines/examples/contrastive_training/train.py @@ -131,7 +131,7 @@ def main(): tokenizer=tokenizer, query_max_len=data_args.query_max_len, passage_max_len=data_args.passage_max_len, - is_batch_negative=model_args.is_batch_negative, + is_batch_negative=training_args.use_inbatch_neg, ) trainer = BiTrainer( diff --git a/pipelines/examples/constrative_train/utils.py b/pipelines/examples/contrastive_training/utils.py similarity index 100% rename from pipelines/examples/constrative_train/utils.py rename to pipelines/examples/contrastive_training/utils.py From 36a79b24d978e9cb88e3d982a3755f8b9e58c67a Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 21 Mar 2024 09:51:29 +0000 Subject: [PATCH 4/4] reformat trainer --- .../contrastive_training/models/modeling.py | 5 ++-- .../examples/contrastive_training/train.py | 5 ++-- .../examples/contrastive_training/utils.py | 28 ------------------- 3 files changed, 4 insertions(+), 34 deletions(-) delete mode 100644 pipelines/examples/contrastive_training/utils.py diff --git a/pipelines/examples/contrastive_training/models/modeling.py b/pipelines/examples/contrastive_training/models/modeling.py index 54b5c0f6b35f..e8665d1ff398 100644 --- a/pipelines/examples/contrastive_training/models/modeling.py +++ b/pipelines/examples/contrastive_training/models/modeling.py @@ -123,11 +123,10 @@ def in_batch_negative_loss(self, q_reps, p_reps): def forward( self, - inputs: Dict[str, paddle.Tensor] = None, + query: Dict[str, paddle.Tensor] = None, + passage: Dict[str, paddle.Tensor] = None, teacher_score: paddle.Tensor = None, ): - query = inputs["query"] - passage = inputs["passage"] q_reps = self.encode(query) p_reps = self.encode(passage) diff --git a/pipelines/examples/contrastive_training/train.py b/pipelines/examples/contrastive_training/train.py index b040f3e088ae..b5891b9e36da 100644 --- a/pipelines/examples/contrastive_training/train.py +++ b/pipelines/examples/contrastive_training/train.py @@ -18,10 +18,9 @@ from arguments import RetrieverTrainingArguments as TrainingArguments from data import EmbedCollator, TrainDatasetForEmbedding from models.modeling import BiEncoderModel -from utils import BiTrainer from paddlenlp.peft import LoRAConfig, LoRAModel -from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed +from paddlenlp.trainer import PdArgumentParser, Trainer, get_last_checkpoint, set_seed from paddlenlp.transformers import AutoTokenizer from paddlenlp.utils.log import logger @@ -134,7 +133,7 @@ def main(): is_batch_negative=training_args.use_inbatch_neg, ) - trainer = BiTrainer( + trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, diff --git a/pipelines/examples/contrastive_training/utils.py b/pipelines/examples/contrastive_training/utils.py deleted file mode 100644 index c5077e48a8ac..000000000000 --- a/pipelines/examples/contrastive_training/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddlenlp.trainer import Trainer - - -class BiTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False): - """ - How the loss is computed by Trainer. By default, all models return the loss in the first element. - - Subclass and override for custom behavior. - """ - outputs = model(inputs) - loss = outputs.loss - - return (loss, outputs) if return_outputs else loss