Skip to content

Use vidore benchmark to monitor performances during training #195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

### Added

- Add the possibility for a user to evaluate a model on retrieval datasets (e.g ViDoRe benchmark) during its training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Add the possibility for a user to evaluate a model on retrieval datasets (e.g ViDoRe benchmark) during its training.
- Add `BenchmarkEvalCallback` to evaluate a model on retrieval datasets (e.g ViDoRe benchmark) during its training and display the metrics on Weight&Biases.

- Add `LambdaTokenPooler` to allow for custom token pooling functions.
- Added training losses with negatives to InfoNCE type losses

Expand Down
71 changes: 67 additions & 4 deletions colpali_engine/trainer/colmodel_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,43 @@
ColbertLoss,
)
from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
from colpali_engine.trainer.eval_utils import BenchmarkEvalCallback, evaluate_dataset
from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor


@dataclass
class ColModelTrainingConfig:
"""Configuration for training a ColVision model.

Args:
model (Union[PreTrainedModel, PeftModel]): Base model to train.
processor (BaseVisualRetrieverProcessor): Processor for visual data processing.
tr_args (Optional[TrainingArguments]): Transformers training arguments. If not provided, uses default values.
output_dir (Optional[str]): Output directory to save the model.
If not provided, creates a path based on model name.
max_length (int): Maximum sequence length for inputs. Default: 256.
run_eval (bool): If True, runs evaluation. Default: True.
run_train (bool): If True, runs training. Default: True.
vidore_eval_frequency (int): Vidore evaluation frequency, must be a multiple of tr_args.eval_steps.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe an assert to guarantee this ?

Pass -1 to disable. Default: -1.
eval_dataset_format (str): Evaluation dataset format ("qa" or "beir"). Default: "qa".
peft_config (Optional[LoraConfig]): PEFT configuration if used. Default: None.
loss_func (Optional[Callable]): Custom loss function. Default: ColbertLoss().
dataset_loading_func (Optional[Callable]): Dataset loading function. Default: None.
eval_dataset_loader (Optional[Dict[str, Callable]]): Evaluation dataset loaders. Default: None.
pretrained_peft_model_name_or_path (Optional[str]): Path to a pretrained PEFT model. Default: None.
"""

model: Union[PreTrainedModel, PeftModel]
processor: BaseVisualRetrieverProcessor
tr_args: Optional[TrainingArguments] = None
output_dir: Optional[str] = None
max_length: int = 256
run_eval: bool = True
run_train: bool = True
vidore_eval_frequency: int = -1
eval_dataset_format: str = "qa"
peft_config: Optional[LoraConfig] = None
loss_func: Optional[Callable] = ColbertLoss()
dataset_loading_func: Optional[Callable] = None
Expand Down Expand Up @@ -118,16 +142,55 @@ def train(self) -> None:

trainer.args.remove_unused_columns = False

if self.config.processor is not None and self.config.vidore_eval_frequency > 0:
trainer.add_callback(
BenchmarkEvalCallback(
processor=self.config.processor,
model=self.model,
eval_dataset_loader=self.config.eval_dataset_loader,
batch_query=self.config.tr_args.per_device_eval_batch_size,
batch_passage=self.config.tr_args.per_device_eval_batch_size,
batch_score=4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's super low, you can probably push this to 256 at least

run_frequency=self.config.vidore_eval_frequency,
dataset_format=self.config.eval_dataset_format,
)
)

result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint)
print_summary(result)

def eval(self) -> None:
raise NotImplementedError("Evaluation is not implemented yet.")
all_metrics = {}

all_metrics["validation_set"] = evaluate_dataset(
model=self.model,
processor=self.config.processor,
dataset=self.dataset["test"],
format="qa",
batch_passage=self.config.tr_args.per_device_eval_batch_size,
batch_query=self.config.tr_args.per_device_eval_batch_size,
batch_score=self.config.tr_args.per_device_eval_batch_size,
)

if self.config.eval_dataset_loader is not None:
# Create a vision retriever with the current model checkpoint.
eval_dataset_format = getattr(self.config.tr_args, "eval_dataset_format", "beir")

for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items():
print(f"Evaluating {test_name}")
all_metrics[test_name] = evaluate_dataset(
model=self.model,
processor=self.config.processor,
dataset=test_dataset_loading_func(),
format=eval_dataset_format,
batch_passage=self.config.tr_args.per_device_eval_batch_size,
batch_query=self.config.tr_args.per_device_eval_batch_size,
batch_score=self.config.tr_args.per_device_eval_batch_size,
)
print(f"Metrics for {test_name}: {all_metrics[test_name]}")

def save(self, config_file: str):
"""
Save the model with its training config, as well as the tokenizer and processor if provided.
"""
# save model
self.model.save_pretrained(self.config.output_dir)
self.config.processor.save_pretrained(self.config.output_dir)

Expand Down
151 changes: 151 additions & 0 deletions colpali_engine/trainer/eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# from mteb.evaluation.evaluators.RetrievalEvaluator
from __future__ import annotations

import logging
from typing import Dict, Union

from datasets import Dataset
from peft import PeftModel
from transformers import PreTrainedModel, TrainerControl, TrainerState, TrainingArguments
from transformers.integrations import WandbCallback
from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No circular imports anymore ?

from vidore_benchmark.retrievers import VisionRetriever

from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor

logger = logging.getLogger(__name__)


METRICS_TO_TRACK = [
"ndcg_at_1",
"ndcg_at_3",
"ndcg_at_5",
"ndcg_at_10",
"ndcg_at_50",
"ndcg_at_100",
"recall_at_1",
"recall_at_3",
"recall_at_5",
"recall_at_10",
"recall_at_50",
"recall_at_100",
"map_at_1",
"map_at_3",
"map_at_5",
"map_at_10",
"map_at_50",
"map_at_100",
]
Comment on lines +19 to +38
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is quite a lot to keep track of in a wandb window, especially on multiple datasets.

What are the few ones we should keep ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would pick nDCG@k and recall@k for $k \in {1, 5, 10}$ or $k \in {1, 3, 5, 10}$.



class BenchmarkEvalCallback(WandbCallback):
def __init__(
self,
processor,
model,
eval_dataset_loader,
batch_query: int = 4,
batch_passage: int = 4,
batch_score: int = 4,
run_frequency: int = 5,
dataset_format: str = "beir",
):
"""
Callback to evaluate the model on a collection of datasets during training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add mention of Wandb in the docstring.


Args:
processor: The processor to use for the model.
model: The model to evaluate.
eval_dataset_loader: A dictionary of dataset loading functions.
batch_query: Batch size for query.
batch_passage: Batch size for passage.
batch_score: Batch size for scoring.
run_frequency: Frequency of evaluation in steps.
dataset_format: Format of the evaluation dataset, either "beir" or "qa".
"""
self.processor = processor
self.model = model
self.eval_dataset_loader = eval_dataset_loader
self.batch_query = batch_query
self.batch_passage = batch_passage
self.batch_score = batch_score
self.eval_steps_frequency = run_frequency
self.counter_eval = 0
self.eval_dataset_format = dataset_format
super().__init__()

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.global_step % self.eval_steps_frequency != 0:
self.counter_eval += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do you use this ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need, it's an artefact from a previous implementation

return
else:
self.counter_eval = 1

if self.processor is None:
print("Processor not provided. Skipping benchmark evaluation.")
return

print(f"\n=== Running benchmark evaluation at global step {state.global_step} ===")

# Evaluate on a collection.
if self.eval_dataset_loader is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for this to be none ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, since this eval is deactivated by default I guess that a user wanting to use this should have specified eval datasets

try:
metrics_collection = {}
for test_name, test_dataset_loading_func in self.eval_dataset_loader.items():
metrics = evaluate_dataset(
model=self.model,
processor=self.processor,
dataset=test_dataset_loading_func(),
format=self.eval_dataset_format,
batch_passage=self.batch_passage,
batch_query=self.batch_query,
batch_score=self.batch_score,
)
metrics_collection[test_name] = {k: v for k, v in metrics.items() if k in METRICS_TO_TRACK}
print(f"Benchmark metrics for tests datasets at step {state.global_step}:")
print(metrics_collection)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this look like? If it's not already formatted (since it's a dict), you can try pprint.pprint if needed!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think we could remove the print entirely, it looks a bit messy when training

print("logging metrics to wandb")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("logging metrics to wandb")

self._wandb.log(metrics_collection)
except Exception as e:
print(f"Error during benchmark evaluation on collection '{self.eval_collection}': {e}")
Comment on lines +109 to +110
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update, not relevant anymore


# Set model back to train mode.
self.model.train()
return


def evaluate_dataset(
model: Union[PreTrainedModel, PeftModel],
processor: BaseVisualRetrieverProcessor,
dataset: Dataset,
format: str = "beir",
batch_passage: int = 4,
batch_query: int = 4,
batch_score: int = 4,
) -> Dict[str, float]:
"""
Evaluate a dataset using the vidore-benchmark library.
"""
model.eval()

# Create a vision retriever with the current model checkpoint.
vision_retriever = VisionRetriever(
model=model,
processor=processor,
)

if format == "qa":
vidore_evaluator = ViDoReEvaluatorQA(vision_retriever)
elif format == "beir":
vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever)
else:
raise ValueError(f"Invalid dataset format: {format}, must be 'qa' or 'beir'")

metrics = vidore_evaluator.evaluate_dataset(
ds=dataset,
batch_query=batch_query,
batch_passage=batch_passage,
batch_score=batch_score,
)

return metrics
15 changes: 15 additions & 0 deletions colpali_engine/utils/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,21 @@ def __call__(self, *args, **kwargs):
return dataset


class TestSetFactoryBEIR:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 things here:

  1. Need a short docstring here
  2. It's not a factory per se (a factory is a design pattern that provides a way to create objects without specifying the exact class of object that will be created, which is not the case here). I think we should keep the implementation as simple as possible. Here is a recommendation below.
def load_beir_test_dataset(dataset_path: str, split: str = "test") -> Dict[str, Dataset]:
    return {
        "corpus": cast(Dataset, load_dataset(dataset_path, name="corpus", split=split)),
        "queries": cast(Dataset, load_dataset(dataset_path, name="queries", split=split)),
        "qrels": cast(Dataset, load_dataset(dataset_path, name="qrels", split=split)),
    }

Wdyt?

def __init__(self, dataset_path):
self.dataset_path = dataset_path

def __call__(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 questions here:

  1. Is there a reason for keeping *args, **kwargs?
  2. Can we expose split: str = "test" in the __call__ args?
Suggested change
def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs):

split = "test"
dataset = {
"corpus": cast(Dataset, load_dataset(self.dataset_path, name="corpus", split=split)),
"queries": cast(Dataset, load_dataset(self.dataset_path, name="queries", split=split)),
"qrels": cast(Dataset, load_dataset(self.dataset_path, name="qrels", split=split)),
}

return dataset


if __name__ == "__main__":
ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
print(ds)
65 changes: 65 additions & 0 deletions scripts/configs/qwen2/train_colqwen2_model_eval_vidore.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
config:
(): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig
output_dir: !path ../../../models/colqwen2-ba256-5e-0304
processor:
(): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor
pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
max_num_visual_tokens: 1024

model:
(): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper
class_to_instanciate: !ext colpali_engine.models.ColQwen2
pretrained_model_name_or_path: "./models/base_models/colqwen2-base"
torch_dtype: !ext torch.bfloat16
use_cache: false
attn_implementation: "flash_attention_2"

dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set
eval_dataset_loader: !import ../data/test_data.yaml
vidore_eval_frequency: 200
eval_dataset_format: "qa"

# max_length: 50
run_eval: true
loss_func:
(): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss
tr_args:
(): transformers.training_args.TrainingArguments
output_dir: null
overwrite_output_dir: true
num_train_epochs: 5
per_device_train_batch_size: 64
gradient_checkpointing: true
gradient_checkpointing_kwargs: { "use_reentrant": false }
# 6 x 8 gpus = 48 batch size
# gradient_accumulation_steps: 4
per_device_eval_batch_size: 8
eval_strategy: "steps"
dataloader_num_workers: 8
# bf16: true
save_steps: 500
logging_steps: 10
eval_steps: 100
warmup_steps: 100
learning_rate: 2e-4
save_total_limit: 1
# resume_from_checkpoint: true
# optim: "paged_adamw_8bit"
# wandb logging
# wandb_project: "colqwen2"
# run_name: "colqwen2-ba32-nolora"
report_to: "wandb"


peft_config:
(): peft.LoraConfig
r: 32
lora_alpha: 32
lora_dropout: 0.1
init_lora_weights: "gaussian"
bias: "none"
task_type: "FEATURE_EXTRACTION"
target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
# target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'