Skip to content

Commit 79107e8

Browse files
FIX make sure to accept "minority" as a valid strategy in over-samplers (#964)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 7cead9c commit 79107e8

File tree

5 files changed

+71
-1
lines changed

5 files changed

+71
-1
lines changed

doc/whats_new/v0.10.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
.. _changes_0_10:
22

3+
Version 0.10.1
4+
==============
5+
6+
**December 28, 2022**
7+
8+
Changelog
9+
---------
10+
11+
Bug fixes
12+
.........
13+
14+
- Fix a regression in over-sampler where the string `minority` was rejected as
15+
an unvalid sampling strategy.
16+
:pr:`964` by :user:`Prakhyath Bhandary <Prakhyath07>`.
17+
318
Version 0.10.0
419
==============
520

imblearn/over_sampling/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class BaseOverSampler(BaseSampler):
6161
_parameter_constraints: dict = {
6262
"sampling_strategy": [
6363
Interval(numbers.Real, 0, 1, closed="right"),
64-
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
64+
StrOptions({"auto", "minority", "not minority", "not majority", "all"}),
6565
Mapping,
6666
callable,
6767
],

imblearn/over_sampling/tests/test_random_over_sampler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import pytest
10+
from sklearn.datasets import make_classification
1011
from sklearn.utils._testing import (
1112
_convert_container,
1213
assert_allclose,
@@ -255,3 +256,20 @@ def test_random_over_sampler_shrinkage_error(data, shrinkage, err_msg):
255256
ros = RandomOverSampler(shrinkage=shrinkage)
256257
with pytest.raises(ValueError, match=err_msg):
257258
ros.fit_resample(X, y)
259+
260+
261+
@pytest.mark.parametrize(
262+
"sampling_strategy", ["auto", "minority", "not minority", "not majority", "all"]
263+
)
264+
def test_random_over_sampler_strings(sampling_strategy):
265+
"""Check that we support all supposed strings as `sampling_strategy` in
266+
a sampler inheriting from `BaseOverSampler`."""
267+
268+
X, y = make_classification(
269+
n_samples=100,
270+
n_clusters_per_class=1,
271+
n_classes=3,
272+
weights=[0.1, 0.3, 0.6],
273+
random_state=0,
274+
)
275+
RandomOverSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)

imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import pytest
10+
from sklearn.datasets import make_classification
1011
from sklearn.utils._testing import assert_array_equal
1112

1213
from imblearn.under_sampling import RandomUnderSampler
@@ -130,3 +131,20 @@ def test_random_under_sampling_nan_inf():
130131
assert y_res.shape == (6,)
131132
assert X_res.shape == (6, 2)
132133
assert np.any(~np.isfinite(X_res))
134+
135+
136+
@pytest.mark.parametrize(
137+
"sampling_strategy", ["auto", "majority", "not minority", "not majority", "all"]
138+
)
139+
def test_random_under_sampler_strings(sampling_strategy):
140+
"""Check that we support all supposed strings as `sampling_strategy` in
141+
a sampler inheriting from `BaseUnderSampler`."""
142+
143+
X, y = make_classification(
144+
n_samples=100,
145+
n_clusters_per_class=1,
146+
n_classes=3,
147+
weights=[0.1, 0.3, 0.6],
148+
random_state=0,
149+
)
150+
RandomUnderSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)

imblearn/under_sampling/_prototype_selection/tests/test_tomek_links.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# License: MIT
55

66
import numpy as np
7+
import pytest
8+
from sklearn.datasets import make_classification
79
from sklearn.utils._testing import assert_array_equal
810

911
from imblearn.under_sampling import TomekLinks
@@ -68,3 +70,20 @@ def test_tl_fit_resample():
6870
y_gt = np.array([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0])
6971
assert_array_equal(X_resampled, X_gt)
7072
assert_array_equal(y_resampled, y_gt)
73+
74+
75+
@pytest.mark.parametrize(
76+
"sampling_strategy", ["auto", "majority", "not minority", "not majority", "all"]
77+
)
78+
def test_tomek_links_strings(sampling_strategy):
79+
"""Check that we support all supposed strings as `sampling_strategy` in
80+
a sampler inheriting from `BaseCleaningSampler`."""
81+
82+
X, y = make_classification(
83+
n_samples=100,
84+
n_clusters_per_class=1,
85+
n_classes=3,
86+
weights=[0.1, 0.3, 0.6],
87+
random_state=0,
88+
)
89+
TomekLinks(sampling_strategy=sampling_strategy).fit_resample(X, y)

0 commit comments

Comments
 (0)