|
2 | 2 | from __future__ import annotations
|
3 | 3 |
|
4 | 4 | import logging
|
5 |
| -from typing import Dict |
6 | 5 |
|
7 |
| -import numpy as np |
8 |
| -import pytrec_eval |
9 |
| -from mteb.evaluation.evaluators.RetrievalEvaluator import RetrievalEvaluator |
10 |
| -from mteb.evaluation.evaluators.utils import ( |
11 |
| - confidence_scores, |
12 |
| - hole, |
13 |
| - mrr, |
14 |
| - nAUC, |
15 |
| - recall_cap, |
16 |
| - top_k_accuracy, |
17 |
| -) |
18 | 6 | from transformers import TrainerControl, TrainerState, TrainingArguments
|
19 | 7 | from transformers.integrations import WandbCallback
|
20 | 8 | from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA
|
|
45 | 33 | ]
|
46 | 34 |
|
47 | 35 |
|
48 |
| -class CustomRetrievalEvaluator: |
49 |
| - """ |
50 |
| - Wrapper class for the MTEB retrieval evaluator. |
51 |
| - """ |
52 |
| - |
53 |
| - def __init__(self, k_values: list[int] = [1, 3, 5, 10, 20, 50, 100]): |
54 |
| - self.k_values = k_values |
55 |
| - |
56 |
| - def compute_mteb_metrics( |
57 |
| - self, |
58 |
| - relevant_docs: Dict[str, dict[str, int]], |
59 |
| - results: Dict[str, dict[str, float]], |
60 |
| - **kwargs, |
61 |
| - ) -> Dict[str, float]: |
62 |
| - """ |
63 |
| - Compute the MTEB retrieval metrics. |
64 |
| - """ |
65 |
| - ndcg, _map, recall, precision, naucs = self.evaluate( |
66 |
| - relevant_docs, |
67 |
| - results, |
68 |
| - self.k_values, |
69 |
| - ignore_identical_ids=kwargs.get("ignore_identical_ids", True), |
70 |
| - ) |
71 |
| - |
72 |
| - mrr = self.evaluate_custom(relevant_docs, results, self.k_values, "mrr") |
73 |
| - |
74 |
| - scores = { |
75 |
| - **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, |
76 |
| - **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, |
77 |
| - **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, |
78 |
| - **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, |
79 |
| - **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()}, |
80 |
| - **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()}, |
81 |
| - } |
82 |
| - return scores |
83 |
| - |
84 |
| - @staticmethod |
85 |
| - def evaluate( |
86 |
| - qrels: dict[str, dict[str, int]], |
87 |
| - results: dict[str, dict[str, float]], |
88 |
| - k_values: list[int], |
89 |
| - ignore_identical_ids: bool = False, |
90 |
| - ) -> tuple[ |
91 |
| - dict[str, float], |
92 |
| - dict[str, float], |
93 |
| - dict[str, float], |
94 |
| - dict[str, float], |
95 |
| - dict[str, float], |
96 |
| - ]: |
97 |
| - if ignore_identical_ids: |
98 |
| - logger.debug( |
99 |
| - "For evaluation, ``ignore_identical_ids=True`` is set to True, the evaluator will ignore " |
100 |
| - "identical query and document ids." |
101 |
| - ) |
102 |
| - # Remove identical ids from results dict |
103 |
| - for qid, rels in results.items(): |
104 |
| - for pid in list(rels): |
105 |
| - if qid == pid: |
106 |
| - results[qid].pop(pid) |
107 |
| - else: |
108 |
| - logger.debug( |
109 |
| - "For evaluation, we DO NOT ignore identical query and document ids (default), please explicitly " |
110 |
| - "set ``ignore_identical_ids=True`` to ignore this." |
111 |
| - ) |
112 |
| - |
113 |
| - all_ndcgs, all_aps, all_recalls, all_precisions = {}, {}, {}, {} |
114 |
| - |
115 |
| - for k in k_values: |
116 |
| - all_ndcgs[f"NDCG@{k}"] = [] |
117 |
| - all_aps[f"MAP@{k}"] = [] |
118 |
| - all_recalls[f"Recall@{k}"] = [] |
119 |
| - all_precisions[f"P@{k}"] = [] |
120 |
| - |
121 |
| - map_string = "map_cut." + ",".join([str(k) for k in k_values]) |
122 |
| - ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) |
123 |
| - recall_string = "recall." + ",".join([str(k) for k in k_values]) |
124 |
| - precision_string = "P." + ",".join([str(k) for k in k_values]) |
125 |
| - evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string}) |
126 |
| - scores = evaluator.evaluate(results) |
127 |
| - |
128 |
| - for query_id in scores.keys(): |
129 |
| - for k in k_values: |
130 |
| - all_ndcgs[f"NDCG@{k}"].append(scores[query_id]["ndcg_cut_" + str(k)]) |
131 |
| - all_aps[f"MAP@{k}"].append(scores[query_id]["map_cut_" + str(k)]) |
132 |
| - all_recalls[f"Recall@{k}"].append(scores[query_id]["recall_" + str(k)]) |
133 |
| - all_precisions[f"P@{k}"].append(scores[query_id]["P_" + str(k)]) |
134 |
| - |
135 |
| - ndcg, _map, recall, precision = ( |
136 |
| - all_ndcgs.copy(), |
137 |
| - all_aps.copy(), |
138 |
| - all_recalls.copy(), |
139 |
| - all_precisions.copy(), |
140 |
| - ) |
141 |
| - |
142 |
| - for k in k_values: |
143 |
| - ndcg[f"NDCG@{k}"] = round(sum(ndcg[f"NDCG@{k}"]) / len(scores), 5) |
144 |
| - _map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5) |
145 |
| - recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5) |
146 |
| - precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5) |
147 |
| - |
148 |
| - naucs = RetrievalEvaluator.evaluate_abstention( |
149 |
| - results, {**all_ndcgs, **all_aps, **all_recalls, **all_precisions} |
150 |
| - ) |
151 |
| - |
152 |
| - return ndcg, _map, recall, precision, naucs |
153 |
| - |
154 |
| - @staticmethod |
155 |
| - def evaluate_custom( |
156 |
| - qrels: dict[str, dict[str, int]], |
157 |
| - results: dict[str, dict[str, float]], |
158 |
| - k_values: list[int], |
159 |
| - metric: str, |
160 |
| - output_type: str = "all", |
161 |
| - ) -> tuple[dict[str, float], dict[str, float]]: |
162 |
| - if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]: |
163 |
| - metric_scores = mrr(qrels, results, k_values, output_type) |
164 |
| - |
165 |
| - elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]: |
166 |
| - metric_scores = recall_cap(qrels, results, k_values, output_type) |
167 |
| - |
168 |
| - elif metric.lower() in ["hole", "hole@k"]: |
169 |
| - metric_scores = hole(qrels, results, k_values, output_type) |
170 |
| - |
171 |
| - elif metric.lower() in [ |
172 |
| - "acc", |
173 |
| - "top_k_acc", |
174 |
| - "accuracy", |
175 |
| - "accuracy@k", |
176 |
| - "top_k_accuracy", |
177 |
| - ]: |
178 |
| - metric_scores = top_k_accuracy(qrels, results, k_values, output_type) |
179 |
| - |
180 |
| - naucs = RetrievalEvaluator.evaluate_abstention(results, metric_scores) |
181 |
| - metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()} |
182 |
| - |
183 |
| - return metric_scores_avg, naucs |
184 |
| - |
185 |
| - @staticmethod |
186 |
| - def evaluate_abstention( |
187 |
| - results: dict[str, dict[str, float]], |
188 |
| - metric_scores: dict[str, list[float]], |
189 |
| - ) -> dict[str, float]: |
190 |
| - """Computes normalized Area Under the Curve on a set of evaluated instances as presented in |
191 |
| - the paper https://arxiv.org/abs/2402.12997""" |
192 |
| - all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())] |
193 |
| - all_conf_scores = [confidence_scores(sim_scores) for sim_scores in all_sim_scores] |
194 |
| - conf_fcts = list(all_conf_scores[0].keys()) |
195 |
| - all_conf_scores = {fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts} |
196 |
| - metric_scores = {k: np.array(v) for k, v in metric_scores.items()} |
197 |
| - naucs = {} |
198 |
| - |
199 |
| - for metric_name, scores in metric_scores.items(): |
200 |
| - for fct, conf_scores in all_conf_scores.items(): |
201 |
| - naucs[f"nAUC_{metric_name}_{fct}"] = nAUC(conf_scores, scores) |
202 |
| - |
203 |
| - return naucs |
204 |
| - |
205 |
| - |
206 | 36 | class BenchmarkEvalCallback(WandbCallback):
|
207 | 37 | def __init__(
|
208 | 38 | self,
|
|
0 commit comments