We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9146c1e commit 8fd9ff9Copy full SHA for 8fd9ff9
paddlenlp/trainer/trainer.py
@@ -2780,6 +2780,12 @@ def evaluation_loop(
2780
2781
# Metrics!
2782
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
2783
+ if self.args.tensor_parallel_degree > 1 and all_preds.shape != all_labels.shape:
2784
+ hcg = fleet.get_hybrid_communicate_group()
2785
+ model_parallel_group = hcg.get_model_parallel_group()
2786
+ gathered_predictions = []
2787
+ dist.all_gather(gathered_predictions, all_preds, group=model_parallel_group)
2788
+ all_preds = paddle.concat(gathered_predictions, axis=0)
2789
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
2790
else:
2791
metrics = {}
0 commit comments