13
13
14
14
import numpy as np
15
15
from scipy import sparse
16
+ from sklearn .base import clone
16
17
from sklearn .preprocessing import OneHotEncoder , OrdinalEncoder
17
18
from sklearn .utils import _safe_indexing , check_array , check_random_state
18
19
from sklearn .utils .sparsefuncs_fast import (
@@ -393,6 +394,11 @@ class SMOTENC(SMOTE):
393
394
- mask array of shape (n_features, ) and ``bool`` dtype for which
394
395
``True`` indicates the categorical features.
395
396
397
+ categorical_encoder : estimator, default=None
398
+ One-hot encoder used to encode the categorical features. If `None`, a
399
+ :class:`~sklearn.preprocessing.OneHotEncoder` is used with default parameters
400
+ apart from `handle_unknown` which is set to 'ignore'.
401
+
396
402
{sampling_strategy}
397
403
398
404
{random_state}
@@ -431,6 +437,13 @@ class SMOTENC(SMOTE):
431
437
ohe_ : :class:`~sklearn.preprocessing.OneHotEncoder`
432
438
The one-hot encoder used to encode the categorical features.
433
439
440
+ .. deprecated:: 0.11
441
+ `ohe_` is deprecated in 0.11 and will be removed in 0.13. Use
442
+ `categorical_encoder_` instead.
443
+
444
+ categorical_encoder_ : estimator
445
+ The encoder used to encode the categorical features.
446
+
434
447
categorical_features_ : ndarray of shape (n_cat_features,), dtype=np.int64
435
448
Indices of the categorical features.
436
449
@@ -514,12 +527,17 @@ class SMOTENC(SMOTE):
514
527
_parameter_constraints : dict = {
515
528
** SMOTE ._parameter_constraints ,
516
529
"categorical_features" : ["array-like" ],
530
+ "categorical_encoder" : [
531
+ HasMethods (["fit_transform" , "inverse_transform" ]),
532
+ None ,
533
+ ],
517
534
}
518
535
519
536
def __init__ (
520
537
self ,
521
538
categorical_features ,
522
539
* ,
540
+ categorical_encoder = None ,
523
541
sampling_strategy = "auto" ,
524
542
random_state = None ,
525
543
k_neighbors = 5 ,
@@ -532,6 +550,7 @@ def __init__(
532
550
n_jobs = n_jobs ,
533
551
)
534
552
self .categorical_features = categorical_features
553
+ self .categorical_encoder = categorical_encoder
535
554
536
555
def _check_X_y (self , X , y ):
537
556
"""Overwrite the checking to let pass some string for categorical
@@ -603,17 +622,19 @@ def _fit_resample(self, X, y):
603
622
else :
604
623
dtype_ohe = np .float64
605
624
606
- self .ohe_ = OneHotEncoder ( handle_unknown = "ignore" , dtype = dtype_ohe )
607
- if hasattr ( self .ohe_ , "sparse_output" ):
608
- # scikit-learn >= 1.2
609
- self . ohe_ . set_params ( sparse_output = True )
625
+ if self .categorical_encoder is None :
626
+ self .categorical_encoder_ = OneHotEncoder (
627
+ handle_unknown = "ignore" , dtype = dtype_ohe
628
+ )
610
629
else :
611
- self .ohe_ . set_params ( sparse = True )
630
+ self .categorical_encoder_ = clone ( self . categorical_encoder )
612
631
613
632
# the input of the OneHotEncoder needs to be dense
614
- X_ohe = self .ohe_ .fit_transform (
633
+ X_ohe = self .categorical_encoder_ .fit_transform (
615
634
X_categorical .toarray () if sparse .issparse (X_categorical ) else X_categorical
616
635
)
636
+ if not sparse .issparse (X_ohe ):
637
+ X_ohe = sparse .csr_matrix (X_ohe , dtype = dtype_ohe )
617
638
618
639
# we can replace the 1 entries of the categorical features with the
619
640
# median of the standard deviation. It will ensure that whenever
@@ -636,7 +657,7 @@ def _fit_resample(self, X, y):
636
657
# reverse the encoding of the categorical features
637
658
X_res_cat = X_resampled [:, self .continuous_features_ .size :]
638
659
X_res_cat .data = np .ones_like (X_res_cat .data )
639
- X_res_cat_dec = self .ohe_ .inverse_transform (X_res_cat )
660
+ X_res_cat_dec = self .categorical_encoder_ .inverse_transform (X_res_cat )
640
661
641
662
if sparse .issparse (X ):
642
663
X_resampled = sparse .hstack (
@@ -695,7 +716,7 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
695
716
all_neighbors = nn_data [nn_num [rows ]]
696
717
697
718
categories_size = [self .continuous_features_ .size ] + [
698
- cat .size for cat in self .ohe_ .categories_
719
+ cat .size for cat in self .categorical_encoder_ .categories_
699
720
]
700
721
701
722
for start_idx , end_idx in zip (
@@ -714,6 +735,16 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
714
735
715
736
return X_new
716
737
738
+ @property
739
+ def ohe_ (self ):
740
+ """One-hot encoder used to encode the categorical features."""
741
+ warnings .warn (
742
+ "'ohe_' attribute has been deprecated in 0.11 and will be removed "
743
+ "in 0.13. Use 'categorical_encoder_' instead." ,
744
+ FutureWarning ,
745
+ )
746
+ return self .categorical_encoder_
747
+
717
748
718
749
@Substitution (
719
750
sampling_strategy = BaseOverSampler ._sampling_strategy_docstring ,
0 commit comments