-
Notifications
You must be signed in to change notification settings - Fork 173
Use vidore benchmark to monitor performances during training #195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Recap from our conversation 👋🏼 Let's:
|
@QuentinJGMace |
@QuentinJGMace @tonywu71 updates ? |
b4aba07
to
dff8e6f
Compare
CHANGELOG.md
Outdated
### Changed | ||
|
||
- Warn about evaluation being different from Vidore, and do not store results to prevent confusion. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not true, update
f85bd9e
to
7b4b1aa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (2)
colpali_engine/utils/dataset_transformation.py:229
- The class is defined as TestSetFactoryBEIR, but the call refers to TestSetFactory, which could lead to a NameError. Please update the class name in the call to match TestSetFactoryBEIR.
ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
colpali_engine/trainer/eval_utils.py:110
- The attribute 'self.eval_collection' is not defined in this class. It looks like it should reference an existing attribute, possibly related to the evaluation dataset loader. Please verify and update the reference.
print(f"Error during benchmark evaluation on collection '{self.eval_collection}': {e}")
METRICS_TO_TRACK = [ | ||
"ndcg_at_1", | ||
"ndcg_at_3", | ||
"ndcg_at_5", | ||
"ndcg_at_10", | ||
"ndcg_at_50", | ||
"ndcg_at_100", | ||
"recall_at_1", | ||
"recall_at_3", | ||
"recall_at_5", | ||
"recall_at_10", | ||
"recall_at_50", | ||
"recall_at_100", | ||
"map_at_1", | ||
"map_at_3", | ||
"map_at_5", | ||
"map_at_10", | ||
"map_at_50", | ||
"map_at_100", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is quite a lot to keep track of in a wandb window, especially on multiple datasets.
What are the few ones we should keep ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would pick nDCG@k and recall@k for
except Exception as e: | ||
print(f"Error during benchmark evaluation on collection '{self.eval_collection}': {e}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update, not relevant anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool thanks !
Let's wait for the updates in the Dataset code to merge this, so that we adapt it, super cool work thanks !!
eval_dataset_loader=self.config.eval_dataset_loader, | ||
batch_query=self.config.tr_args.per_device_eval_batch_size, | ||
batch_passage=self.config.tr_args.per_device_eval_batch_size, | ||
batch_score=4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's super low, you can probably push this to 256 at least
max_length (int): Maximum sequence length for inputs. Default: 256. | ||
run_eval (bool): If True, runs evaluation. Default: True. | ||
run_train (bool): If True, runs training. Default: True. | ||
vidore_eval_frequency (int): Vidore evaluation frequency, must be a multiple of tr_args.eval_steps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe an assert to guarantee this ?
from peft import PeftModel | ||
from transformers import PreTrainedModel, TrainerControl, TrainerState, TrainingArguments | ||
from transformers.integrations import WandbCallback | ||
from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR, ViDoReEvaluatorQA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No circular imports anymore ?
print(f"\n=== Running benchmark evaluation at global step {state.global_step} ===") | ||
|
||
# Evaluate on a collection. | ||
if self.eval_dataset_loader is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for this to be none ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really, since this eval is deactivated by default I guess that a user wanting to use this should have specified eval datasets
|
||
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||
if state.global_step % self.eval_steps_frequency != 0: | ||
self.counter_eval += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do you use this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need, it's an artefact from a previous implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice work overall, thanks a ton @QuentinJGMace! A few comments to address but otherwise LTGM :)
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). | |||
|
|||
### Added | |||
|
|||
- Add the possibility for a user to evaluate a model on retrieval datasets (e.g ViDoRe benchmark) during its training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Add the possibility for a user to evaluate a model on retrieval datasets (e.g ViDoRe benchmark) during its training. | |
- Add `BenchmarkEvalCallback` to evaluate a model on retrieval datasets (e.g ViDoRe benchmark) during its training and display the metrics on Weight&Biases. |
dataset_format: str = "beir", | ||
): | ||
""" | ||
Callback to evaluate the model on a collection of datasets during training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add mention of Wandb in the docstring.
metrics_collection[test_name] = {k: v for k, v in metrics.items() if k in METRICS_TO_TRACK} | ||
print(f"Benchmark metrics for tests datasets at step {state.global_step}:") | ||
print(metrics_collection) | ||
print("logging metrics to wandb") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("logging metrics to wandb") |
) | ||
metrics_collection[test_name] = {k: v for k, v in metrics.items() if k in METRICS_TO_TRACK} | ||
print(f"Benchmark metrics for tests datasets at step {state.global_step}:") | ||
print(metrics_collection) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this look like? If it's not already formatted (since it's a dict), you can try pprint.pprint
if needed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think we could remove the print entirely, it looks a bit messy when training
@@ -210,6 +210,21 @@ def __call__(self, *args, **kwargs): | |||
return dataset | |||
|
|||
|
|||
class TestSetFactoryBEIR: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 things here:
- Need a short docstring here
- It's not a factory per se (a factory is a design pattern that provides a way to create objects without specifying the exact class of object that will be created, which is not the case here). I think we should keep the implementation as simple as possible. Here is a recommendation below.
def load_beir_test_dataset(dataset_path: str, split: str = "test") -> Dict[str, Dataset]:
return {
"corpus": cast(Dataset, load_dataset(dataset_path, name="corpus", split=split)),
"queries": cast(Dataset, load_dataset(dataset_path, name="queries", split=split)),
"qrels": cast(Dataset, load_dataset(dataset_path, name="qrels", split=split)),
}
Wdyt?
def __init__(self, dataset_path): | ||
self.dataset_path = dataset_path | ||
|
||
def __call__(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 questions here:
- Is there a reason for keeping
*args, **kwargs
? - Can we expose
split: str = "test"
in the__call__
args?
def __call__(self, *args, **kwargs): | |
def __call__(self, *args, **kwargs): |
Code to be able to monitor real retrieving metrics on datasets (e.g ViDoRe benchmark) during training.
This feature is deactivated by default and is designed for power users.
To use, simply add in your training config :
An example can be found at
scripts/configs/qwen2/train_colqwen2_model_eval_vidore.yaml