Skip to content

Commit 9f8d6a6

Browse files
committed
Update imblearn/utils/estimator_checks.py
1 parent 7a6fc00 commit 9f8d6a6

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

imblearn/utils/estimator_checks.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from imblearn.over_sampling.base import BaseOverSampler
3737
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
3838

39+
3940
def sample_dataset_generator():
4041
X, y = make_classification(
4142
n_samples=1000,
@@ -45,10 +46,13 @@ def sample_dataset_generator():
4546
random_state=0,
4647
)
4748
return X, y
49+
50+
4851
@pytest.fixture(name="sample_dataset_generator")
4952
def sample_dataset_generator_fixture():
5053
return sample_dataset_generator()
5154

55+
5256
def _set_checking_parameters(estimator):
5357
params = estimator.get_params()
5458
name = estimator.__class__.__name__
@@ -261,6 +265,7 @@ def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
261265
X_res, y_res = sampler.fit_resample(X, y)
262266
assert Counter(y_res)[1] == expected_stat
263267

268+
264269
def check_samplers_sparse(name, sampler_orig):
265270
sampler = clone(sampler_orig)
266271
# check that sparse matrices can be passed through the sampler leading to
@@ -320,7 +325,7 @@ def check_samplers_list(name, sampler_orig):
320325
assert_allclose(y_res, y_res_list)
321326

322327

323-
def check_samplers_multiclass_ova(name, sampler_orig, sample_dataset_generator):
328+
def check_samplers_multiclass_ova(name, sampler_orig):
324329
sampler = clone(sampler_orig)
325330
# Check that multiclass target lead to the same results than OVA encoding
326331
X, y = sample_dataset_generator()
@@ -332,15 +337,15 @@ def check_samplers_multiclass_ova(name, sampler_orig, sample_dataset_generator):
332337
assert_allclose(y_res, y_res_ova.argmax(axis=1))
333338

334339

335-
def check_samplers_2d_target(name, sampler_orig, sample_dataset_generator):
340+
def check_samplers_2d_target(name, sampler_orig):
336341
sampler = clone(sampler_orig)
337342
X, y = sample_dataset_generator()
338343

339344
y = y.reshape(-1, 1) # Make the target 2d
340345
sampler.fit_resample(X, y)
341346

342347

343-
def check_samplers_preserve_dtype(name, sampler_orig, sample_dataset_generator):
348+
def check_samplers_preserve_dtype(name, sampler_orig):
344349
sampler = clone(sampler_orig)
345350
X, y = sample_dataset_generator()
346351
# Cast X and y to not default dtype
@@ -351,7 +356,7 @@ def check_samplers_preserve_dtype(name, sampler_orig, sample_dataset_generator):
351356
assert y.dtype == y_res.dtype, "y dtype is not preserved"
352357

353358

354-
def check_samplers_sample_indices(name, sampler_orig, sample_dataset_generator):
359+
def check_samplers_sample_indices(name, sampler_orig):
355360
sampler = clone(sampler_orig)
356361
X, y = sample_dataset_generator()
357362
sampler.fit_resample(X, y)

0 commit comments

Comments
 (0)