Skip to content

Commit d69acd5

Browse files
authored
ENH add categorical_encoder to SMOTEN (#1001)
1 parent a1d9f3c commit d69acd5

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

doc/whats_new/v0.11.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,9 @@ Enhancements
3333
allowing to specify a :class:`~sklearn.preprocessing.OneHotEncoder` with custom
3434
parameters.
3535
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.
36+
37+
- :class:`~imblearn.over_sampling.SMOTEN` now accepts a parameter `categorical_encoder`
38+
allowing to specify a :class:`~sklearn.preprocessing.OrdinalEncoder` with custom
39+
parameters. A new fitted parameter `categorical_encoder_` is exposed to access the
40+
fitted encoder.
41+
:pr:`1001` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/over_sampling/_smote/base.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,10 @@ class SMOTEN(SMOTE):
764764
765765
Parameters
766766
----------
767+
categorical_encoder : estimator, default=None
768+
Ordinal encoder used to encode the categorical features. If `None`, a
769+
:class:`~sklearn.preprocessing.OrdinalEncoder` is used with default parameters.
770+
767771
{sampling_strategy}
768772
769773
{random_state}
@@ -791,6 +795,9 @@ class SMOTEN(SMOTE):
791795
792796
Attributes
793797
----------
798+
categorical_encoder_ : estimator
799+
The encoder used to encode the categorical features.
800+
794801
sampling_strategy_ : dict
795802
Dictionary containing the information to sample the dataset. The keys
796803
corresponds to the class labels from which to sample and the values
@@ -853,6 +860,31 @@ class SMOTEN(SMOTE):
853860
Class counts after resampling Counter({{0: 40, 1: 40}})
854861
"""
855862

863+
_parameter_constraints: dict = {
864+
**SMOTE._parameter_constraints,
865+
"categorical_encoder": [
866+
HasMethods(["fit_transform", "inverse_transform"]),
867+
None,
868+
],
869+
}
870+
871+
def __init__(
872+
self,
873+
categorical_encoder=None,
874+
*,
875+
sampling_strategy="auto",
876+
random_state=None,
877+
k_neighbors=5,
878+
n_jobs=None,
879+
):
880+
super().__init__(
881+
sampling_strategy=sampling_strategy,
882+
random_state=random_state,
883+
k_neighbors=k_neighbors,
884+
n_jobs=n_jobs,
885+
)
886+
self.categorical_encoder = categorical_encoder
887+
856888
def _check_X_y(self, X, y):
857889
"""Check should accept strings and not sparse matrices."""
858890
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
@@ -900,11 +932,14 @@ def _fit_resample(self, X, y):
900932
X_resampled = [X.copy()]
901933
y_resampled = [y.copy()]
902934

903-
encoder = OrdinalEncoder(dtype=np.int32)
904-
X_encoded = encoder.fit_transform(X)
935+
if self.categorical_encoder is None:
936+
self.categorical_encoder_ = OrdinalEncoder(dtype=np.int32)
937+
else:
938+
self.categorical_encoder_ = clone(self.categorical_encoder)
939+
X_encoded = self.categorical_encoder_.fit_transform(X)
905940

906941
vdm = ValueDifferenceMetric(
907-
n_categories=[len(cat) for cat in encoder.categories_]
942+
n_categories=[len(cat) for cat in self.categorical_encoder_.categories_]
908943
).fit(X_encoded, y)
909944

910945
for class_sample, n_samples in self.sampling_strategy_.items():
@@ -922,7 +957,7 @@ def _fit_resample(self, X, y):
922957
X_class, class_sample, y.dtype, nn_indices, n_samples
923958
)
924959

925-
X_new = encoder.inverse_transform(X_new)
960+
X_new = self.categorical_encoder_.inverse_transform(X_new)
926961
X_resampled.append(X_new)
927962
y_resampled.append(y_new)
928963

imblearn/over_sampling/_smote/tests/test_smoten.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from sklearn.preprocessing import OrdinalEncoder
34

45
from imblearn.over_sampling import SMOTEN
56

@@ -27,6 +28,7 @@ def test_smoten(data):
2728

2829
assert X_res.shape == (80, 3)
2930
assert y_res.shape == (80,)
31+
assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)
3032

3133

3234
def test_smoten_resampling():
@@ -52,3 +54,22 @@ def test_smoten_resampling():
5254
X_generated, y_generated = X_res[X.shape[0] :], y_res[X.shape[0] :]
5355
np.testing.assert_array_equal(X_generated, "blue")
5456
np.testing.assert_array_equal(y_generated, "not apple")
57+
58+
59+
def test_smoten_categorical_encoder(data):
60+
"""Check that `categorical_encoder` is used when provided."""
61+
62+
X, y = data
63+
sampler = SMOTEN(random_state=0)
64+
sampler.fit_resample(X, y)
65+
66+
assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)
67+
assert sampler.categorical_encoder_.dtype == np.int32
68+
69+
encoder = OrdinalEncoder(dtype=np.int64)
70+
sampler.set_params(categorical_encoder=encoder).fit_resample(X, y)
71+
72+
assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)
73+
assert sampler.categorical_encoder is encoder
74+
assert sampler.categorical_encoder_ is not encoder
75+
assert sampler.categorical_encoder_.dtype == np.int64

0 commit comments

Comments
 (0)