From bf100eabc6d4f087c38a787a3ed2cd15bb05ac0d Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Fri, 14 Feb 2025 16:47:35 +0100 Subject: [PATCH 01/16] first draft wroking eval benchmark --- colpali_engine/trainer/colmodel_training.py | 13 ++ colpali_engine/trainer/eval_utils.py | 244 ++++++++++++++++++++ 2 files changed, 257 insertions(+) create mode 100644 colpali_engine/trainer/eval_utils.py diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 3a5fe913..02f094cc 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -13,6 +13,7 @@ ColbertLoss, ) from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer +from colpali_engine.trainer.eval_utils import BenchmarkEvalCallback from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor @@ -118,6 +119,18 @@ def train(self) -> None: trainer.args.remove_unused_columns = False + if self.config.processor is not None: + trainer.add_callback( + BenchmarkEvalCallback( + processor=self.config.processor, + model=self.model, + eval_dataset_loader=self.config.eval_dataset_loader, + batch_query=self.config.tr_args.per_device_eval_batch_size, + batch_passage=4, + batch_score=4, + ) + ) + result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint) print_summary(result) diff --git a/colpali_engine/trainer/eval_utils.py b/colpali_engine/trainer/eval_utils.py new file mode 100644 index 00000000..4ca51d7c --- /dev/null +++ b/colpali_engine/trainer/eval_utils.py @@ -0,0 +1,244 @@ +# from mteb.evaluation.evaluators.RetrievalEvaluator +from __future__ import annotations + +import logging +from typing import Dict + +import numpy as np +import pytrec_eval +from mteb.evaluation.evaluators.RetrievalEvaluator import RetrievalEvaluator +from mteb.evaluation.evaluators.utils import ( + confidence_scores, + hole, + mrr, + nAUC, + recall_cap, + top_k_accuracy, +) +from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments +from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA +from vidore_benchmark.retrievers import VisionRetriever + +logger = logging.getLogger(__name__) + + +class CustomRetrievalEvaluator: + """ + Wrapper class for the MTEB retrieval evaluator. + """ + + def __init__(self, k_values: list[int] = [1, 3, 5, 10, 20, 50, 100]): + self.k_values = k_values + + def compute_mteb_metrics( + self, + relevant_docs: Dict[str, dict[str, int]], + results: Dict[str, dict[str, float]], + **kwargs, + ) -> Dict[str, float]: + """ + Compute the MTEB retrieval metrics. + """ + ndcg, _map, recall, precision, naucs = self.evaluate( + relevant_docs, + results, + self.k_values, + ignore_identical_ids=kwargs.get("ignore_identical_ids", True), + ) + + mrr = self.evaluate_custom(relevant_docs, results, self.k_values, "mrr") + + scores = { + **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, + **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, + **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, + **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, + **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()}, + **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()}, + } + return scores + + @staticmethod + def evaluate( + qrels: dict[str, dict[str, int]], + results: dict[str, dict[str, float]], + k_values: list[int], + ignore_identical_ids: bool = False, + ) -> tuple[ + dict[str, float], + dict[str, float], + dict[str, float], + dict[str, float], + dict[str, float], + ]: + if ignore_identical_ids: + logger.debug( + "For evaluation, ``ignore_identical_ids=True`` is set to True, the evaluator will ignore " + "identical query and document ids." + ) + # Remove identical ids from results dict + for qid, rels in results.items(): + for pid in list(rels): + if qid == pid: + results[qid].pop(pid) + else: + logger.debug( + "For evaluation, we DO NOT ignore identical query and document ids (default), please explicitly " + "set ``ignore_identical_ids=True`` to ignore this." + ) + + all_ndcgs, all_aps, all_recalls, all_precisions = {}, {}, {}, {} + + for k in k_values: + all_ndcgs[f"NDCG@{k}"] = [] + all_aps[f"MAP@{k}"] = [] + all_recalls[f"Recall@{k}"] = [] + all_precisions[f"P@{k}"] = [] + + map_string = "map_cut." + ",".join([str(k) for k in k_values]) + ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) + recall_string = "recall." + ",".join([str(k) for k in k_values]) + precision_string = "P." + ",".join([str(k) for k in k_values]) + evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string}) + scores = evaluator.evaluate(results) + + for query_id in scores.keys(): + for k in k_values: + all_ndcgs[f"NDCG@{k}"].append(scores[query_id]["ndcg_cut_" + str(k)]) + all_aps[f"MAP@{k}"].append(scores[query_id]["map_cut_" + str(k)]) + all_recalls[f"Recall@{k}"].append(scores[query_id]["recall_" + str(k)]) + all_precisions[f"P@{k}"].append(scores[query_id]["P_" + str(k)]) + + ndcg, _map, recall, precision = ( + all_ndcgs.copy(), + all_aps.copy(), + all_recalls.copy(), + all_precisions.copy(), + ) + + for k in k_values: + ndcg[f"NDCG@{k}"] = round(sum(ndcg[f"NDCG@{k}"]) / len(scores), 5) + _map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5) + recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5) + precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5) + + naucs = RetrievalEvaluator.evaluate_abstention( + results, {**all_ndcgs, **all_aps, **all_recalls, **all_precisions} + ) + + return ndcg, _map, recall, precision, naucs + + @staticmethod + def evaluate_custom( + qrels: dict[str, dict[str, int]], + results: dict[str, dict[str, float]], + k_values: list[int], + metric: str, + output_type: str = "all", + ) -> tuple[dict[str, float], dict[str, float]]: + if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]: + metric_scores = mrr(qrels, results, k_values, output_type) + + elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]: + metric_scores = recall_cap(qrels, results, k_values, output_type) + + elif metric.lower() in ["hole", "hole@k"]: + metric_scores = hole(qrels, results, k_values, output_type) + + elif metric.lower() in [ + "acc", + "top_k_acc", + "accuracy", + "accuracy@k", + "top_k_accuracy", + ]: + metric_scores = top_k_accuracy(qrels, results, k_values, output_type) + + naucs = RetrievalEvaluator.evaluate_abstention(results, metric_scores) + metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()} + + return metric_scores_avg, naucs + + @staticmethod + def evaluate_abstention( + results: dict[str, dict[str, float]], + metric_scores: dict[str, list[float]], + ) -> dict[str, float]: + """Computes normalized Area Under the Curve on a set of evaluated instances as presented in + the paper https://arxiv.org/abs/2402.12997""" + all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())] + all_conf_scores = [confidence_scores(sim_scores) for sim_scores in all_sim_scores] + conf_fcts = list(all_conf_scores[0].keys()) + all_conf_scores = {fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts} + metric_scores = {k: np.array(v) for k, v in metric_scores.items()} + naucs = {} + + for metric_name, scores in metric_scores.items(): + for fct, conf_scores in all_conf_scores.items(): + naucs[f"nAUC_{metric_name}_{fct}"] = nAUC(conf_scores, scores) + + return naucs + + +class BenchmarkEvalCallback(TrainerCallback): + def __init__( + self, + processor, + model, + eval_dataset_loader, + batch_query: int = 4, + batch_passage: int = 4, + batch_score: int = 4, + ): + """ + :param processor: The processor instance (e.g., ColIdefics3Processor) needed for retrieval. + :param eval_dataset_name: The name of the single benchmark dataset to evaluate on. + :param eval_collection: The name of the collection (e.g., from Hugging Face Hub) to evaluate. + :param batch_query: Batch size for queries. + :param batch_passage: Batch size for passages. + :param batch_score: Batch size for scoring. + """ + self.processor = processor + self.model = model + self.eval_dataset_loader = eval_dataset_loader + self.batch_query = batch_query + self.batch_passage = batch_passage + self.batch_score = batch_score + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.processor is None: + print("Processor not provided. Skipping benchmark evaluation.") + return + + print(f"\n=== Running benchmark evaluation at global step {state.global_step} ===") + # Set model to evaluation mode. + self.model.eval() + + # Create a vision retriever with the current model checkpoint. + vision_retriever = VisionRetriever( + model=self.model, + processor=self.processor, + ) + vidore_evaluator = ViDoReEvaluatorQA(vision_retriever) + + # Evaluate on a collection. + if self.eval_dataset_loader is not None: + try: + metrics_collection = {} + for test_name, test_dataset_loading_func in self.eval_dataset_loader.items(): + ds_coll = test_dataset_loading_func() + metrics = vidore_evaluator.evaluate_dataset( + ds=ds_coll, + batch_query=self.batch_query, + batch_passage=self.batch_passage, + batch_score=self.batch_score, + ) + metrics_collection[test_name] = metrics + print(f"Benchmark metrics for tests datasets at step {state.global_step}:") + print(metrics_collection) + except Exception as e: + print(f"Error during benchmark evaluation on collection '{self.eval_collection}': {e}") + + # Set model back to train mode. + self.model.train() + return From 1361ae0cc5137004f2112320545837160e4d3b21 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Mon, 17 Feb 2025 17:18:24 +0100 Subject: [PATCH 02/16] beir support + QoL changes --- colpali_engine/trainer/colmodel_training.py | 41 +++++++++++-------- colpali_engine/trainer/eval_utils.py | 28 +++++++++++-- .../utils/dataset_transformation.py | 15 +++++++ 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 02f094cc..dd7c2a60 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -119,28 +119,35 @@ def train(self) -> None: trainer.args.remove_unused_columns = False - if self.config.processor is not None: - trainer.add_callback( - BenchmarkEvalCallback( - processor=self.config.processor, - model=self.model, - eval_dataset_loader=self.config.eval_dataset_loader, - batch_query=self.config.tr_args.per_device_eval_batch_size, - batch_passage=4, - batch_score=4, + if self.config.processor is not None and self.config.tr_args.get("run_vidore_evaluator", False): + vidore_eval_dataset_name = self.config.tr_args.get("vidore_eval_dataset_name", None) + vidore_eval_collection_name = self.config.tr_args.get("vidore_eval_collection_name", None) + + if vidore_eval_dataset_name is not None and vidore_eval_collection_name is not None: + raise ValueError( + "Both vidore_eval_dataset_name and vidore_eval_collection_name are provided. " + "You should only provide one of the two" + ) + elif vidore_eval_dataset_name is None and vidore_eval_collection_name is None: + print("WARNING : No dataset provided for ViDoRe evaluation. Skipping evaluation.") + else: + trainer.add_callback( + BenchmarkEvalCallback( + processor=self.config.processor, + model=self.model, + eval_dataset_loader=self.config.eval_dataset_loader, + batch_query=self.config.tr_args.per_device_eval_batch_size, + batch_passage=4, + batch_score=4, + run_frequency=self.config.tr_args.get("eval_steps_frequency", 5), + ) ) - ) result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint) print_summary(result) - def eval(self) -> None: - raise NotImplementedError("Evaluation is not implemented yet.") - - def save(self, config_file: str): - """ - Save the model with its training config, as well as the tokenizer and processor if provided. - """ + def save(self, config_file): + # save model self.model.save_pretrained(self.config.output_dir) self.config.processor.save_pretrained(self.config.output_dir) diff --git a/colpali_engine/trainer/eval_utils.py b/colpali_engine/trainer/eval_utils.py index 4ca51d7c..3efd3ca0 100644 --- a/colpali_engine/trainer/eval_utils.py +++ b/colpali_engine/trainer/eval_utils.py @@ -16,7 +16,7 @@ top_k_accuracy, ) from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments -from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA +from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA from vidore_benchmark.retrievers import VisionRetriever logger = logging.getLogger(__name__) @@ -189,14 +189,17 @@ def __init__( batch_query: int = 4, batch_passage: int = 4, batch_score: int = 4, + run_frequency: int = 5, + dataset_format: str = "beir", ): """ :param processor: The processor instance (e.g., ColIdefics3Processor) needed for retrieval. - :param eval_dataset_name: The name of the single benchmark dataset to evaluate on. - :param eval_collection: The name of the collection (e.g., from Hugging Face Hub) to evaluate. + :eval_dataset_loader: A dictionary with the test dataset names as keys and functions that load the datasets as + values. :param batch_query: Batch size for queries. :param batch_passage: Batch size for passages. :param batch_score: Batch size for scoring. + :param run_frequency: Frequency of evaluation ver the evaluation triggers. """ self.processor = processor self.model = model @@ -204,8 +207,17 @@ def __init__( self.batch_query = batch_query self.batch_passage = batch_passage self.batch_score = batch_score + self.eval_steps_frequency = run_frequency + self.counter_eval = 0 + self.eval_dataset_format = dataset_format def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.counter_eval % self.eval_steps_frequency != 0: + self.counter_eval += 1 + return + else: + self.counter_eval = 1 + if self.processor is None: print("Processor not provided. Skipping benchmark evaluation.") return @@ -219,7 +231,6 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra model=self.model, processor=self.processor, ) - vidore_evaluator = ViDoReEvaluatorQA(vision_retriever) # Evaluate on a collection. if self.eval_dataset_loader is not None: @@ -227,6 +238,15 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra metrics_collection = {} for test_name, test_dataset_loading_func in self.eval_dataset_loader.items(): ds_coll = test_dataset_loading_func() + + # Temporary before we are caplable of detecting ds format + if self.eval_dataset_format == "beir": + vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever=vision_retriever) + elif self.eval_dataset_format == "qa": + vidore_evaluator = ViDoReEvaluatorQA(vision_retriever=vision_retriever) + else: + raise ValueError(f"Invalid eval dataset format: {self.eval_dataset_format}") + metrics = vidore_evaluator.evaluate_dataset( ds=ds_coll, batch_query=self.batch_query, diff --git a/colpali_engine/utils/dataset_transformation.py b/colpali_engine/utils/dataset_transformation.py index 9ee0b62d..8f1bfee1 100644 --- a/colpali_engine/utils/dataset_transformation.py +++ b/colpali_engine/utils/dataset_transformation.py @@ -210,6 +210,21 @@ def __call__(self, *args, **kwargs): return dataset +class TestSetFactoryBEIR: + def __init__(self, dataset_path): + self.dataset_path = dataset_path + + def __call__(self, *args, **kwargs): + split = "test" + dataset = { + "corpus": cast(Dataset, load_dataset(self.dataset_path, name="corpus", split=split)), + "queries": cast(Dataset, load_dataset(self.dataset_path, name="queries", split=split)), + "qrels": cast(Dataset, load_dataset(self.dataset_path, name="qrels", split=split)), + } + + return dataset + + if __name__ == "__main__": ds = TestSetFactory("vidore/tabfquad_test_subsampled")() print(ds) From 8f1966ddaed839d3918d544e0cb14a21a67484ef Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Mon, 17 Feb 2025 17:20:06 +0100 Subject: [PATCH 03/16] fix beir support --- colpali_engine/trainer/colmodel_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index dd7c2a60..9db7e382 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -140,6 +140,7 @@ def train(self) -> None: batch_passage=4, batch_score=4, run_frequency=self.config.tr_args.get("eval_steps_frequency", 5), + dataset_format=self.config.tr_args.get("eval_dataset_format", "beir"), ) ) From 8285b51f4fc277052062bce004e8aec73c68f920 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Wed, 5 Mar 2025 16:28:11 +0100 Subject: [PATCH 04/16] wandb callback --- colpali_engine/trainer/eval_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/colpali_engine/trainer/eval_utils.py b/colpali_engine/trainer/eval_utils.py index 3efd3ca0..2268fb5c 100644 --- a/colpali_engine/trainer/eval_utils.py +++ b/colpali_engine/trainer/eval_utils.py @@ -15,7 +15,7 @@ recall_cap, top_k_accuracy, ) -from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments +from transformers import TrainerControl, TrainerState, TrainingArguments, WandbCallback from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA from vidore_benchmark.retrievers import VisionRetriever @@ -180,7 +180,7 @@ def evaluate_abstention( return naucs -class BenchmarkEvalCallback(TrainerCallback): +class BenchmarkEvalCallback(WandbCallback): def __init__( self, processor, @@ -212,7 +212,7 @@ def __init__( self.eval_dataset_format = dataset_format def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - if self.counter_eval % self.eval_steps_frequency != 0: + if state.global_step % self.eval_steps_frequency != 0: self.counter_eval += 1 return else: @@ -256,6 +256,8 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra metrics_collection[test_name] = metrics print(f"Benchmark metrics for tests datasets at step {state.global_step}:") print(metrics_collection) + print("logging metrics to wandb") + self._wandb.log(metrics_collection) except Exception as e: print(f"Error during benchmark evaluation on collection '{self.eval_collection}': {e}") From dc84ce431ece9d03ca02403d2509c657280e8147 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Wed, 5 Mar 2025 17:48:11 +0100 Subject: [PATCH 05/16] only keep a few metrics --- colpali_engine/trainer/eval_utils.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/colpali_engine/trainer/eval_utils.py b/colpali_engine/trainer/eval_utils.py index 2268fb5c..2f920169 100644 --- a/colpali_engine/trainer/eval_utils.py +++ b/colpali_engine/trainer/eval_utils.py @@ -22,6 +22,28 @@ logger = logging.getLogger(__name__) +METRICS_TO_TRACK = [ + "ndcg_at_1", + "ndcg_at_3", + "ndcg_at_5", + "ndcg_at_10", + "ndcg_at_50", + "ndcg_at_100", + "recall_at_1", + "recall_at_3", + "recall_at_5", + "recall_at_10", + "recall_at_50", + "recall_at_100", + "map_at_1", + "map_at_3", + "map_at_5", + "map_at_10", + "map_at_50", + "map_at_100", +] + + class CustomRetrievalEvaluator: """ Wrapper class for the MTEB retrieval evaluator. @@ -253,7 +275,7 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra batch_passage=self.batch_passage, batch_score=self.batch_score, ) - metrics_collection[test_name] = metrics + metrics_collection[test_name] = {k: v for k, v in metrics.items() if k in METRICS_TO_TRACK} print(f"Benchmark metrics for tests datasets at step {state.global_step}:") print(metrics_collection) print("logging metrics to wandb") From 553548aaeaf1faf5f7e2b3b1031b3ea4e224bca1 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Thu, 6 Mar 2025 15:21:05 +0100 Subject: [PATCH 06/16] working wandb --- .../models/idefics3/colidefics3/old_file.py | 51 +++++++++++++++++++ .../colidefics3/publish_base_colidefics3.py | 51 +++++++++++++++++++ colpali_engine/trainer/colmodel_training.py | 46 ++++++++--------- colpali_engine/trainer/eval_utils.py | 4 +- 4 files changed, 128 insertions(+), 24 deletions(-) create mode 100644 colpali_engine/models/idefics3/colidefics3/old_file.py create mode 100644 colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py diff --git a/colpali_engine/models/idefics3/colidefics3/old_file.py b/colpali_engine/models/idefics3/colidefics3/old_file.py new file mode 100644 index 00000000..04ab816d --- /dev/null +++ b/colpali_engine/models/idefics3/colidefics3/old_file.py @@ -0,0 +1,51 @@ +from typing import Annotated, cast + +import torch +import typer +from transformers.models.idefics3 import Idefics3ForConditionalGeneration + +from colpali_engine.models.idefics3.colidefics3.modeling_colidefics3 import ColIdefics3 +from colpali_engine.utils.torch_utils import get_torch_device + + +def main( + vlm_backbone_name: Annotated[str, typer.Option(help="The name of the VLM backbone model to use.")], + new_base_model_name: Annotated[str, typer.Option(help="The name of the base model to push to the hub.")], +): + """ + Publish the base ColIdefics3 model to the hub. + + Args: + - vlm_backbone_name (str): The name of the VLM backbone model to use. + - new_base_model_name (str): The name of the base model to push to the hub. + + Example usage: + ```bash + python colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py \ + --vlm-backbone-name smol-explorers/SmolVLM-256M-Base-25750 \ + --new-base-model-name vidore/colsmolvlm-256M-base + ``` + """ + device = get_torch_device("auto") + + vlm_backbone = cast( + Idefics3ForConditionalGeneration, + Idefics3ForConditionalGeneration.from_pretrained( + vlm_backbone_name, + torch_dtype=torch.bfloat16, + device_map=device, + ), + ).eval() + + model = ColIdefics3(config=vlm_backbone.config).to(device).to(torch.bfloat16).eval() + + # Copy pre-trained weights from old model + model.load_state_dict(vlm_backbone.state_dict(), strict=False) + + model.push_to_hub(new_base_model_name, private=True) + + return + + +if __name__ == "__main__": + typer.run(main) diff --git a/colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py b/colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py new file mode 100644 index 00000000..04ab816d --- /dev/null +++ b/colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py @@ -0,0 +1,51 @@ +from typing import Annotated, cast + +import torch +import typer +from transformers.models.idefics3 import Idefics3ForConditionalGeneration + +from colpali_engine.models.idefics3.colidefics3.modeling_colidefics3 import ColIdefics3 +from colpali_engine.utils.torch_utils import get_torch_device + + +def main( + vlm_backbone_name: Annotated[str, typer.Option(help="The name of the VLM backbone model to use.")], + new_base_model_name: Annotated[str, typer.Option(help="The name of the base model to push to the hub.")], +): + """ + Publish the base ColIdefics3 model to the hub. + + Args: + - vlm_backbone_name (str): The name of the VLM backbone model to use. + - new_base_model_name (str): The name of the base model to push to the hub. + + Example usage: + ```bash + python colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py \ + --vlm-backbone-name smol-explorers/SmolVLM-256M-Base-25750 \ + --new-base-model-name vidore/colsmolvlm-256M-base + ``` + """ + device = get_torch_device("auto") + + vlm_backbone = cast( + Idefics3ForConditionalGeneration, + Idefics3ForConditionalGeneration.from_pretrained( + vlm_backbone_name, + torch_dtype=torch.bfloat16, + device_map=device, + ), + ).eval() + + model = ColIdefics3(config=vlm_backbone.config).to(device).to(torch.bfloat16).eval() + + # Copy pre-trained weights from old model + model.load_state_dict(vlm_backbone.state_dict(), strict=False) + + model.push_to_hub(new_base_model_name, private=True) + + return + + +if __name__ == "__main__": + typer.run(main) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 9db7e382..f7b37ffc 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -119,30 +119,30 @@ def train(self) -> None: trainer.args.remove_unused_columns = False - if self.config.processor is not None and self.config.tr_args.get("run_vidore_evaluator", False): - vidore_eval_dataset_name = self.config.tr_args.get("vidore_eval_dataset_name", None) - vidore_eval_collection_name = self.config.tr_args.get("vidore_eval_collection_name", None) - - if vidore_eval_dataset_name is not None and vidore_eval_collection_name is not None: - raise ValueError( - "Both vidore_eval_dataset_name and vidore_eval_collection_name are provided. " - "You should only provide one of the two" - ) - elif vidore_eval_dataset_name is None and vidore_eval_collection_name is None: - print("WARNING : No dataset provided for ViDoRe evaluation. Skipping evaluation.") - else: - trainer.add_callback( - BenchmarkEvalCallback( - processor=self.config.processor, - model=self.model, - eval_dataset_loader=self.config.eval_dataset_loader, - batch_query=self.config.tr_args.per_device_eval_batch_size, - batch_passage=4, - batch_score=4, - run_frequency=self.config.tr_args.get("eval_steps_frequency", 5), - dataset_format=self.config.tr_args.get("eval_dataset_format", "beir"), - ) + if self.config.processor is not None: # and getattr(self.config.tr_args, "run_vidore_evaluator", False): + # vidore_eval_dataset_name = getattr(self.config.tr_args, "vidore_eval_dataset_name", None) + # vidore_eval_collection_name = getattr(self.config.tr_args, "vidore_eval_collection_name", None) + + # if vidore_eval_dataset_name is not None and vidore_eval_collection_name is not None: + # raise ValueError( + # "Both vidore_eval_dataset_name and vidore_eval_collection_name are provided. " + # "You should only provide one of the two" + # ) + # elif vidore_eval_dataset_name is None and vidore_eval_collection_name is None: + # print("WARNING : No dataset provided for ViDoRe evaluation. Skipping evaluation.") + # else: + trainer.add_callback( + BenchmarkEvalCallback( + processor=self.config.processor, + model=self.model, + eval_dataset_loader=self.config.eval_dataset_loader, + batch_query=self.config.tr_args.per_device_eval_batch_size, + batch_passage=4, + batch_score=4, + run_frequency=getattr(self.config.tr_args, "eval_steps_frequency", 5), + dataset_format=getattr(self.config.tr_args, "eval_dataset_format", "qa"), ) + ) result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint) print_summary(result) diff --git a/colpali_engine/trainer/eval_utils.py b/colpali_engine/trainer/eval_utils.py index 2f920169..ef281976 100644 --- a/colpali_engine/trainer/eval_utils.py +++ b/colpali_engine/trainer/eval_utils.py @@ -15,7 +15,8 @@ recall_cap, top_k_accuracy, ) -from transformers import TrainerControl, TrainerState, TrainingArguments, WandbCallback +from transformers import TrainerControl, TrainerState, TrainingArguments +from transformers.integrations import WandbCallback from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA from vidore_benchmark.retrievers import VisionRetriever @@ -232,6 +233,7 @@ def __init__( self.eval_steps_frequency = run_frequency self.counter_eval = 0 self.eval_dataset_format = dataset_format + super().__init__() def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): if state.global_step % self.eval_steps_frequency != 0: From 8f5dbb01a538d1a03197f4a1bf9dcd18b46133f3 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Thu, 6 Mar 2025 17:36:07 +0100 Subject: [PATCH 07/16] minor changes --- colpali_engine/trainer/colmodel_training.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index f7b37ffc..a7861906 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -119,18 +119,7 @@ def train(self) -> None: trainer.args.remove_unused_columns = False - if self.config.processor is not None: # and getattr(self.config.tr_args, "run_vidore_evaluator", False): - # vidore_eval_dataset_name = getattr(self.config.tr_args, "vidore_eval_dataset_name", None) - # vidore_eval_collection_name = getattr(self.config.tr_args, "vidore_eval_collection_name", None) - - # if vidore_eval_dataset_name is not None and vidore_eval_collection_name is not None: - # raise ValueError( - # "Both vidore_eval_dataset_name and vidore_eval_collection_name are provided. " - # "You should only provide one of the two" - # ) - # elif vidore_eval_dataset_name is None and vidore_eval_collection_name is None: - # print("WARNING : No dataset provided for ViDoRe evaluation. Skipping evaluation.") - # else: + if self.config.processor is not None: trainer.add_callback( BenchmarkEvalCallback( processor=self.config.processor, @@ -139,8 +128,8 @@ def train(self) -> None: batch_query=self.config.tr_args.per_device_eval_batch_size, batch_passage=4, batch_score=4, - run_frequency=getattr(self.config.tr_args, "eval_steps_frequency", 5), - dataset_format=getattr(self.config.tr_args, "eval_dataset_format", "qa"), + run_frequency=getattr(self.config.tr_args, "eval_steps_frequency", 500), + dataset_format=getattr(self.config.tr_args, "eval_dataset_format", "beir"), ) ) From 3617b77d52c59288fbc3fb6b5f4ec1f5e01bda26 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Thu, 6 Mar 2025 18:04:18 +0100 Subject: [PATCH 08/16] put back old functions --- colpali_engine/trainer/colmodel_training.py | 112 ++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index a7861906..9e2c5a7f 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -2,11 +2,17 @@ from dataclasses import dataclass from typing import Callable, Dict, Optional, Tuple, Union +import torch +from datasets import concatenate_datasets from peft import LoraConfig, PeftModel, get_peft_model +from torch.utils.data import DataLoader +from tqdm import tqdm from transformers import ( PreTrainedModel, TrainingArguments, ) +from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA +from vidore_benchmark.retrievers import VisionRetriever from colpali_engine.collators import CorpusQueryCollator, VisualRetrieverCollator from colpali_engine.loss.late_interaction_losses import ( @@ -136,6 +142,112 @@ def train(self) -> None: result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint) print_summary(result) + def eval_dataset(self, test_dataset): + self.model.eval() + + idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None] + idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None] + + dataloader_with_query = DataLoader( + test_dataset.select(idx_with_query), + batch_size=self.config.tr_args.per_device_eval_batch_size, + shuffle=False, + collate_fn=self.collator, + ) + dataloader_without_query = DataLoader( + test_dataset.select(idx_without_query), + batch_size=self.config.tr_args.per_device_eval_batch_size, + shuffle=False, + collate_fn=self.collator, + ) + + # dataset is ordered so that non-null queries come first + test_dataset = concatenate_datasets( + [test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)] + ) + + relevant_docs = {} + docidx_2_docid = {} + qsidx_2_query = [] + for idx, sample in enumerate(test_dataset): + doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"])) + # query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"])) + if sample["query"] is not None: + relevant_docs[str(idx)] = {doc_id: 1} + qsidx_2_query.append(str(idx)) + docidx_2_docid[str(idx)] = doc_id + + qs = [] + ps = [] + + device = self.model.device + with torch.no_grad(): + for dataloader in [dataloader_with_query, dataloader_without_query]: + for batch in tqdm(dataloader): + # feed only kwargs with 'doc_' prefix + doc = self.model(**{k[4:]: v.to(device) for k, v in batch.items() if k.startswith("doc")}) + ps.extend(list(torch.unbind(doc.to("cpu")))) + + if "query_input_ids" in batch: + query = self.model( + input_ids=batch["query_input_ids"].to(device), + attention_mask=batch["query_attention_mask"].to(device), + ) + # variable len + qs.extend(list(torch.unbind(query.to("cpu")))) + + print("Embeddings computed, evaluating") + scores = self.config.processor.score(qs, ps, device=self.model.device) + # scores is 2d array of shape (n_queries, n_docs) + # turn it into a dict + results = {} + assert scores.shape[0] == len(qsidx_2_query) + for idx, scores_per_query in enumerate(scores): + results[qsidx_2_query[idx]] = { + docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query) + } + + # evaluate + metrics = self.retrieval_evaluator.compute_mteb_metrics(relevant_docs, results) + print("MTEB metrics:", metrics) + + return metrics + + def eval(self) -> None: + all_metrics = {} + try: + print("Evaluating on validation set") + metrics = self.eval_dataset(self.dataset["test"]) + print(f"Metrics for validation set: {metrics}") + all_metrics["validation_set"] = metrics + except Exception as e: + print(f"Error evaluating validation set: {e}") + + if self.config.eval_dataset_loader is not None: + # Create a vision retriever with the current model checkpoint. + vision_retriever = VisionRetriever( + model=self.model, + processor=self.config.processor, + ) + if getattr(self.config.tr_args, "eval_dataset_format", "beir") == "beir": + vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever) + elif getattr(self.config.tr_args, "eval_dataset_format", "beir") == "qa": + vidore_evaluator = ViDoReEvaluatorQA(vision_retriever) + else: + raise ValueError("eval_dataset_format must be 'beir' or 'qa'") + + for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items(): + print(f"Evaluating {test_name}") + test_ds = test_dataset_loading_func() + metrics = vidore_evaluator.evaluate_dataset( + ds=test_ds, + batch_query=self.config.tr_args.per_device_eval_batch_size, + batch_passage=self.config.tr_args.per_device_eval_batch_size, + batch_score=self.config.tr_args.per_device_eval_batch_size, + ) + all_metrics[test_name] = metrics + print(f"Metrics for {test_name}: {metrics}") + def save(self, config_file): # save model self.model.save_pretrained(self.config.output_dir) From 4033d4b58b7a68283128aad553b9baad9b9e8519 Mon Sep 17 00:00:00 2001 From: Tony Wu <28306721+tonywu71@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:20:03 +0200 Subject: [PATCH 09/16] fix: fix artifact code from merge conflict --- colpali_engine/trainer/eval_utils.py | 170 --------------------------- 1 file changed, 170 deletions(-) diff --git a/colpali_engine/trainer/eval_utils.py b/colpali_engine/trainer/eval_utils.py index ef281976..9ad19f77 100644 --- a/colpali_engine/trainer/eval_utils.py +++ b/colpali_engine/trainer/eval_utils.py @@ -2,19 +2,7 @@ from __future__ import annotations import logging -from typing import Dict -import numpy as np -import pytrec_eval -from mteb.evaluation.evaluators.RetrievalEvaluator import RetrievalEvaluator -from mteb.evaluation.evaluators.utils import ( - confidence_scores, - hole, - mrr, - nAUC, - recall_cap, - top_k_accuracy, -) from transformers import TrainerControl, TrainerState, TrainingArguments from transformers.integrations import WandbCallback from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA @@ -45,164 +33,6 @@ ] -class CustomRetrievalEvaluator: - """ - Wrapper class for the MTEB retrieval evaluator. - """ - - def __init__(self, k_values: list[int] = [1, 3, 5, 10, 20, 50, 100]): - self.k_values = k_values - - def compute_mteb_metrics( - self, - relevant_docs: Dict[str, dict[str, int]], - results: Dict[str, dict[str, float]], - **kwargs, - ) -> Dict[str, float]: - """ - Compute the MTEB retrieval metrics. - """ - ndcg, _map, recall, precision, naucs = self.evaluate( - relevant_docs, - results, - self.k_values, - ignore_identical_ids=kwargs.get("ignore_identical_ids", True), - ) - - mrr = self.evaluate_custom(relevant_docs, results, self.k_values, "mrr") - - scores = { - **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, - **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, - **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, - **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, - **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()}, - **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()}, - } - return scores - - @staticmethod - def evaluate( - qrels: dict[str, dict[str, int]], - results: dict[str, dict[str, float]], - k_values: list[int], - ignore_identical_ids: bool = False, - ) -> tuple[ - dict[str, float], - dict[str, float], - dict[str, float], - dict[str, float], - dict[str, float], - ]: - if ignore_identical_ids: - logger.debug( - "For evaluation, ``ignore_identical_ids=True`` is set to True, the evaluator will ignore " - "identical query and document ids." - ) - # Remove identical ids from results dict - for qid, rels in results.items(): - for pid in list(rels): - if qid == pid: - results[qid].pop(pid) - else: - logger.debug( - "For evaluation, we DO NOT ignore identical query and document ids (default), please explicitly " - "set ``ignore_identical_ids=True`` to ignore this." - ) - - all_ndcgs, all_aps, all_recalls, all_precisions = {}, {}, {}, {} - - for k in k_values: - all_ndcgs[f"NDCG@{k}"] = [] - all_aps[f"MAP@{k}"] = [] - all_recalls[f"Recall@{k}"] = [] - all_precisions[f"P@{k}"] = [] - - map_string = "map_cut." + ",".join([str(k) for k in k_values]) - ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) - recall_string = "recall." + ",".join([str(k) for k in k_values]) - precision_string = "P." + ",".join([str(k) for k in k_values]) - evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string}) - scores = evaluator.evaluate(results) - - for query_id in scores.keys(): - for k in k_values: - all_ndcgs[f"NDCG@{k}"].append(scores[query_id]["ndcg_cut_" + str(k)]) - all_aps[f"MAP@{k}"].append(scores[query_id]["map_cut_" + str(k)]) - all_recalls[f"Recall@{k}"].append(scores[query_id]["recall_" + str(k)]) - all_precisions[f"P@{k}"].append(scores[query_id]["P_" + str(k)]) - - ndcg, _map, recall, precision = ( - all_ndcgs.copy(), - all_aps.copy(), - all_recalls.copy(), - all_precisions.copy(), - ) - - for k in k_values: - ndcg[f"NDCG@{k}"] = round(sum(ndcg[f"NDCG@{k}"]) / len(scores), 5) - _map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5) - recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5) - precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5) - - naucs = RetrievalEvaluator.evaluate_abstention( - results, {**all_ndcgs, **all_aps, **all_recalls, **all_precisions} - ) - - return ndcg, _map, recall, precision, naucs - - @staticmethod - def evaluate_custom( - qrels: dict[str, dict[str, int]], - results: dict[str, dict[str, float]], - k_values: list[int], - metric: str, - output_type: str = "all", - ) -> tuple[dict[str, float], dict[str, float]]: - if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]: - metric_scores = mrr(qrels, results, k_values, output_type) - - elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]: - metric_scores = recall_cap(qrels, results, k_values, output_type) - - elif metric.lower() in ["hole", "hole@k"]: - metric_scores = hole(qrels, results, k_values, output_type) - - elif metric.lower() in [ - "acc", - "top_k_acc", - "accuracy", - "accuracy@k", - "top_k_accuracy", - ]: - metric_scores = top_k_accuracy(qrels, results, k_values, output_type) - - naucs = RetrievalEvaluator.evaluate_abstention(results, metric_scores) - metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()} - - return metric_scores_avg, naucs - - @staticmethod - def evaluate_abstention( - results: dict[str, dict[str, float]], - metric_scores: dict[str, list[float]], - ) -> dict[str, float]: - """Computes normalized Area Under the Curve on a set of evaluated instances as presented in - the paper https://arxiv.org/abs/2402.12997""" - all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())] - all_conf_scores = [confidence_scores(sim_scores) for sim_scores in all_sim_scores] - conf_fcts = list(all_conf_scores[0].keys()) - all_conf_scores = {fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts} - metric_scores = {k: np.array(v) for k, v in metric_scores.items()} - naucs = {} - - for metric_name, scores in metric_scores.items(): - for fct, conf_scores in all_conf_scores.items(): - naucs[f"nAUC_{metric_name}_{fct}"] = nAUC(conf_scores, scores) - - return naucs - - class BenchmarkEvalCallback(WandbCallback): def __init__( self, From f53e296c94a095d6a5d7c1a310bb3debe04d28e4 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Fri, 4 Apr 2025 17:05:56 +0200 Subject: [PATCH 10/16] refactoring --- colpali_engine/trainer/colmodel_training.py | 120 ++++---------------- colpali_engine/trainer/eval_utils.py | 89 ++++++++++----- 2 files changed, 79 insertions(+), 130 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 9e2c5a7f..32c9681e 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -2,24 +2,18 @@ from dataclasses import dataclass from typing import Callable, Dict, Optional, Tuple, Union -import torch -from datasets import concatenate_datasets from peft import LoraConfig, PeftModel, get_peft_model -from torch.utils.data import DataLoader -from tqdm import tqdm from transformers import ( PreTrainedModel, TrainingArguments, ) -from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA -from vidore_benchmark.retrievers import VisionRetriever from colpali_engine.collators import CorpusQueryCollator, VisualRetrieverCollator from colpali_engine.loss.late_interaction_losses import ( ColbertLoss, ) from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer -from colpali_engine.trainer.eval_utils import BenchmarkEvalCallback +from colpali_engine.trainer.eval_utils import BenchmarkEvalCallback, evaluate_dataset from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor @@ -142,111 +136,35 @@ def train(self) -> None: result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint) print_summary(result) - def eval_dataset(self, test_dataset): - self.model.eval() - - idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None] - idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None] - - dataloader_with_query = DataLoader( - test_dataset.select(idx_with_query), - batch_size=self.config.tr_args.per_device_eval_batch_size, - shuffle=False, - collate_fn=self.collator, - ) - dataloader_without_query = DataLoader( - test_dataset.select(idx_without_query), - batch_size=self.config.tr_args.per_device_eval_batch_size, - shuffle=False, - collate_fn=self.collator, - ) - - # dataset is ordered so that non-null queries come first - test_dataset = concatenate_datasets( - [test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)] - ) - - relevant_docs = {} - docidx_2_docid = {} - qsidx_2_query = [] - for idx, sample in enumerate(test_dataset): - doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"])) - # query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"])) - if sample["query"] is not None: - relevant_docs[str(idx)] = {doc_id: 1} - qsidx_2_query.append(str(idx)) - docidx_2_docid[str(idx)] = doc_id - - qs = [] - ps = [] - - device = self.model.device - with torch.no_grad(): - for dataloader in [dataloader_with_query, dataloader_without_query]: - for batch in tqdm(dataloader): - # feed only kwargs with 'doc_' prefix - doc = self.model(**{k[4:]: v.to(device) for k, v in batch.items() if k.startswith("doc")}) - ps.extend(list(torch.unbind(doc.to("cpu")))) - - if "query_input_ids" in batch: - query = self.model( - input_ids=batch["query_input_ids"].to(device), - attention_mask=batch["query_attention_mask"].to(device), - ) - # variable len - qs.extend(list(torch.unbind(query.to("cpu")))) - - print("Embeddings computed, evaluating") - scores = self.config.processor.score(qs, ps, device=self.model.device) - # scores is 2d array of shape (n_queries, n_docs) - # turn it into a dict - results = {} - assert scores.shape[0] == len(qsidx_2_query) - for idx, scores_per_query in enumerate(scores): - results[qsidx_2_query[idx]] = { - docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query) - } - - # evaluate - metrics = self.retrieval_evaluator.compute_mteb_metrics(relevant_docs, results) - print("MTEB metrics:", metrics) - - return metrics - def eval(self) -> None: all_metrics = {} - try: - print("Evaluating on validation set") - metrics = self.eval_dataset(self.dataset["test"]) - print(f"Metrics for validation set: {metrics}") - all_metrics["validation_set"] = metrics - except Exception as e: - print(f"Error evaluating validation set: {e}") + + all_metrics["validation_set"] = evaluate_dataset( + model=self.model, + processor=self.config.processor, + dataset=self.dataset["test"], + format="qa", + batch_passage=self.config.tr_args.per_device_eval_batch_size, + batch_query=self.config.tr_args.per_device_eval_batch_size, + batch_score=self.config.tr_args.per_device_eval_batch_size, + ) if self.config.eval_dataset_loader is not None: # Create a vision retriever with the current model checkpoint. - vision_retriever = VisionRetriever( - model=self.model, - processor=self.config.processor, - ) - if getattr(self.config.tr_args, "eval_dataset_format", "beir") == "beir": - vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever) - elif getattr(self.config.tr_args, "eval_dataset_format", "beir") == "qa": - vidore_evaluator = ViDoReEvaluatorQA(vision_retriever) - else: - raise ValueError("eval_dataset_format must be 'beir' or 'qa'") + eval_dataset_format = getattr(self.config.tr_args, "eval_dataset_format", "beir") for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items(): print(f"Evaluating {test_name}") - test_ds = test_dataset_loading_func() - metrics = vidore_evaluator.evaluate_dataset( - ds=test_ds, - batch_query=self.config.tr_args.per_device_eval_batch_size, + all_metrics[test_name] = evaluate_dataset( + model=self.model, + processor=self.config.processor, + dataset=test_dataset_loading_func(), + format=eval_dataset_format, batch_passage=self.config.tr_args.per_device_eval_batch_size, + batch_query=self.config.tr_args.per_device_eval_batch_size, batch_score=self.config.tr_args.per_device_eval_batch_size, ) - all_metrics[test_name] = metrics - print(f"Metrics for {test_name}: {metrics}") + print(f"Metrics for {test_name}: {all_metrics[test_name]}") def save(self, config_file): # save model diff --git a/colpali_engine/trainer/eval_utils.py b/colpali_engine/trainer/eval_utils.py index 9ad19f77..232dcd79 100644 --- a/colpali_engine/trainer/eval_utils.py +++ b/colpali_engine/trainer/eval_utils.py @@ -2,12 +2,17 @@ from __future__ import annotations import logging +from typing import Dict, Union -from transformers import TrainerControl, TrainerState, TrainingArguments +from datasets import Dataset +from peft import PeftModel +from transformers import PreTrainedModel, TrainerControl, TrainerState, TrainingArguments from transformers.integrations import WandbCallback from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA from vidore_benchmark.retrievers import VisionRetriever +from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor + logger = logging.getLogger(__name__) @@ -46,13 +51,17 @@ def __init__( dataset_format: str = "beir", ): """ - :param processor: The processor instance (e.g., ColIdefics3Processor) needed for retrieval. - :eval_dataset_loader: A dictionary with the test dataset names as keys and functions that load the datasets as - values. - :param batch_query: Batch size for queries. - :param batch_passage: Batch size for passages. - :param batch_score: Batch size for scoring. - :param run_frequency: Frequency of evaluation ver the evaluation triggers. + Callback to evaluate the model on a collection of datasets during training. + + Args: + processor: The processor to use for the model. + model: The model to evaluate. + eval_dataset_loader: A dictionary of dataset loading functions. + batch_query: Batch size for query. + batch_passage: Batch size for passage. + batch_score: Batch size for scoring. + run_frequency: Frequency of evaluation in steps. + dataset_format: Format of the evaluation dataset, either "beir" or "qa". """ self.processor = processor self.model = model @@ -77,34 +86,19 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra return print(f"\n=== Running benchmark evaluation at global step {state.global_step} ===") - # Set model to evaluation mode. - self.model.eval() - - # Create a vision retriever with the current model checkpoint. - vision_retriever = VisionRetriever( - model=self.model, - processor=self.processor, - ) # Evaluate on a collection. if self.eval_dataset_loader is not None: try: metrics_collection = {} for test_name, test_dataset_loading_func in self.eval_dataset_loader.items(): - ds_coll = test_dataset_loading_func() - - # Temporary before we are caplable of detecting ds format - if self.eval_dataset_format == "beir": - vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever=vision_retriever) - elif self.eval_dataset_format == "qa": - vidore_evaluator = ViDoReEvaluatorQA(vision_retriever=vision_retriever) - else: - raise ValueError(f"Invalid eval dataset format: {self.eval_dataset_format}") - - metrics = vidore_evaluator.evaluate_dataset( - ds=ds_coll, - batch_query=self.batch_query, + metrics = evaluate_dataset( + model=self.model, + processor=self.processor, + dataset=test_dataset_loading_func(), + format=self.eval_dataset_format, batch_passage=self.batch_passage, + batch_query=self.batch_query, batch_score=self.batch_score, ) metrics_collection[test_name] = {k: v for k, v in metrics.items() if k in METRICS_TO_TRACK} @@ -118,3 +112,40 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra # Set model back to train mode. self.model.train() return + + +def evaluate_dataset( + model: Union[PreTrainedModel, PeftModel], + processor: BaseVisualRetrieverProcessor, + dataset: Dataset, + format: str = "beir", + batch_passage: int = 4, + batch_query: int = 4, + batch_score: int = 4, +) -> Dict[str, float]: + """ + Evaluate a dataset using the vidore-benchmark library. + """ + model.eval() + + # Create a vision retriever with the current model checkpoint. + vision_retriever = VisionRetriever( + model=model, + processor=processor, + ) + + if format == "qa": + vidore_evaluator = ViDoReEvaluatorQA(vision_retriever) + elif format == "beir": + vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever) + else: + raise ValueError(f"Invalid dataset format: {format}, must be 'qa' or 'beir'") + + metrics = vidore_evaluator.evaluate_dataset( + ds=dataset, + batch_query=batch_query, + batch_passage=batch_passage, + batch_score=batch_score, + ) + + return metrics From f06ef103dc9081e2aa4627b2f9e304cd3c6909d0 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Fri, 4 Apr 2025 17:10:12 +0200 Subject: [PATCH 11/16] fix --- .../models/idefics3/colidefics3/old_file.py | 51 ------------------- .../colidefics3/publish_base_colidefics3.py | 51 ------------------- 2 files changed, 102 deletions(-) delete mode 100644 colpali_engine/models/idefics3/colidefics3/old_file.py delete mode 100644 colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py diff --git a/colpali_engine/models/idefics3/colidefics3/old_file.py b/colpali_engine/models/idefics3/colidefics3/old_file.py deleted file mode 100644 index 04ab816d..00000000 --- a/colpali_engine/models/idefics3/colidefics3/old_file.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Annotated, cast - -import torch -import typer -from transformers.models.idefics3 import Idefics3ForConditionalGeneration - -from colpali_engine.models.idefics3.colidefics3.modeling_colidefics3 import ColIdefics3 -from colpali_engine.utils.torch_utils import get_torch_device - - -def main( - vlm_backbone_name: Annotated[str, typer.Option(help="The name of the VLM backbone model to use.")], - new_base_model_name: Annotated[str, typer.Option(help="The name of the base model to push to the hub.")], -): - """ - Publish the base ColIdefics3 model to the hub. - - Args: - - vlm_backbone_name (str): The name of the VLM backbone model to use. - - new_base_model_name (str): The name of the base model to push to the hub. - - Example usage: - ```bash - python colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py \ - --vlm-backbone-name smol-explorers/SmolVLM-256M-Base-25750 \ - --new-base-model-name vidore/colsmolvlm-256M-base - ``` - """ - device = get_torch_device("auto") - - vlm_backbone = cast( - Idefics3ForConditionalGeneration, - Idefics3ForConditionalGeneration.from_pretrained( - vlm_backbone_name, - torch_dtype=torch.bfloat16, - device_map=device, - ), - ).eval() - - model = ColIdefics3(config=vlm_backbone.config).to(device).to(torch.bfloat16).eval() - - # Copy pre-trained weights from old model - model.load_state_dict(vlm_backbone.state_dict(), strict=False) - - model.push_to_hub(new_base_model_name, private=True) - - return - - -if __name__ == "__main__": - typer.run(main) diff --git a/colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py b/colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py deleted file mode 100644 index 04ab816d..00000000 --- a/colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Annotated, cast - -import torch -import typer -from transformers.models.idefics3 import Idefics3ForConditionalGeneration - -from colpali_engine.models.idefics3.colidefics3.modeling_colidefics3 import ColIdefics3 -from colpali_engine.utils.torch_utils import get_torch_device - - -def main( - vlm_backbone_name: Annotated[str, typer.Option(help="The name of the VLM backbone model to use.")], - new_base_model_name: Annotated[str, typer.Option(help="The name of the base model to push to the hub.")], -): - """ - Publish the base ColIdefics3 model to the hub. - - Args: - - vlm_backbone_name (str): The name of the VLM backbone model to use. - - new_base_model_name (str): The name of the base model to push to the hub. - - Example usage: - ```bash - python colpali_engine/models/idefics3/colidefics3/publish_base_colidefics3.py \ - --vlm-backbone-name smol-explorers/SmolVLM-256M-Base-25750 \ - --new-base-model-name vidore/colsmolvlm-256M-base - ``` - """ - device = get_torch_device("auto") - - vlm_backbone = cast( - Idefics3ForConditionalGeneration, - Idefics3ForConditionalGeneration.from_pretrained( - vlm_backbone_name, - torch_dtype=torch.bfloat16, - device_map=device, - ), - ).eval() - - model = ColIdefics3(config=vlm_backbone.config).to(device).to(torch.bfloat16).eval() - - # Copy pre-trained weights from old model - model.load_state_dict(vlm_backbone.state_dict(), strict=False) - - model.push_to_hub(new_base_model_name, private=True) - - return - - -if __name__ == "__main__": - typer.run(main) From 400c3ec2c9d688b501a8aefe80341353d84b0d41 Mon Sep 17 00:00:00 2001 From: Quentin Mace Date: Fri, 4 Apr 2025 17:12:29 +0200 Subject: [PATCH 12/16] f --- colpali_engine/trainer/colmodel_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 32c9681e..73b6a96f 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -166,7 +166,7 @@ def eval(self) -> None: ) print(f"Metrics for {test_name}: {all_metrics[test_name]}") - def save(self, config_file): + def save(self, config_file: str): # save model self.model.save_pretrained(self.config.output_dir) self.config.processor.save_pretrained(self.config.output_dir) From 9c4d1d0426f46d168dc3efc5e3c7fc405e4a0be2 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Thu, 17 Apr 2025 16:17:57 +0200 Subject: [PATCH 13/16] doc --- colpali_engine/trainer/colmodel_training.py | 29 ++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 73b6a96f..2e69d69e 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -20,6 +20,27 @@ @dataclass class ColModelTrainingConfig: + """Configuration for training a ColVision model. + + Args: + model (Union[PreTrainedModel, PeftModel]): Base model to train. + processor (BaseVisualRetrieverProcessor): Processor for visual data processing. + tr_args (Optional[TrainingArguments]): Transformers training arguments. If not provided, uses default values. + output_dir (Optional[str]): Output directory to save the model. + If not provided, creates a path based on model name. + max_length (int): Maximum sequence length for inputs. Default: 256. + run_eval (bool): If True, runs evaluation. Default: True. + run_train (bool): If True, runs training. Default: True. + vidore_eval_frequency (int): Vidore evaluation frequency, must be a multiple of tr_args.eval_steps. + Pass -1 to disable. Default: -1. + eval_dataset_format (str): Evaluation dataset format ("qa" or "beir"). Default: "qa". + peft_config (Optional[LoraConfig]): PEFT configuration if used. Default: None. + loss_func (Optional[Callable]): Custom loss function. Default: ColbertLoss(). + dataset_loading_func (Optional[Callable]): Dataset loading function. Default: None. + eval_dataset_loader (Optional[Dict[str, Callable]]): Evaluation dataset loaders. Default: None. + pretrained_peft_model_name_or_path (Optional[str]): Path to a pretrained PEFT model. Default: None. + """ + model: Union[PreTrainedModel, PeftModel] processor: BaseVisualRetrieverProcessor tr_args: Optional[TrainingArguments] = None @@ -27,6 +48,8 @@ class ColModelTrainingConfig: max_length: int = 256 run_eval: bool = True run_train: bool = True + vidore_eval_frequency: int = -1 + eval_dataset_format: str = "qa" peft_config: Optional[LoraConfig] = None loss_func: Optional[Callable] = ColbertLoss() dataset_loading_func: Optional[Callable] = None @@ -119,7 +142,7 @@ def train(self) -> None: trainer.args.remove_unused_columns = False - if self.config.processor is not None: + if self.config.processor is not None and self.config.vidore_eval_frequency > 0: trainer.add_callback( BenchmarkEvalCallback( processor=self.config.processor, @@ -128,8 +151,8 @@ def train(self) -> None: batch_query=self.config.tr_args.per_device_eval_batch_size, batch_passage=4, batch_score=4, - run_frequency=getattr(self.config.tr_args, "eval_steps_frequency", 500), - dataset_format=getattr(self.config.tr_args, "eval_dataset_format", "beir"), + run_frequency=self.config.vidore_eval_frequency, + dataset_format=self.config.eval_dataset_format, ) ) From 3c40b4afc1de1c7d04b394530054aaae2d8e1217 Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Thu, 17 Apr 2025 16:27:17 +0200 Subject: [PATCH 14/16] add example --- .../train_colqwen2_model_eval_vidore.yaml | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 scripts/configs/qwen2/train_colqwen2_model_eval_vidore.yaml diff --git a/scripts/configs/qwen2/train_colqwen2_model_eval_vidore.yaml b/scripts/configs/qwen2/train_colqwen2_model_eval_vidore.yaml new file mode 100644 index 00000000..ff4cd14f --- /dev/null +++ b/scripts/configs/qwen2/train_colqwen2_model_eval_vidore.yaml @@ -0,0 +1,65 @@ +config: + (): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig + output_dir: !path ../../../models/colqwen2-ba256-5e-0304 + processor: + (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper + class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor + pretrained_model_name_or_path: "./models/base_models/colqwen2-base" + max_num_visual_tokens: 1024 + + model: + (): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper + class_to_instanciate: !ext colpali_engine.models.ColQwen2 + pretrained_model_name_or_path: "./models/base_models/colqwen2-base" + torch_dtype: !ext torch.bfloat16 + use_cache: false + attn_implementation: "flash_attention_2" + + dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set + eval_dataset_loader: !import ../data/test_data.yaml + vidore_eval_frequency: 200 + eval_dataset_format: "qa" + + # max_length: 50 + run_eval: true + loss_func: + (): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss + tr_args: + (): transformers.training_args.TrainingArguments + output_dir: null + overwrite_output_dir: true + num_train_epochs: 5 + per_device_train_batch_size: 64 + gradient_checkpointing: true + gradient_checkpointing_kwargs: { "use_reentrant": false } + # 6 x 8 gpus = 48 batch size + # gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_strategy: "steps" + dataloader_num_workers: 8 + # bf16: true + save_steps: 500 + logging_steps: 10 + eval_steps: 100 + warmup_steps: 100 + learning_rate: 2e-4 + save_total_limit: 1 + # resume_from_checkpoint: true + # optim: "paged_adamw_8bit" + # wandb logging + # wandb_project: "colqwen2" + # run_name: "colqwen2-ba32-nolora" + report_to: "wandb" + + + peft_config: + (): peft.LoraConfig + r: 32 + lora_alpha: 32 + lora_dropout: 0.1 + init_lora_weights: "gaussian" + bias: "none" + task_type: "FEATURE_EXTRACTION" + target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' + # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' + From 78192924a854144ea80c6e7e0fb66c9dc90e56bc Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Thu, 17 Apr 2025 16:29:54 +0200 Subject: [PATCH 15/16] f --- colpali_engine/trainer/colmodel_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colpali_engine/trainer/colmodel_training.py b/colpali_engine/trainer/colmodel_training.py index 2e69d69e..c3521e13 100644 --- a/colpali_engine/trainer/colmodel_training.py +++ b/colpali_engine/trainer/colmodel_training.py @@ -149,7 +149,7 @@ def train(self) -> None: model=self.model, eval_dataset_loader=self.config.eval_dataset_loader, batch_query=self.config.tr_args.per_device_eval_batch_size, - batch_passage=4, + batch_passage=self.config.tr_args.per_device_eval_batch_size, batch_score=4, run_frequency=self.config.vidore_eval_frequency, dataset_format=self.config.eval_dataset_format, From 7b4b1aa1a2bd4f569a2fe6ce3c3b0490cec814ae Mon Sep 17 00:00:00 2001 From: QuentinJGMace Date: Thu, 17 Apr 2025 16:34:30 +0200 Subject: [PATCH 16/16] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f036db2d..1b360f49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added +- Add the possibility for a user to evaluate a model on retrieval datasets (e.g ViDoRe benchmark) during its training. - Add `LambdaTokenPooler` to allow for custom token pooling functions. - Added training losses with negatives to InfoNCE type losses