Skip to content

[Paddle-Pipelines] Add matryoshka representation learning #8165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions pipelines/examples/contrastive_training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# 向量检索模型训练

## 安装

推荐安装gpu版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),以cuda11.7的paddle为例,安装命令如下:

```
python -m pip install paddlepaddle-gpu==2.6.0.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```
安装其他依赖:
```
pip install -r requirements.txt
```

下载DuReader-Retrieval中文数据集:

```
cd data
wget https://paddlenlp.bj.bcebos.com/datasets/dureader_dual.train.jsonl
```

## 运行

### 单卡训练

```
export CUDA_VISIBLE_DEVICES=0
python train.py --do_train \
--model_name_or_path rocketqa-zh-base-query-encoder \
--output_dir ./checkpoints \
--train_data ./data/dureader_dual.train.jsonl \
--overwrite_output_dir \
--fine_tune_type sft \
--sentence_pooling_method cls \
--num_train_epochs 3 \
--per_device_train_batch_size 64 \
--learning_rate 3e-5 \
--train_group_size 4 \
--recompute \
--passage_max_len 512 \
--use_matryoshka
```

- `model_name_or_path`: 选择预训练模型,可选rocketqa-zh-base-query-encoder
- `output_dir`: 模型保存路径
- `train_data`: 训练数据集路径,这里使用的是dureader中文数据集
- `overwrite_output_dir`: 是否覆盖模型保存路径,默认为False
- `fine_tune_type`: 训练模式,可选sft和lora, bitfit等策略
- `sentence_pooling_method`: 句子池化方法,可选cls和mean, cls为CLS层,mean为平均池化
- `num_train_epochs`: 训练轮数
- `per_device_train_batch_size`: 单卡训练batch大小
- `learning_rate`: 学习率
- `train_group_size`: 每个训练集正负样本的数据,默认为8,例如train_group_size=4,则每个训练集包含1个正样本和3个负样本
- `max_example_num_per_dataset`: 每个训练集的最大样本数,默认为100000000
- `recompute`: 是否重新计算,默认为False
- `query_max_len`: query的最大长度,默认为32
- `query_instruction_for_retrieval`: query的检索指令,默认为None
- `passage_instruction_for_retrieval`: passage的检索指令,默认为None
- `passage_max_len`: passage的最大长度,默认为512
- `use_matryoshka`: 是否使用俄罗斯套娃策略(matryoshka),默认为False
- `matryoshka_dims`: 俄罗斯套娃策略的维度,默认为[64, 128, 256, 512, 768]
- `matryoshka_loss_weights`: 俄罗斯套娃策略的损失权重,默认为[1, 1, 1, 1, 1]
- `use_inbatch_neg`: 是否使用in batch negatives策略,默认为False
- `use_flash_attention`: 是否使用flash attention,默认为False
- `temperature`: in batch negatives策略的temperature参数,默认为0.02
- `negatives_cross_device`: 跨设备in batch negatives策略,默认为False
- `margin`: in batch negatives策略的margin参数,默认为0.2

### 多卡训练

单卡训练效率过低,batch_size较小,建议使用多卡训练,对于对比学习训练推荐使用大batch_size,多卡训练,示例命令如下:

```
python -m paddle.distributed.launch --gpus "1,2,3,4" train.py --do_train \
--model_name_or_path rocketqa-zh-base-query-encoder \
--output_dir ./checkpoints \
--train_data ./data/dual.train.json \
--overwrite_output_dir \
--fine_tune_type sft \
--sentence_pooling_method cls \
--num_train_epochs 3 \
--per_device_train_batch_size 32 \
--learning_rate 3e-5 \
--train_group_size 8 \
--recompute \
--passage_max_len 512 \
--use_matryoshka
```

## 评估

评估脚本:

```
export CUDA_VISIBLE_DEVICES=0
python evaluation/benchmarks.py --model_type bert \
--query_model checkpoints/checkpoint-1500 \
--passage_model checkpoints/checkpoint-1500 \
--query_max_length 64 \
--passage_max_length 512 \
--evaluate_all
```
- `model_type`: 模型的类似,可选bert或roberta等等
- `query_model`: query向量模型的路径
- `passage_model`: passage向量模型的路径
- `query_max_length`: query的最大长度
- `passage_max_length`: passage的最大长度
- `evaluate_all`: 是否评估所有的checkpoint,默认为False,即只评估指定的checkpoint

## Reference

[1] Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham M. Kakade, Prateek Jain, Ali Farhadi: Matryoshka Representation Learning. NeurIPS 2022
101 changes: 101 additions & 0 deletions pipelines/examples/contrastive_training/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass, field
from typing import List, Optional

from paddlenlp.trainer import TrainingArguments


@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
)

normalized: bool = field(default=True)
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})


@dataclass
class DataArguments:
train_data: str = field(default=None, metadata={"help": "Path to train data"})
train_group_size: int = field(default=8)

query_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)

passage_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)

max_example_num_per_dataset: int = field(
default=100000000,
metadata={"help": "the max number of examples for each dataset"},
)

query_instruction_for_retrieval: str = field(default=None, metadata={"help": "instruction for query"})
passage_instruction_for_retrieval: str = field(default=None, metadata={"help": "instruction for passage"})

def __post_init__(self):
if not os.path.exists(self.train_data):
raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path")


@dataclass
class RetrieverTrainingArguments(TrainingArguments):
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
temperature: Optional[float] = field(default=0.02)
margin: Optional[float] = field(default=0.2)
fix_position_embedding: bool = field(
default=False, metadata={"help": "Freeze the parameters of position embeddings"}
)
sentence_pooling_method: str = field(
default="mean",
metadata={"help": "the pooling method, should be weighted_mean"},
)
fine_tune_type: str = field(
default="sft",
metadata={"help": "fine-tune type for retrieval,eg: sft, bitfit, lora"},
)
use_inbatch_neg: bool = field(default=False, metadata={"help": "use passages in the same batch as negatives"})

use_matryoshka: bool = field(default=False, metadata={"help": "use matryoshka for flexible embedding size"})

matryoshka_dims: List[int] = field(
default_factory=lambda: [64, 128, 256, 512, 768],
metadata={"help": "matryoshka dims"},
)
matryoshka_loss_weights: List[float] = field(
default_factory=lambda: [1, 1, 1, 1, 1],
metadata={"help": "matryoshka loss weights"},
)
139 changes: 139 additions & 0 deletions pipelines/examples/contrastive_training/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os.path
import random
from dataclasses import dataclass

import datasets
from arguments import DataArguments
from paddle.io import Dataset

from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.transformers import PretrainedTokenizer


class TrainDatasetForEmbedding(Dataset):
def __init__(
self,
args: DataArguments,
tokenizer: PretrainedTokenizer,
query_max_len: int = 64,
passage_max_len: int = 1048,
is_batch_negative: bool = False,
):
if os.path.isdir(args.train_data):
train_datasets = []
for file in os.listdir(args.train_data):
temp_dataset = datasets.load_dataset(
"json",
data_files=os.path.join(args.train_data, file),
split="train",
)
if len(temp_dataset) > args.max_example_num_per_dataset:
temp_dataset = temp_dataset.select(
random.sample(
list(range(len(temp_dataset))),
args.max_example_num_per_dataset,
)
)
train_datasets.append(temp_dataset)
self.dataset = datasets.concatenate_datasets(train_datasets)
else:
self.dataset = datasets.load_dataset("json", data_files=args.train_data, split="train")
self.tokenizer = tokenizer
self.args = args
self.total_len = len(self.dataset)
self.query_max_len = query_max_len
self.passage_max_len = passage_max_len
self.is_batch_negative = is_batch_negative

def __len__(self):
return self.total_len

def __getitem__(self, item):
query = self.dataset[item]["query"]
if self.args.query_instruction_for_retrieval is not None:
query = self.args.query_instruction_for_retrieval + query
query = self.tokenizer(
query,
truncation=True,
max_length=self.query_max_len,
return_attention_mask=False,
truncation_side="right",
)
passages = []
pos = random.choice(self.dataset[item]["pos"])
passages.append(pos)
# Add negative examples
if not self.is_batch_negative:
if len(self.dataset[item]["neg"]) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]["neg"]))
negs = random.sample(self.dataset[item]["neg"] * num, self.args.train_group_size - 1)
else:
negs = random.sample(self.dataset[item]["neg"], self.args.train_group_size - 1)
passages.extend(negs)

if self.args.passage_instruction_for_retrieval is not None:
passages = [self.args.passage_instruction_for_retrieval + p for p in passages]
passages = self.tokenizer(
passages,
truncation=True,
max_length=self.passage_max_len,
return_attention_mask=False,
truncation_side="right",
)
# Convert passages to input_ids
passages_tackle = []
for i in range(len(passages["input_ids"])):
passages_tackle.append({"input_ids": passages["input_ids"][i]})
return query, passages_tackle


@dataclass
class EmbedCollator(DataCollatorWithPadding):
"""
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
and pass batch separately to the actual collator.
Abstract out data detail for the model.
"""

query_max_len: int = 32
passage_max_len: int = 128

def __call__(self, features):
query = [f[0] for f in features]
passage = [f[1] for f in features]
if isinstance(query[0], list):
query = sum(query, [])
if isinstance(passage[0], list):
passage = sum(passage, [])
q_collated = self.tokenizer.pad(
query,
padding="max_length",
max_length=self.query_max_len,
return_attention_mask=True,
pad_to_multiple_of=None,
return_tensors="pd",
)
d_collated = self.tokenizer.pad(
passage,
padding="max_length",
max_length=self.passage_max_len,
return_attention_mask=True,
pad_to_multiple_of=None,
return_tensors="pd",
)
return {"query": q_collated, "passage": d_collated}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{"query": "Five women walk along a beach wearing flip-flops.", "pos": ["Some women with flip-flops on, are walking along the beach"], "neg": ["The 4 women are sitting on the beach.", "There was a reform in 1996.", "She's not going to court to clear her record.", "The man is talking about hawaii.", "A woman is standing outside.", "The battle was over. ", "A group of people plays volleyball."]}
{"query": "A woman standing on a high cliff on one leg looking over a river.", "pos": ["A woman is standing on a cliff."], "neg": ["A woman sits on a chair.", "George Bush told the Republicans there was no way he would let them even consider this foolish idea, against his top advisors advice.", "The family was falling apart.", "no one showed up to the meeting", "A boy is sitting outside playing in the sand.", "Ended as soon as I received the wire.", "A child is reading in her bedroom."]}
{"query": "Two woman are playing instruments; one a clarinet, the other a violin.", "pos": ["Some people are playing a tune."], "neg": ["Two women are playing a guitar and drums.", "A man is skiing down a mountain.", "The fatal dose was not taken when the murderer thought it would be.", "Person on bike", "The girl is standing, leaning against the archway.", "A group of women watch soap operas.", "No matter how old people get they never forget. "]}
{"query": "A girl with a blue tank top sitting watching three dogs.", "pos": ["A girl is wearing blue."], "neg": ["A girl is with three cats.", "The people are watching a funeral procession.", "The child is wearing black.", "Financing is an issue for us in public schools.", "Kids at a pool.", "It is calming to be assaulted.", "I face a serious problem at eighteen years old. "]}
{"query": "A yellow dog running along a forest path.", "pos": ["a dog is running"], "neg": ["a cat is running", "Steele did not keep her original story.", "The rule discourages people to pay their child support.", "A man in a vest sits in a car.", "Person in black clothing, with white bandanna and sunglasses waits at a bus stop.", "Neither the Globe or Mail had comments on the current state of Canada's road system. ", "The Spring Creek facility is old and outdated."]}
{"query": "It sets out essential activities in each phase along with critical factors related to those activities.", "pos": ["Critical factors for essential activities are set out."], "neg": ["It lays out critical activities but makes no provision for critical factors related to those activities.", "People are assembled in protest.", "The state would prefer for you to do that.", "A girl sits beside a boy.", "Two males are performing.", "Nobody is jumping", "Conrad was being plotted against, to be hit on the head."]}
{"query": "A man giving a speech in a restaurant.", "pos": ["A person gives a speech."], "neg": ["The man sits at the table and eats food.", "This is definitely not an endorsement.", "They sold their home because they were retiring and not because of the loan.", "The seal of Missouri is perfect.", "Someone is raising their hand.", "An athlete is competing in the 1500 meter swimming competition.", "Two men watching a magic show."]}
{"query": "Indians having a gathering with coats and food and drinks.", "pos": ["A group of Indians are having a gathering with food and drinks"], "neg": ["A group of Indians are having a funeral", "It is only staged on Winter afternoons in Palma's large bullring.", "Right information can empower the legal service practices and the justice system. ", "Meanwhile, the mainland was empty of population.", "Two children is sleeping.", "a fisherman is trying to catch a monkey", "the people are in a train"]}
{"query": "A woman with violet hair rides her bicycle outside.", "pos": ["A woman is riding her bike."], "neg": ["A woman is jogging in the park.", "The street was lined with white-painted houses.", "A group watches a movie inside.", "man at picnics cut steak", "Several chefs are sitting down and talking about food.", "The Commission notes that no significant alternatives were considered.", "We ran out of firewood and had to use pine needles for the fire."]}
{"query": "A man pulls two women down a city street in a rickshaw.", "pos": ["A man is in a city."], "neg": ["A man is a pilot of an airplane.", "It is boring and mundane.", "The morning sunlight was shining brightly and it was warm. ", "Two people jumped off the dock.", "People watching a spaceship launch.", "Mother Teresa is an easy choice.", "It's worth being able to go at a pace you prefer."]}
13 changes: 13 additions & 0 deletions pipelines/examples/contrastive_training/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading