Skip to content

Commit 020f278

Browse files
authored
ENH add categorical_encoder param to SMOTENC (#1000)
1 parent cdf9327 commit 020f278

File tree

3 files changed

+107
-11
lines changed

3 files changed

+107
-11
lines changed

doc/whats_new/v0.11.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,20 @@ Changelog
99
Compatibility
1010
.............
1111

12-
- Maintenance release for be compatible with scikit-learn >= 1.3.0.
12+
- Maintenance release for being compatible with scikit-learn >= 1.3.0.
1313
:pr:`999` by :user:`Guillaume Lemaitre <glemaitre>`.
14+
15+
Enhancements
16+
............
17+
18+
- :class:`~imblearn.over_sampling.SMOTENC` now accepts a parameter `categorical_encoder`
19+
allowing to specify a :class:`~sklearn.preprocessing.OneHotEncoder` with custom
20+
parameters.
21+
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.
22+
23+
Deprecation
24+
...........
25+
26+
- The fitted attribute `ohe_` in :class:`~imblearn.over_sampling.SMOTENC` is deprecated
27+
and will be removed in version 0.13. Use `categorical_encoder_` instead.
28+
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/over_sampling/_smote/base.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
from scipy import sparse
16+
from sklearn.base import clone
1617
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
1718
from sklearn.utils import _safe_indexing, check_array, check_random_state
1819
from sklearn.utils.sparsefuncs_fast import (
@@ -393,6 +394,11 @@ class SMOTENC(SMOTE):
393394
- mask array of shape (n_features, ) and ``bool`` dtype for which
394395
``True`` indicates the categorical features.
395396
397+
categorical_encoder : estimator, default=None
398+
One-hot encoder used to encode the categorical features. If `None`, a
399+
:class:`~sklearn.preprocessing.OneHotEncoder` is used with default parameters
400+
apart from `handle_unknown` which is set to 'ignore'.
401+
396402
{sampling_strategy}
397403
398404
{random_state}
@@ -431,6 +437,13 @@ class SMOTENC(SMOTE):
431437
ohe_ : :class:`~sklearn.preprocessing.OneHotEncoder`
432438
The one-hot encoder used to encode the categorical features.
433439
440+
.. deprecated:: 0.11
441+
`ohe_` is deprecated in 0.11 and will be removed in 0.13. Use
442+
`categorical_encoder_` instead.
443+
444+
categorical_encoder_ : estimator
445+
The encoder used to encode the categorical features.
446+
434447
categorical_features_ : ndarray of shape (n_cat_features,), dtype=np.int64
435448
Indices of the categorical features.
436449
@@ -514,12 +527,17 @@ class SMOTENC(SMOTE):
514527
_parameter_constraints: dict = {
515528
**SMOTE._parameter_constraints,
516529
"categorical_features": ["array-like"],
530+
"categorical_encoder": [
531+
HasMethods(["fit_transform", "inverse_transform"]),
532+
None,
533+
],
517534
}
518535

519536
def __init__(
520537
self,
521538
categorical_features,
522539
*,
540+
categorical_encoder=None,
523541
sampling_strategy="auto",
524542
random_state=None,
525543
k_neighbors=5,
@@ -532,6 +550,7 @@ def __init__(
532550
n_jobs=n_jobs,
533551
)
534552
self.categorical_features = categorical_features
553+
self.categorical_encoder = categorical_encoder
535554

536555
def _check_X_y(self, X, y):
537556
"""Overwrite the checking to let pass some string for categorical
@@ -603,17 +622,19 @@ def _fit_resample(self, X, y):
603622
else:
604623
dtype_ohe = np.float64
605624

606-
self.ohe_ = OneHotEncoder(handle_unknown="ignore", dtype=dtype_ohe)
607-
if hasattr(self.ohe_, "sparse_output"):
608-
# scikit-learn >= 1.2
609-
self.ohe_.set_params(sparse_output=True)
625+
if self.categorical_encoder is None:
626+
self.categorical_encoder_ = OneHotEncoder(
627+
handle_unknown="ignore", dtype=dtype_ohe
628+
)
610629
else:
611-
self.ohe_.set_params(sparse=True)
630+
self.categorical_encoder_ = clone(self.categorical_encoder)
612631

613632
# the input of the OneHotEncoder needs to be dense
614-
X_ohe = self.ohe_.fit_transform(
633+
X_ohe = self.categorical_encoder_.fit_transform(
615634
X_categorical.toarray() if sparse.issparse(X_categorical) else X_categorical
616635
)
636+
if not sparse.issparse(X_ohe):
637+
X_ohe = sparse.csr_matrix(X_ohe, dtype=dtype_ohe)
617638

618639
# we can replace the 1 entries of the categorical features with the
619640
# median of the standard deviation. It will ensure that whenever
@@ -636,7 +657,7 @@ def _fit_resample(self, X, y):
636657
# reverse the encoding of the categorical features
637658
X_res_cat = X_resampled[:, self.continuous_features_.size :]
638659
X_res_cat.data = np.ones_like(X_res_cat.data)
639-
X_res_cat_dec = self.ohe_.inverse_transform(X_res_cat)
660+
X_res_cat_dec = self.categorical_encoder_.inverse_transform(X_res_cat)
640661

641662
if sparse.issparse(X):
642663
X_resampled = sparse.hstack(
@@ -695,7 +716,7 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
695716
all_neighbors = nn_data[nn_num[rows]]
696717

697718
categories_size = [self.continuous_features_.size] + [
698-
cat.size for cat in self.ohe_.categories_
719+
cat.size for cat in self.categorical_encoder_.categories_
699720
]
700721

701722
for start_idx, end_idx in zip(
@@ -714,6 +735,16 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
714735

715736
return X_new
716737

738+
@property
739+
def ohe_(self):
740+
"""One-hot encoder used to encode the categorical features."""
741+
warnings.warn(
742+
"'ohe_' attribute has been deprecated in 0.11 and will be removed "
743+
"in 0.13. Use 'categorical_encoder_' instead.",
744+
FutureWarning,
745+
)
746+
return self.categorical_encoder_
747+
717748

718749
@Substitution(
719750
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,

imblearn/over_sampling/_smote/tests/test_smote_nc.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,20 @@
88

99
import numpy as np
1010
import pytest
11+
import sklearn
1112
from scipy import sparse
1213
from sklearn.datasets import make_classification
14+
from sklearn.preprocessing import OneHotEncoder
1315
from sklearn.utils._testing import assert_allclose, assert_array_equal
16+
from sklearn.utils.fixes import parse_version
1417

1518
from imblearn.over_sampling import SMOTENC
19+
from imblearn.utils.estimator_checks import (
20+
_set_checking_parameters,
21+
check_param_validation,
22+
)
23+
24+
sklearn_version = parse_version(sklearn.__version__)
1625

1726

1827
def data_heterogneous_ordered():
@@ -182,8 +191,7 @@ def test_smotenc_pandas():
182191
smote = SMOTENC(categorical_features=categorical_features, random_state=0)
183192
X_res_pd, y_res_pd = smote.fit_resample(X_pd, y)
184193
X_res, y_res = smote.fit_resample(X, y)
185-
# FIXME: we should use to_numpy with pandas >= 0.25
186-
assert_array_equal(X_res_pd.values, X_res)
194+
assert_array_equal(X_res_pd.to_numpy(), X_res)
187195
assert_allclose(y_res_pd, y_res)
188196

189197

@@ -240,3 +248,45 @@ def test_smote_nc_with_null_median_std():
240248
# check that the categorical feature is not random but correspond to the
241249
# categories seen in the minority class samples
242250
assert X_res[-1, -1] == "C"
251+
252+
253+
def test_smotenc_categorical_encoder():
254+
"""Check that we can pass our own categorical encoder."""
255+
256+
# TODO: only use `sparse_output` when sklearn >= 1.2
257+
param = "sparse" if sklearn_version < parse_version("1.2") else "sparse_output"
258+
259+
X, y, categorical_features = data_heterogneous_unordered()
260+
smote = SMOTENC(categorical_features=categorical_features, random_state=0)
261+
smote.fit_resample(X, y)
262+
263+
assert getattr(smote.categorical_encoder_, param) is True
264+
265+
encoder = OneHotEncoder()
266+
encoder.set_params(**{param: False})
267+
smote.set_params(categorical_encoder=encoder).fit_resample(X, y)
268+
assert smote.categorical_encoder is encoder
269+
assert smote.categorical_encoder_ is not encoder
270+
assert getattr(smote.categorical_encoder_, param) is False
271+
272+
273+
# TODO(0.13): remove this test
274+
def test_smotenc_deprecation_ohe_():
275+
"""Check that we raise a deprecation warning when using `ohe_`."""
276+
X, y, categorical_features = data_heterogneous_unordered()
277+
smote = SMOTENC(categorical_features=categorical_features, random_state=0)
278+
smote.fit_resample(X, y)
279+
280+
with pytest.warns(FutureWarning, match="'ohe_' attribute has been deprecated"):
281+
smote.ohe_
282+
283+
284+
def test_smotenc_param_validation():
285+
"""Check that we validate the parameters correctly since this estimator requires
286+
a specific parameter.
287+
"""
288+
categorical_features = [0]
289+
smote = SMOTENC(categorical_features=categorical_features, random_state=0)
290+
name = smote.__class__.__name__
291+
_set_checking_parameters(smote)
292+
check_param_validation(name, smote)

0 commit comments

Comments
 (0)