Skip to content

Commit b4aba07

Browse files
committed
put back old functions
1 parent 73434cb commit b4aba07

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

colpali_engine/trainer/colmodel_training.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22
from dataclasses import dataclass
33
from typing import Callable, Dict, Optional, Tuple
44

5+
import torch
6+
from datasets import concatenate_datasets
57
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
8+
from torch.utils.data import DataLoader
9+
from tqdm import tqdm
610
from transformers import (
711
AutoTokenizer,
812
PreTrainedModel,
913
PreTrainedTokenizer,
1014
TrainingArguments,
1115
)
16+
from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA
17+
from vidore_benchmark.retrievers import VisionRetriever
1218

1319
from colpali_engine.collators import CorpusQueryCollator, VisualRetrieverCollator
1420
from colpali_engine.loss.late_interaction_losses import (
@@ -142,6 +148,112 @@ def train(self) -> None:
142148
result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint)
143149
print_summary(result)
144150

151+
def eval_dataset(self, test_dataset):
152+
self.model.eval()
153+
154+
idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None]
155+
idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None]
156+
157+
dataloader_with_query = DataLoader(
158+
test_dataset.select(idx_with_query),
159+
batch_size=self.config.tr_args.per_device_eval_batch_size,
160+
shuffle=False,
161+
collate_fn=self.collator,
162+
)
163+
dataloader_without_query = DataLoader(
164+
test_dataset.select(idx_without_query),
165+
batch_size=self.config.tr_args.per_device_eval_batch_size,
166+
shuffle=False,
167+
collate_fn=self.collator,
168+
)
169+
170+
# dataset is ordered so that non-null queries come first
171+
test_dataset = concatenate_datasets(
172+
[test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)]
173+
)
174+
175+
relevant_docs = {}
176+
docidx_2_docid = {}
177+
qsidx_2_query = []
178+
for idx, sample in enumerate(test_dataset):
179+
doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"]))
180+
# query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"]))
181+
if sample["query"] is not None:
182+
relevant_docs[str(idx)] = {doc_id: 1}
183+
qsidx_2_query.append(str(idx))
184+
docidx_2_docid[str(idx)] = doc_id
185+
186+
qs = []
187+
ps = []
188+
189+
device = self.model.device
190+
with torch.no_grad():
191+
for dataloader in [dataloader_with_query, dataloader_without_query]:
192+
for batch in tqdm(dataloader):
193+
# feed only kwargs with 'doc_' prefix
194+
doc = self.model(**{k[4:]: v.to(device) for k, v in batch.items() if k.startswith("doc")})
195+
ps.extend(list(torch.unbind(doc.to("cpu"))))
196+
197+
if "query_input_ids" in batch:
198+
query = self.model(
199+
input_ids=batch["query_input_ids"].to(device),
200+
attention_mask=batch["query_attention_mask"].to(device),
201+
)
202+
# variable len
203+
qs.extend(list(torch.unbind(query.to("cpu"))))
204+
205+
print("Embeddings computed, evaluating")
206+
scores = self.config.processor.score(qs, ps, device=self.model.device)
207+
# scores is 2d array of shape (n_queries, n_docs)
208+
# turn it into a dict
209+
results = {}
210+
assert scores.shape[0] == len(qsidx_2_query)
211+
for idx, scores_per_query in enumerate(scores):
212+
results[qsidx_2_query[idx]] = {
213+
docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query)
214+
}
215+
216+
# evaluate
217+
metrics = self.retrieval_evaluator.compute_mteb_metrics(relevant_docs, results)
218+
print("MTEB metrics:", metrics)
219+
220+
return metrics
221+
222+
def eval(self) -> None:
223+
all_metrics = {}
224+
try:
225+
print("Evaluating on validation set")
226+
metrics = self.eval_dataset(self.dataset["test"])
227+
print(f"Metrics for validation set: {metrics}")
228+
all_metrics["validation_set"] = metrics
229+
except Exception as e:
230+
print(f"Error evaluating validation set: {e}")
231+
232+
if self.config.eval_dataset_loader is not None:
233+
# Create a vision retriever with the current model checkpoint.
234+
vision_retriever = VisionRetriever(
235+
model=self.model,
236+
processor=self.config.processor,
237+
)
238+
if getattr(self.config.tr_args, "eval_dataset_format", "beir") == "beir":
239+
vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever)
240+
elif getattr(self.config.tr_args, "eval_dataset_format", "beir") == "qa":
241+
vidore_evaluator = ViDoReEvaluatorQA(vision_retriever)
242+
else:
243+
raise ValueError("eval_dataset_format must be 'beir' or 'qa'")
244+
245+
for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items():
246+
print(f"Evaluating {test_name}")
247+
test_ds = test_dataset_loading_func()
248+
metrics = vidore_evaluator.evaluate_dataset(
249+
ds=test_ds,
250+
batch_query=self.config.tr_args.per_device_eval_batch_size,
251+
batch_passage=self.config.tr_args.per_device_eval_batch_size,
252+
batch_score=self.config.tr_args.per_device_eval_batch_size,
253+
)
254+
all_metrics[test_name] = metrics
255+
print(f"Metrics for {test_name}: {metrics}")
256+
145257
def save(self, config_file):
146258
# save model
147259
self.model.save_pretrained(self.config.output_dir)

0 commit comments

Comments
 (0)