Skip to content

Commit 8fd9ff9

Browse files
committed
[Distributed] metric calculation supports tp logits
1 parent 9146c1e commit 8fd9ff9

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,6 +2780,12 @@ def evaluation_loop(
27802780

27812781
# Metrics!
27822782
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
2783+
if self.args.tensor_parallel_degree > 1 and all_preds.shape != all_labels.shape:
2784+
hcg = fleet.get_hybrid_communicate_group()
2785+
model_parallel_group = hcg.get_model_parallel_group()
2786+
gathered_predictions = []
2787+
dist.all_gather(gathered_predictions, all_preds, group=model_parallel_group)
2788+
all_preds = paddle.concat(gathered_predictions, axis=0)
27832789
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
27842790
else:
27852791
metrics = {}

0 commit comments

Comments
 (0)