|
2 | 2 | from dataclasses import dataclass
|
3 | 3 | from typing import Callable, Dict, Optional, Tuple
|
4 | 4 |
|
| 5 | +import torch |
| 6 | +from datasets import concatenate_datasets |
5 | 7 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 8 | +from torch.utils.data import DataLoader |
| 9 | +from tqdm import tqdm |
6 | 10 | from transformers import (
|
7 | 11 | AutoTokenizer,
|
8 | 12 | PreTrainedModel,
|
9 | 13 | PreTrainedTokenizer,
|
10 | 14 | TrainingArguments,
|
11 | 15 | )
|
| 16 | +from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA |
| 17 | +from vidore_benchmark.retrievers import VisionRetriever |
12 | 18 |
|
13 | 19 | from colpali_engine.collators import CorpusQueryCollator, VisualRetrieverCollator
|
14 | 20 | from colpali_engine.loss.late_interaction_losses import (
|
@@ -142,6 +148,112 @@ def train(self) -> None:
|
142 | 148 | result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint)
|
143 | 149 | print_summary(result)
|
144 | 150 |
|
| 151 | + def eval_dataset(self, test_dataset): |
| 152 | + self.model.eval() |
| 153 | + |
| 154 | + idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None] |
| 155 | + idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None] |
| 156 | + |
| 157 | + dataloader_with_query = DataLoader( |
| 158 | + test_dataset.select(idx_with_query), |
| 159 | + batch_size=self.config.tr_args.per_device_eval_batch_size, |
| 160 | + shuffle=False, |
| 161 | + collate_fn=self.collator, |
| 162 | + ) |
| 163 | + dataloader_without_query = DataLoader( |
| 164 | + test_dataset.select(idx_without_query), |
| 165 | + batch_size=self.config.tr_args.per_device_eval_batch_size, |
| 166 | + shuffle=False, |
| 167 | + collate_fn=self.collator, |
| 168 | + ) |
| 169 | + |
| 170 | + # dataset is ordered so that non-null queries come first |
| 171 | + test_dataset = concatenate_datasets( |
| 172 | + [test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)] |
| 173 | + ) |
| 174 | + |
| 175 | + relevant_docs = {} |
| 176 | + docidx_2_docid = {} |
| 177 | + qsidx_2_query = [] |
| 178 | + for idx, sample in enumerate(test_dataset): |
| 179 | + doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"])) |
| 180 | + # query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"])) |
| 181 | + if sample["query"] is not None: |
| 182 | + relevant_docs[str(idx)] = {doc_id: 1} |
| 183 | + qsidx_2_query.append(str(idx)) |
| 184 | + docidx_2_docid[str(idx)] = doc_id |
| 185 | + |
| 186 | + qs = [] |
| 187 | + ps = [] |
| 188 | + |
| 189 | + device = self.model.device |
| 190 | + with torch.no_grad(): |
| 191 | + for dataloader in [dataloader_with_query, dataloader_without_query]: |
| 192 | + for batch in tqdm(dataloader): |
| 193 | + # feed only kwargs with 'doc_' prefix |
| 194 | + doc = self.model(**{k[4:]: v.to(device) for k, v in batch.items() if k.startswith("doc")}) |
| 195 | + ps.extend(list(torch.unbind(doc.to("cpu")))) |
| 196 | + |
| 197 | + if "query_input_ids" in batch: |
| 198 | + query = self.model( |
| 199 | + input_ids=batch["query_input_ids"].to(device), |
| 200 | + attention_mask=batch["query_attention_mask"].to(device), |
| 201 | + ) |
| 202 | + # variable len |
| 203 | + qs.extend(list(torch.unbind(query.to("cpu")))) |
| 204 | + |
| 205 | + print("Embeddings computed, evaluating") |
| 206 | + scores = self.config.processor.score(qs, ps, device=self.model.device) |
| 207 | + # scores is 2d array of shape (n_queries, n_docs) |
| 208 | + # turn it into a dict |
| 209 | + results = {} |
| 210 | + assert scores.shape[0] == len(qsidx_2_query) |
| 211 | + for idx, scores_per_query in enumerate(scores): |
| 212 | + results[qsidx_2_query[idx]] = { |
| 213 | + docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query) |
| 214 | + } |
| 215 | + |
| 216 | + # evaluate |
| 217 | + metrics = self.retrieval_evaluator.compute_mteb_metrics(relevant_docs, results) |
| 218 | + print("MTEB metrics:", metrics) |
| 219 | + |
| 220 | + return metrics |
| 221 | + |
| 222 | + def eval(self) -> None: |
| 223 | + all_metrics = {} |
| 224 | + try: |
| 225 | + print("Evaluating on validation set") |
| 226 | + metrics = self.eval_dataset(self.dataset["test"]) |
| 227 | + print(f"Metrics for validation set: {metrics}") |
| 228 | + all_metrics["validation_set"] = metrics |
| 229 | + except Exception as e: |
| 230 | + print(f"Error evaluating validation set: {e}") |
| 231 | + |
| 232 | + if self.config.eval_dataset_loader is not None: |
| 233 | + # Create a vision retriever with the current model checkpoint. |
| 234 | + vision_retriever = VisionRetriever( |
| 235 | + model=self.model, |
| 236 | + processor=self.config.processor, |
| 237 | + ) |
| 238 | + if getattr(self.config.tr_args, "eval_dataset_format", "beir") == "beir": |
| 239 | + vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever) |
| 240 | + elif getattr(self.config.tr_args, "eval_dataset_format", "beir") == "qa": |
| 241 | + vidore_evaluator = ViDoReEvaluatorQA(vision_retriever) |
| 242 | + else: |
| 243 | + raise ValueError("eval_dataset_format must be 'beir' or 'qa'") |
| 244 | + |
| 245 | + for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items(): |
| 246 | + print(f"Evaluating {test_name}") |
| 247 | + test_ds = test_dataset_loading_func() |
| 248 | + metrics = vidore_evaluator.evaluate_dataset( |
| 249 | + ds=test_ds, |
| 250 | + batch_query=self.config.tr_args.per_device_eval_batch_size, |
| 251 | + batch_passage=self.config.tr_args.per_device_eval_batch_size, |
| 252 | + batch_score=self.config.tr_args.per_device_eval_batch_size, |
| 253 | + ) |
| 254 | + all_metrics[test_name] = metrics |
| 255 | + print(f"Metrics for {test_name}: {metrics}") |
| 256 | + |
145 | 257 | def save(self, config_file):
|
146 | 258 | # save model
|
147 | 259 | self.model.save_pretrained(self.config.output_dir)
|
|
0 commit comments