Skip to content

Commit a1d9f3c

Browse files
authored
FIX handle heterogeneous data type in categorical feature in SMOTENC (#1002)
1 parent 020f278 commit a1d9f3c

File tree

3 files changed

+55
-13
lines changed

3 files changed

+55
-13
lines changed

doc/whats_new/v0.11.rst

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,30 @@ Version 0.11.0 (Under development)
66
Changelog
77
---------
88

9+
Bug fixes
10+
.........
11+
12+
- :class:`~imblearn.over_sampling.SMOTENC` now handles mix types of data type such as
13+
`bool` and `pd.category` by delegating the conversion to scikit-learn encoder.
14+
:pr:`1002` by :user:`Guillaume Lemaitre <glemaitre>`.
15+
916
Compatibility
1017
.............
1118

1219
- Maintenance release for being compatible with scikit-learn >= 1.3.0.
1320
:pr:`999` by :user:`Guillaume Lemaitre <glemaitre>`.
1421

22+
Deprecation
23+
...........
24+
25+
- The fitted attribute `ohe_` in :class:`~imblearn.over_sampling.SMOTENC` is deprecated
26+
and will be removed in version 0.13. Use `categorical_encoder_` instead.
27+
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.
28+
1529
Enhancements
1630
............
1731

1832
- :class:`~imblearn.over_sampling.SMOTENC` now accepts a parameter `categorical_encoder`
1933
allowing to specify a :class:`~sklearn.preprocessing.OneHotEncoder` with custom
2034
parameters.
2135
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.
22-
23-
Deprecation
24-
...........
25-
26-
- The fitted attribute `ohe_` in :class:`~imblearn.over_sampling.SMOTENC` is deprecated
27-
and will be removed in version 0.13. Use `categorical_encoder_` instead.
28-
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/over_sampling/_smote/base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
csc_mean_variance_axis0,
2121
csr_mean_variance_axis0,
2222
)
23+
from sklearn.utils.validation import _num_features
2324

2425
from ...metrics.pairwise import ValueDifferenceMetric
2526
from ...utils import Substitution, check_neighbors_object, check_target_type
@@ -557,9 +558,9 @@ def _check_X_y(self, X, y):
557558
features.
558559
"""
559560
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
560-
X, y = self._validate_data(
561-
X, y, reset=True, dtype=None, accept_sparse=["csr", "csc"]
562-
)
561+
if not (hasattr(X, "__array__") or sparse.issparse(X)):
562+
X = check_array(X, dtype=object)
563+
self._check_n_features(X, reset=True)
563564
return X, y, binarize_y
564565

565566
def _validate_estimator(self):
@@ -596,14 +597,14 @@ def _fit_resample(self, X, y):
596597
FutureWarning,
597598
)
598599

599-
self.n_features_ = X.shape[1]
600+
self.n_features_ = _num_features(X)
600601
self._validate_estimator()
601602

602603
# compute the median of the standard deviation of the minority class
603604
target_stats = Counter(y)
604605
class_minority = min(target_stats, key=target_stats.get)
605606

606-
X_continuous = X[:, self.continuous_features_]
607+
X_continuous = _safe_indexing(X, self.continuous_features_, axis=1)
607608
X_continuous = check_array(X_continuous, accept_sparse=["csr", "csc"])
608609
X_minority = _safe_indexing(X_continuous, np.flatnonzero(y == class_minority))
609610

@@ -616,7 +617,7 @@ def _fit_resample(self, X, y):
616617
var = X_minority.var(axis=0)
617618
self.median_std_ = np.median(np.sqrt(var))
618619

619-
X_categorical = X[:, self.categorical_features_]
620+
X_categorical = _safe_indexing(X, self.categorical_features_, axis=1)
620621
if X_continuous.dtype.name != "object":
621622
dtype_ohe = X_continuous.dtype
622623
else:

imblearn/over_sampling/_smote/tests/test_smote_nc.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,37 @@ def test_smotenc_param_validation():
290290
name = smote.__class__.__name__
291291
_set_checking_parameters(smote)
292292
check_param_validation(name, smote)
293+
294+
295+
def test_smotenc_bool_categorical():
296+
"""Check that we don't try to early convert the full input data to numeric when
297+
handling a pandas dataframe.
298+
299+
Non-regression test for:
300+
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/974
301+
"""
302+
pd = pytest.importorskip("pandas")
303+
304+
X = pd.DataFrame(
305+
{
306+
"c": pd.Categorical([x for x in "abbacaba" * 3]),
307+
"f": [0.3, 0.5, 0.1, 0.2] * 6,
308+
"b": [False, False, True] * 8,
309+
}
310+
)
311+
y = pd.DataFrame({"out": [1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0] * 2})
312+
smote = SMOTENC(categorical_features=[0])
313+
314+
X_res, y_res = smote.fit_resample(X, y)
315+
pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
316+
assert len(X_res) == len(y_res)
317+
318+
smote.set_params(categorical_features=[0, 2])
319+
X_res, y_res = smote.fit_resample(X, y)
320+
pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
321+
assert len(X_res) == len(y_res)
322+
323+
X = X.astype({"b": "category"})
324+
X_res, y_res = smote.fit_resample(X, y)
325+
pd.testing.assert_series_equal(X_res.dtypes, X.dtypes)
326+
assert len(X_res) == len(y_res)

0 commit comments

Comments
 (0)