-
Notifications
You must be signed in to change notification settings - Fork 173
Use vidore benchmark to monitor performances during training #195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bf100ea
1361ae0
8f1966d
8285b51
dc84ce4
553548a
8f5dbb0
3617b77
4033d4b
f53e296
f06ef10
400c3ec
9c4d1d0
3c40b4a
7819292
7b4b1aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,19 +13,43 @@ | |
ColbertLoss, | ||
) | ||
from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer | ||
from colpali_engine.trainer.eval_utils import BenchmarkEvalCallback, evaluate_dataset | ||
from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary | ||
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor | ||
|
||
|
||
@dataclass | ||
class ColModelTrainingConfig: | ||
"""Configuration for training a ColVision model. | ||
|
||
Args: | ||
model (Union[PreTrainedModel, PeftModel]): Base model to train. | ||
processor (BaseVisualRetrieverProcessor): Processor for visual data processing. | ||
tr_args (Optional[TrainingArguments]): Transformers training arguments. If not provided, uses default values. | ||
output_dir (Optional[str]): Output directory to save the model. | ||
If not provided, creates a path based on model name. | ||
max_length (int): Maximum sequence length for inputs. Default: 256. | ||
run_eval (bool): If True, runs evaluation. Default: True. | ||
run_train (bool): If True, runs training. Default: True. | ||
vidore_eval_frequency (int): Vidore evaluation frequency, must be a multiple of tr_args.eval_steps. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe an assert to guarantee this ? |
||
Pass -1 to disable. Default: -1. | ||
eval_dataset_format (str): Evaluation dataset format ("qa" or "beir"). Default: "qa". | ||
peft_config (Optional[LoraConfig]): PEFT configuration if used. Default: None. | ||
loss_func (Optional[Callable]): Custom loss function. Default: ColbertLoss(). | ||
dataset_loading_func (Optional[Callable]): Dataset loading function. Default: None. | ||
eval_dataset_loader (Optional[Dict[str, Callable]]): Evaluation dataset loaders. Default: None. | ||
pretrained_peft_model_name_or_path (Optional[str]): Path to a pretrained PEFT model. Default: None. | ||
""" | ||
|
||
model: Union[PreTrainedModel, PeftModel] | ||
processor: BaseVisualRetrieverProcessor | ||
tr_args: Optional[TrainingArguments] = None | ||
output_dir: Optional[str] = None | ||
max_length: int = 256 | ||
run_eval: bool = True | ||
run_train: bool = True | ||
vidore_eval_frequency: int = -1 | ||
eval_dataset_format: str = "qa" | ||
peft_config: Optional[LoraConfig] = None | ||
loss_func: Optional[Callable] = ColbertLoss() | ||
dataset_loading_func: Optional[Callable] = None | ||
|
@@ -118,16 +142,55 @@ def train(self) -> None: | |
|
||
trainer.args.remove_unused_columns = False | ||
|
||
if self.config.processor is not None and self.config.vidore_eval_frequency > 0: | ||
trainer.add_callback( | ||
BenchmarkEvalCallback( | ||
processor=self.config.processor, | ||
model=self.model, | ||
eval_dataset_loader=self.config.eval_dataset_loader, | ||
batch_query=self.config.tr_args.per_device_eval_batch_size, | ||
batch_passage=self.config.tr_args.per_device_eval_batch_size, | ||
batch_score=4, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's super low, you can probably push this to 256 at least |
||
run_frequency=self.config.vidore_eval_frequency, | ||
dataset_format=self.config.eval_dataset_format, | ||
) | ||
) | ||
|
||
result = trainer.train(resume_from_checkpoint=self.config.tr_args.resume_from_checkpoint) | ||
print_summary(result) | ||
|
||
def eval(self) -> None: | ||
raise NotImplementedError("Evaluation is not implemented yet.") | ||
all_metrics = {} | ||
|
||
all_metrics["validation_set"] = evaluate_dataset( | ||
model=self.model, | ||
processor=self.config.processor, | ||
dataset=self.dataset["test"], | ||
format="qa", | ||
batch_passage=self.config.tr_args.per_device_eval_batch_size, | ||
batch_query=self.config.tr_args.per_device_eval_batch_size, | ||
batch_score=self.config.tr_args.per_device_eval_batch_size, | ||
) | ||
|
||
if self.config.eval_dataset_loader is not None: | ||
# Create a vision retriever with the current model checkpoint. | ||
eval_dataset_format = getattr(self.config.tr_args, "eval_dataset_format", "beir") | ||
|
||
for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items(): | ||
print(f"Evaluating {test_name}") | ||
all_metrics[test_name] = evaluate_dataset( | ||
model=self.model, | ||
processor=self.config.processor, | ||
dataset=test_dataset_loading_func(), | ||
format=eval_dataset_format, | ||
batch_passage=self.config.tr_args.per_device_eval_batch_size, | ||
batch_query=self.config.tr_args.per_device_eval_batch_size, | ||
batch_score=self.config.tr_args.per_device_eval_batch_size, | ||
) | ||
print(f"Metrics for {test_name}: {all_metrics[test_name]}") | ||
|
||
def save(self, config_file: str): | ||
""" | ||
Save the model with its training config, as well as the tokenizer and processor if provided. | ||
""" | ||
# save model | ||
self.model.save_pretrained(self.config.output_dir) | ||
self.config.processor.save_pretrained(self.config.output_dir) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,151 @@ | ||||
# from mteb.evaluation.evaluators.RetrievalEvaluator | ||||
from __future__ import annotations | ||||
|
||||
import logging | ||||
from typing import Dict, Union | ||||
|
||||
from datasets import Dataset | ||||
from peft import PeftModel | ||||
from transformers import PreTrainedModel, TrainerControl, TrainerState, TrainingArguments | ||||
from transformers.integrations import WandbCallback | ||||
from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No circular imports anymore ? |
||||
from vidore_benchmark.retrievers import VisionRetriever | ||||
|
||||
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor | ||||
|
||||
logger = logging.getLogger(__name__) | ||||
|
||||
|
||||
METRICS_TO_TRACK = [ | ||||
"ndcg_at_1", | ||||
"ndcg_at_3", | ||||
"ndcg_at_5", | ||||
"ndcg_at_10", | ||||
"ndcg_at_50", | ||||
"ndcg_at_100", | ||||
"recall_at_1", | ||||
"recall_at_3", | ||||
"recall_at_5", | ||||
"recall_at_10", | ||||
"recall_at_50", | ||||
"recall_at_100", | ||||
"map_at_1", | ||||
"map_at_3", | ||||
"map_at_5", | ||||
"map_at_10", | ||||
"map_at_50", | ||||
"map_at_100", | ||||
] | ||||
Comment on lines
+19
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is quite a lot to keep track of in a wandb window, especially on multiple datasets. What are the few ones we should keep ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would pick nDCG@k and recall@k for |
||||
|
||||
|
||||
class BenchmarkEvalCallback(WandbCallback): | ||||
def __init__( | ||||
self, | ||||
processor, | ||||
model, | ||||
eval_dataset_loader, | ||||
batch_query: int = 4, | ||||
batch_passage: int = 4, | ||||
batch_score: int = 4, | ||||
run_frequency: int = 5, | ||||
dataset_format: str = "beir", | ||||
): | ||||
""" | ||||
Callback to evaluate the model on a collection of datasets during training. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add mention of Wandb in the docstring. |
||||
|
||||
Args: | ||||
processor: The processor to use for the model. | ||||
model: The model to evaluate. | ||||
eval_dataset_loader: A dictionary of dataset loading functions. | ||||
batch_query: Batch size for query. | ||||
batch_passage: Batch size for passage. | ||||
batch_score: Batch size for scoring. | ||||
run_frequency: Frequency of evaluation in steps. | ||||
dataset_format: Format of the evaluation dataset, either "beir" or "qa". | ||||
""" | ||||
self.processor = processor | ||||
self.model = model | ||||
self.eval_dataset_loader = eval_dataset_loader | ||||
self.batch_query = batch_query | ||||
self.batch_passage = batch_passage | ||||
self.batch_score = batch_score | ||||
self.eval_steps_frequency = run_frequency | ||||
self.counter_eval = 0 | ||||
self.eval_dataset_format = dataset_format | ||||
super().__init__() | ||||
|
||||
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||
if state.global_step % self.eval_steps_frequency != 0: | ||||
self.counter_eval += 1 | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When do you use this ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need, it's an artefact from a previous implementation |
||||
return | ||||
else: | ||||
self.counter_eval = 1 | ||||
|
||||
if self.processor is None: | ||||
print("Processor not provided. Skipping benchmark evaluation.") | ||||
return | ||||
|
||||
print(f"\n=== Running benchmark evaluation at global step {state.global_step} ===") | ||||
|
||||
# Evaluate on a collection. | ||||
if self.eval_dataset_loader is not None: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason for this to be none ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really, since this eval is deactivated by default I guess that a user wanting to use this should have specified eval datasets |
||||
try: | ||||
metrics_collection = {} | ||||
for test_name, test_dataset_loading_func in self.eval_dataset_loader.items(): | ||||
metrics = evaluate_dataset( | ||||
model=self.model, | ||||
processor=self.processor, | ||||
dataset=test_dataset_loading_func(), | ||||
format=self.eval_dataset_format, | ||||
batch_passage=self.batch_passage, | ||||
batch_query=self.batch_query, | ||||
batch_score=self.batch_score, | ||||
) | ||||
metrics_collection[test_name] = {k: v for k, v in metrics.items() if k in METRICS_TO_TRACK} | ||||
print(f"Benchmark metrics for tests datasets at step {state.global_step}:") | ||||
print(metrics_collection) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this look like? If it's not already formatted (since it's a dict), you can try There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think we could remove the print entirely, it looks a bit messy when training |
||||
print("logging metrics to wandb") | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
self._wandb.log(metrics_collection) | ||||
except Exception as e: | ||||
print(f"Error during benchmark evaluation on collection '{self.eval_collection}': {e}") | ||||
Comment on lines
+109
to
+110
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update, not relevant anymore |
||||
|
||||
# Set model back to train mode. | ||||
self.model.train() | ||||
return | ||||
|
||||
|
||||
def evaluate_dataset( | ||||
model: Union[PreTrainedModel, PeftModel], | ||||
processor: BaseVisualRetrieverProcessor, | ||||
dataset: Dataset, | ||||
format: str = "beir", | ||||
batch_passage: int = 4, | ||||
batch_query: int = 4, | ||||
batch_score: int = 4, | ||||
) -> Dict[str, float]: | ||||
""" | ||||
Evaluate a dataset using the vidore-benchmark library. | ||||
""" | ||||
model.eval() | ||||
|
||||
# Create a vision retriever with the current model checkpoint. | ||||
vision_retriever = VisionRetriever( | ||||
model=model, | ||||
processor=processor, | ||||
) | ||||
|
||||
if format == "qa": | ||||
vidore_evaluator = ViDoReEvaluatorQA(vision_retriever) | ||||
elif format == "beir": | ||||
vidore_evaluator = ViDoReEvaluatorBEIR(vision_retriever) | ||||
else: | ||||
raise ValueError(f"Invalid dataset format: {format}, must be 'qa' or 'beir'") | ||||
|
||||
metrics = vidore_evaluator.evaluate_dataset( | ||||
ds=dataset, | ||||
batch_query=batch_query, | ||||
batch_passage=batch_passage, | ||||
batch_score=batch_score, | ||||
) | ||||
|
||||
return metrics |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -210,6 +210,21 @@ def __call__(self, *args, **kwargs): | |||||
return dataset | ||||||
|
||||||
|
||||||
class TestSetFactoryBEIR: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 things here:
def load_beir_test_dataset(dataset_path: str, split: str = "test") -> Dict[str, Dataset]:
return {
"corpus": cast(Dataset, load_dataset(dataset_path, name="corpus", split=split)),
"queries": cast(Dataset, load_dataset(dataset_path, name="queries", split=split)),
"qrels": cast(Dataset, load_dataset(dataset_path, name="qrels", split=split)),
} Wdyt? |
||||||
def __init__(self, dataset_path): | ||||||
self.dataset_path = dataset_path | ||||||
|
||||||
def __call__(self, *args, **kwargs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 questions here:
Suggested change
|
||||||
split = "test" | ||||||
dataset = { | ||||||
"corpus": cast(Dataset, load_dataset(self.dataset_path, name="corpus", split=split)), | ||||||
"queries": cast(Dataset, load_dataset(self.dataset_path, name="queries", split=split)), | ||||||
"qrels": cast(Dataset, load_dataset(self.dataset_path, name="qrels", split=split)), | ||||||
} | ||||||
|
||||||
return dataset | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
ds = TestSetFactory("vidore/tabfquad_test_subsampled")() | ||||||
print(ds) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
config: | ||
(): colpali_engine.trainer.colmodel_training.ColModelTrainingConfig | ||
output_dir: !path ../../../models/colqwen2-ba256-5e-0304 | ||
processor: | ||
(): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper | ||
class_to_instanciate: !ext colpali_engine.models.ColQwen2Processor | ||
pretrained_model_name_or_path: "./models/base_models/colqwen2-base" | ||
max_num_visual_tokens: 1024 | ||
|
||
model: | ||
(): colpali_engine.utils.transformers_wrappers.AllPurposeWrapper | ||
class_to_instanciate: !ext colpali_engine.models.ColQwen2 | ||
pretrained_model_name_or_path: "./models/base_models/colqwen2-base" | ||
torch_dtype: !ext torch.bfloat16 | ||
use_cache: false | ||
attn_implementation: "flash_attention_2" | ||
|
||
dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set | ||
eval_dataset_loader: !import ../data/test_data.yaml | ||
vidore_eval_frequency: 200 | ||
eval_dataset_format: "qa" | ||
|
||
# max_length: 50 | ||
run_eval: true | ||
loss_func: | ||
(): colpali_engine.loss.late_interaction_losses.ColbertPairwiseCELoss | ||
tr_args: | ||
(): transformers.training_args.TrainingArguments | ||
output_dir: null | ||
overwrite_output_dir: true | ||
num_train_epochs: 5 | ||
per_device_train_batch_size: 64 | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: { "use_reentrant": false } | ||
# 6 x 8 gpus = 48 batch size | ||
# gradient_accumulation_steps: 4 | ||
per_device_eval_batch_size: 8 | ||
eval_strategy: "steps" | ||
dataloader_num_workers: 8 | ||
# bf16: true | ||
save_steps: 500 | ||
logging_steps: 10 | ||
eval_steps: 100 | ||
warmup_steps: 100 | ||
learning_rate: 2e-4 | ||
save_total_limit: 1 | ||
# resume_from_checkpoint: true | ||
# optim: "paged_adamw_8bit" | ||
# wandb logging | ||
# wandb_project: "colqwen2" | ||
# run_name: "colqwen2-ba32-nolora" | ||
report_to: "wandb" | ||
|
||
|
||
peft_config: | ||
(): peft.LoraConfig | ||
r: 32 | ||
lora_alpha: 32 | ||
lora_dropout: 0.1 | ||
init_lora_weights: "gaussian" | ||
bias: "none" | ||
task_type: "FEATURE_EXTRACTION" | ||
target_modules: '(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' | ||
# target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.