|
32 | 32 | from imblearn.metrics import make_index_balanced_accuracy
|
33 | 33 | from imblearn.metrics import classification_report_imbalanced
|
34 | 34 |
|
35 |
| -from pytest import approx |
| 35 | +from pytest import approx, raises |
36 | 36 |
|
37 | 37 | RND_SEED = 42
|
38 | 38 | R_TOL = 1e-2
|
@@ -432,43 +432,22 @@ def test_classification_report_imbalanced_multiclass_with_long_string_label():
|
432 | 432 |
|
433 | 433 | def test_iba_sklearn_metrics():
|
434 | 434 | y_true, y_pred, _ = make_prediction(binary=True)
|
| 435 | + iba_scoring_func = make_index_balanced_accuracy(alpha=0.5, squared=True) |
| 436 | + expected_metric_result_pairs = ((accuracy_score, 0.54756), |
| 437 | + (jaccard_similarity_score, 0.54756), |
| 438 | + (precision_score, 0.65025), |
| 439 | + (recall_score, 0.41616000000000009)) |
435 | 440 |
|
436 |
| - acc = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
437 |
| - accuracy_score) |
438 |
| - score = acc(y_true, y_pred) |
439 |
| - assert score == approx(0.54756) |
440 |
| - |
441 |
| - jss = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
442 |
| - jaccard_similarity_score) |
443 |
| - score = jss(y_true, y_pred) |
444 |
| - assert score == approx(0.54756) |
445 |
| - |
446 |
| - pre = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
447 |
| - precision_score) |
448 |
| - score = pre(y_true, y_pred) |
449 |
| - assert score == approx(0.65025) |
450 |
| - |
451 |
| - rec = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
452 |
| - recall_score) |
453 |
| - score = rec(y_true, y_pred) |
454 |
| - assert score == approx(0.41616000000000009) |
| 441 | + for metric, expected_value in expected_metric_result_pairs: |
| 442 | + score = iba_scoring_func(metric)(y_true, y_pred) |
| 443 | + assert score == approx(expected_value) |
455 | 444 |
|
456 | 445 |
|
457 | 446 | def test_iba_error_y_score_prob():
|
458 | 447 | y_true, y_pred, _ = make_prediction(binary=True)
|
| 448 | + iba_scoring_func = make_index_balanced_accuracy(alpha=0.5, squared=True) |
459 | 449 |
|
460 |
| - aps = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
461 |
| - average_precision_score) |
462 |
| - assert_raises(AttributeError, aps, y_true, y_pred) |
463 |
| - |
464 |
| - brier = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
465 |
| - brier_score_loss) |
466 |
| - assert_raises(AttributeError, brier, y_true, y_pred) |
467 |
| - |
468 |
| - kappa = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
469 |
| - cohen_kappa_score) |
470 |
| - assert_raises(AttributeError, kappa, y_true, y_pred) |
471 |
| - |
472 |
| - ras = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
473 |
| - roc_auc_score) |
474 |
| - assert_raises(AttributeError, ras, y_true, y_pred) |
| 450 | + for score_func in (average_precision_score, brier_score_loss, |
| 451 | + cohen_kappa_score, roc_auc_score): |
| 452 | + with raises(AttributeError): |
| 453 | + iba_scoring_func(score_func)(y_true, y_pred) |
0 commit comments