Skip to content

Commit 92b5305

Browse files
authored
ENH allow any dtype in input from RandomSampler (#1004)
1 parent 758cd92 commit 92b5305

File tree

11 files changed

+99
-24
lines changed

11 files changed

+99
-24
lines changed

doc/whats_new/v0.11.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,8 @@ Enhancements
4343
parameters. A new fitted parameter `categorical_encoder_` is exposed to access the
4444
fitted encoder.
4545
:pr:`1001` by :user:`Guillaume Lemaitre <glemaitre>`.
46+
47+
- :class:`~imblearn.under_sampling.RandomUnderSampler` and
48+
:class:`~imblearn.over_sampling.RandomOverSampler` (when `shrinkage is not
49+
None`) now accept any data types and will not attempt any data conversion.
50+
:pr:`1004` by :user:`Guillaume Lemaitre <glemaitre>`.

examples/api/plot_sampling_strategy_usage.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,9 @@
5959
# resampling and the number of samples in the minority class, respectively.
6060

6161
# %%
62-
import numpy as np
6362

6463
# select only 2 classes since the ratio make sense in this case
65-
binary_mask = np.bitwise_or(y == 0, y == 2)
64+
binary_mask = y.isin([0, 1])
6665
binary_y = y[binary_mask]
6766
binary_X = X[binary_mask]
6867

imblearn/datasets/tests/test_imbalance.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,14 @@ def test_make_imbalance_dict(iris, sampling_strategy, expected_counts):
6767
],
6868
)
6969
def test_make_imbalanced_iris(as_frame, sampling_strategy, expected_counts):
70-
pytest.importorskip("pandas")
71-
iris = load_iris(as_frame=True)
70+
pd = pytest.importorskip("pandas")
71+
iris = load_iris(as_frame=as_frame)
7272
X, y = iris.data, iris.target
7373
y = iris.target_names[iris.target]
74+
if as_frame:
75+
y = pd.Series(iris.target_names[iris.target], name="target")
7476
X_res, y_res = make_imbalance(X, y, sampling_strategy=sampling_strategy)
7577
if as_frame:
7678
assert hasattr(X_res, "loc")
79+
pd.testing.assert_index_equal(X_res.index, y_res.index)
7780
assert Counter(y_res) == expected_counts

imblearn/ensemble/tests/test_bagging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,11 +572,12 @@ def roughly_balanced_bagging(X, y, replace=False):
572572

573573
# Roughly Balanced Bagging
574574
rbb = BalancedBaggingClassifier(
575-
estimator=CountDecisionTreeClassifier(),
575+
estimator=CountDecisionTreeClassifier(random_state=0),
576576
n_estimators=2,
577577
sampler=FunctionSampler(
578578
func=roughly_balanced_bagging, kw_args={"replace": replace}
579579
),
580+
random_state=0,
580581
)
581582
rbb.fit(X, y)
582583

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..utils import Substitution, check_target_type
1616
from ..utils._docstring import _random_state_docstring
1717
from ..utils._param_validation import Interval
18+
from ..utils._validation import _check_X
1819
from .base import BaseOverSampler
1920

2021

@@ -154,14 +155,9 @@ def __init__(
154155

155156
def _check_X_y(self, X, y):
156157
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
157-
X, y = self._validate_data(
158-
X,
159-
y,
160-
reset=True,
161-
accept_sparse=["csr", "csc"],
162-
dtype=None,
163-
force_all_finite=False,
164-
)
158+
X = _check_X(X)
159+
self._check_n_features(X, reset=True)
160+
self._check_feature_names(X, reset=True)
165161
return X, y, binarize_y
166162

167163
def _fit_resample(self, X, y):
@@ -258,4 +254,7 @@ def _more_tags(self):
258254
"X_types": ["2darray", "string", "sparse", "dataframe"],
259255
"sample_indices": True,
260256
"allow_nan": True,
257+
"_xfail_checks": {
258+
"check_complex_data": "Robust to this type of data.",
259+
},
261260
}

imblearn/over_sampling/_smote/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ...utils import Substitution, check_neighbors_object, check_target_type
2828
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
2929
from ...utils._param_validation import HasMethods, Interval
30+
from ...utils._validation import _check_X
3031
from ...utils.fixes import _mode
3132
from ..base import BaseOverSampler
3233

@@ -559,9 +560,9 @@ def _check_X_y(self, X, y):
559560
features.
560561
"""
561562
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
562-
if not (hasattr(X, "__array__") or sparse.issparse(X)):
563-
X = check_array(X, dtype=object)
563+
X = _check_X(X)
564564
self._check_n_features(X, reset=True)
565+
self._check_feature_names(X, reset=True)
565566
return X, y, binarize_y
566567

567568
def _validate_estimator(self):

imblearn/over_sampling/tests/test_random_over_sampler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# License: MIT
55

66
from collections import Counter
7+
from datetime import datetime
78

89
import numpy as np
910
import pytest
@@ -273,3 +274,16 @@ def test_random_over_sampler_strings(sampling_strategy):
273274
random_state=0,
274275
)
275276
RandomOverSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)
277+
278+
279+
def test_random_over_sampling_datetime():
280+
"""Check that we don't convert input data and only sample from it."""
281+
pd = pytest.importorskip("pandas")
282+
X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4})
283+
y = X["label"]
284+
ros = RandomOverSampler(random_state=0)
285+
X_res, y_res = ros.fit_resample(X, y)
286+
287+
pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
288+
pd.testing.assert_index_equal(X_res.index, y_res.index)
289+
assert_array_equal(y_res.to_numpy(), np.array([0, 0, 0, 1, 1, 1]))

imblearn/under_sampling/_prototype_selection/_random_under_sampler.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ...utils import Substitution, check_target_type
1111
from ...utils._docstring import _random_state_docstring
12+
from ...utils._validation import _check_X
1213
from ..base import BaseUnderSampler
1314

1415

@@ -97,14 +98,9 @@ def __init__(
9798

9899
def _check_X_y(self, X, y):
99100
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
100-
X, y = self._validate_data(
101-
X,
102-
y,
103-
reset=True,
104-
accept_sparse=["csr", "csc"],
105-
dtype=None,
106-
force_all_finite=False,
107-
)
101+
X = _check_X(X)
102+
self._check_n_features(X, reset=True)
103+
self._check_feature_names(X, reset=True)
108104
return X, y, binarize_y
109105

110106
def _fit_resample(self, X, y):
@@ -140,4 +136,7 @@ def _more_tags(self):
140136
"X_types": ["2darray", "string", "sparse", "dataframe"],
141137
"sample_indices": True,
142138
"allow_nan": True,
139+
"_xfail_checks": {
140+
"check_complex_data": "Robust to this type of data.",
141+
},
143142
}

imblearn/under_sampling/_prototype_selection/tests/test_random_under_sampler.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# License: MIT
55

66
from collections import Counter
7+
from datetime import datetime
78

89
import numpy as np
910
import pytest
@@ -148,3 +149,16 @@ def test_random_under_sampler_strings(sampling_strategy):
148149
random_state=0,
149150
)
150151
RandomUnderSampler(sampling_strategy=sampling_strategy).fit_resample(X, y)
152+
153+
154+
def test_random_under_sampling_datetime():
155+
"""Check that we don't convert input data and only sample from it."""
156+
pd = pytest.importorskip("pandas")
157+
X = pd.DataFrame({"label": [0, 0, 0, 1], "td": [datetime.now()] * 4})
158+
y = X["label"]
159+
rus = RandomUnderSampler(random_state=0)
160+
X_res, y_res = rus.fit_resample(X, y)
161+
162+
pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
163+
pd.testing.assert_index_equal(X_res.index, y_res.index)
164+
assert_array_equal(y_res.to_numpy(), np.array([0, 1]))

imblearn/utils/_validation.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
import numpy as np
1313
from sklearn.base import clone
1414
from sklearn.neighbors import NearestNeighbors
15-
from sklearn.utils import column_or_1d
15+
from sklearn.utils import check_array, column_or_1d
1616
from sklearn.utils.multiclass import type_of_target
17+
from sklearn.utils.validation import _num_samples
18+
19+
from .fixes import _is_pandas_df
1720

1821
SAMPLING_KIND = (
1922
"over-sampling",
@@ -35,6 +38,12 @@ def __init__(self, X, y):
3538
def transform(self, X, y):
3639
X = self._transfrom_one(X, self.x_props)
3740
y = self._transfrom_one(y, self.y_props)
41+
if self.x_props["type"].lower() == "dataframe" and self.y_props[
42+
"type"
43+
].lower() in {"series", "dataframe"}:
44+
# We lost the y.index during resampling. We can safely use X.index to align
45+
# them.
46+
y.index = X.index
3847
return X, y
3948

4049
def _gets_props(self, array):
@@ -607,3 +616,18 @@ def inner_f(*args, **kwargs):
607616
return f(**kwargs)
608617

609618
return inner_f
619+
620+
621+
def _check_X(X):
622+
"""Check X and do not check it if a dataframe."""
623+
n_samples = _num_samples(X)
624+
if n_samples < 1:
625+
raise ValueError(
626+
f"Found array with {n_samples} sample(s) while a minimum of 1 is "
627+
"required."
628+
)
629+
if _is_pandas_df(X):
630+
return X
631+
return check_array(
632+
X, dtype=None, accept_sparse=["csr", "csc"], force_all_finite=False
633+
)

imblearn/utils/fixes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
which the fix is no longer needed.
66
"""
77
import functools
8+
import sys
89

910
import numpy as np
1011
import scipy
@@ -132,3 +133,18 @@ def _is_fitted(estimator, attributes=None, all_or_any=all):
132133

133134
else:
134135
from sklearn.utils.validation import _is_fitted # type: ignore[no-redef]
136+
137+
try:
138+
from sklearn.utils.validation import _is_pandas_df
139+
except ImportError:
140+
141+
def _is_pandas_df(X):
142+
"""Return True if the X is a pandas dataframe."""
143+
if hasattr(X, "columns") and hasattr(X, "iloc"):
144+
# Likely a pandas DataFrame, we explicitly check the type to confirm.
145+
try:
146+
pd = sys.modules["pandas"]
147+
except KeyError:
148+
return False
149+
return isinstance(X, pd.DataFrame)
150+
return False

0 commit comments

Comments
 (0)