36
36
from imblearn .over_sampling .base import BaseOverSampler
37
37
from imblearn .under_sampling .base import BaseCleaningSampler , BaseUnderSampler
38
38
39
+
39
40
def sample_dataset_generator ():
40
41
X , y = make_classification (
41
42
n_samples = 1000 ,
@@ -45,10 +46,13 @@ def sample_dataset_generator():
45
46
random_state = 0 ,
46
47
)
47
48
return X , y
49
+
50
+
48
51
@pytest .fixture (name = "sample_dataset_generator" )
49
52
def sample_dataset_generator_fixture ():
50
53
return sample_dataset_generator ()
51
54
55
+
52
56
def _set_checking_parameters (estimator ):
53
57
params = estimator .get_params ()
54
58
name = estimator .__class__ .__name__
@@ -261,6 +265,7 @@ def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
261
265
X_res , y_res = sampler .fit_resample (X , y )
262
266
assert Counter (y_res )[1 ] == expected_stat
263
267
268
+
264
269
def check_samplers_sparse (name , sampler_orig ):
265
270
sampler = clone (sampler_orig )
266
271
# check that sparse matrices can be passed through the sampler leading to
@@ -320,7 +325,7 @@ def check_samplers_list(name, sampler_orig):
320
325
assert_allclose (y_res , y_res_list )
321
326
322
327
323
- def check_samplers_multiclass_ova (name , sampler_orig , sample_dataset_generator ):
328
+ def check_samplers_multiclass_ova (name , sampler_orig ):
324
329
sampler = clone (sampler_orig )
325
330
# Check that multiclass target lead to the same results than OVA encoding
326
331
X , y = sample_dataset_generator ()
@@ -332,15 +337,15 @@ def check_samplers_multiclass_ova(name, sampler_orig, sample_dataset_generator):
332
337
assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
333
338
334
339
335
- def check_samplers_2d_target (name , sampler_orig , sample_dataset_generator ):
340
+ def check_samplers_2d_target (name , sampler_orig ):
336
341
sampler = clone (sampler_orig )
337
342
X , y = sample_dataset_generator ()
338
343
339
344
y = y .reshape (- 1 , 1 ) # Make the target 2d
340
345
sampler .fit_resample (X , y )
341
346
342
347
343
- def check_samplers_preserve_dtype (name , sampler_orig , sample_dataset_generator ):
348
+ def check_samplers_preserve_dtype (name , sampler_orig ):
344
349
sampler = clone (sampler_orig )
345
350
X , y = sample_dataset_generator ()
346
351
# Cast X and y to not default dtype
@@ -351,7 +356,7 @@ def check_samplers_preserve_dtype(name, sampler_orig, sample_dataset_generator):
351
356
assert y .dtype == y_res .dtype , "y dtype is not preserved"
352
357
353
358
354
- def check_samplers_sample_indices (name , sampler_orig , sample_dataset_generator ):
359
+ def check_samplers_sample_indices (name , sampler_orig ):
355
360
sampler = clone (sampler_orig )
356
361
X , y = sample_dataset_generator ()
357
362
sampler .fit_resample (X , y )
0 commit comments