@@ -764,6 +764,10 @@ class SMOTEN(SMOTE):
764
764
765
765
Parameters
766
766
----------
767
+ categorical_encoder : estimator, default=None
768
+ Ordinal encoder used to encode the categorical features. If `None`, a
769
+ :class:`~sklearn.preprocessing.OrdinalEncoder` is used with default parameters.
770
+
767
771
{sampling_strategy}
768
772
769
773
{random_state}
@@ -791,6 +795,9 @@ class SMOTEN(SMOTE):
791
795
792
796
Attributes
793
797
----------
798
+ categorical_encoder_ : estimator
799
+ The encoder used to encode the categorical features.
800
+
794
801
sampling_strategy_ : dict
795
802
Dictionary containing the information to sample the dataset. The keys
796
803
corresponds to the class labels from which to sample and the values
@@ -853,6 +860,31 @@ class SMOTEN(SMOTE):
853
860
Class counts after resampling Counter({{0: 40, 1: 40}})
854
861
"""
855
862
863
+ _parameter_constraints : dict = {
864
+ ** SMOTE ._parameter_constraints ,
865
+ "categorical_encoder" : [
866
+ HasMethods (["fit_transform" , "inverse_transform" ]),
867
+ None ,
868
+ ],
869
+ }
870
+
871
+ def __init__ (
872
+ self ,
873
+ categorical_encoder = None ,
874
+ * ,
875
+ sampling_strategy = "auto" ,
876
+ random_state = None ,
877
+ k_neighbors = 5 ,
878
+ n_jobs = None ,
879
+ ):
880
+ super ().__init__ (
881
+ sampling_strategy = sampling_strategy ,
882
+ random_state = random_state ,
883
+ k_neighbors = k_neighbors ,
884
+ n_jobs = n_jobs ,
885
+ )
886
+ self .categorical_encoder = categorical_encoder
887
+
856
888
def _check_X_y (self , X , y ):
857
889
"""Check should accept strings and not sparse matrices."""
858
890
y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
@@ -900,11 +932,14 @@ def _fit_resample(self, X, y):
900
932
X_resampled = [X .copy ()]
901
933
y_resampled = [y .copy ()]
902
934
903
- encoder = OrdinalEncoder (dtype = np .int32 )
904
- X_encoded = encoder .fit_transform (X )
935
+ if self .categorical_encoder is None :
936
+ self .categorical_encoder_ = OrdinalEncoder (dtype = np .int32 )
937
+ else :
938
+ self .categorical_encoder_ = clone (self .categorical_encoder )
939
+ X_encoded = self .categorical_encoder_ .fit_transform (X )
905
940
906
941
vdm = ValueDifferenceMetric (
907
- n_categories = [len (cat ) for cat in encoder .categories_ ]
942
+ n_categories = [len (cat ) for cat in self . categorical_encoder_ .categories_ ]
908
943
).fit (X_encoded , y )
909
944
910
945
for class_sample , n_samples in self .sampling_strategy_ .items ():
@@ -922,7 +957,7 @@ def _fit_resample(self, X, y):
922
957
X_class , class_sample , y .dtype , nn_indices , n_samples
923
958
)
924
959
925
- X_new = encoder .inverse_transform (X_new )
960
+ X_new = self . categorical_encoder_ .inverse_transform (X_new )
926
961
X_resampled .append (X_new )
927
962
y_resampled .append (y_new )
928
963
0 commit comments