Description
Describe the bug
Test test_estimators_compatibility_sklearn
fails with AssertionError: Estimator SMOTEN didn't fail when fitted on sparse data but should have according to its tag self.input_tags.sparse=False. The tag is inconsistent and must be fixed.
This appears to be an issue with scikit-learn >= 1.6.1
.
Expected Results
Test succeeds.
Actual Results
=================================== FAILURES ===================================
_ test_estimators_compatibility_sklearn[SMOTEN(random_state=42)-check_estimator_sparse_tag] _
estimator = SMOTEN(random_state=42)
check = functools.partial(<function check_estimator_sparse_tag at 0x7f5dc05c6fc0>, 'SMOTEN')
request = <FixtureRequest for <Function test_estimators_compatibility_sklearn[SMOTEN(random_state=42)-check_estimator_sparse_tag]>>
@parametrize_with_checks_sklearn(
list(_tested_estimators()), expected_failed_checks=_get_expected_failed_checks
)
def test_estimators_compatibility_sklearn(estimator, check, request):
_set_checking_parameters(estimator)
> check(estimator)
imblearn/tests/test_common.py:46:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
name = 'SMOTEN', estimator_orig = SMOTEN(random_state=42)
def check_estimator_sparse_tag(name, estimator_orig):
"""Check that estimator tag related with accepting sparse data is properly set."""
if SPARSE_ARRAY_PRESENT:
sparse_container = sparse.csr_array
else:
sparse_container = sparse.csr_matrix
estimator = clone(estimator_orig)
rng = np.random.RandomState(0)
n_samples = 15 if name == "SpectralCoclustering" else 40
X = rng.uniform(size=(n_samples, 3))
X[X < 0.6] = 0
y = rng.randint(0, 3, size=n_samples)
X = _enforce_estimator_tags_X(estimator, X)
y = _enforce_estimator_tags_y(estimator, y)
X = sparse_container(X)
tags = get_tags(estimator)
if tags.input_tags.sparse:
try:
estimator.fit(X, y) # should pass
except Exception as e:
err_msg = (
f"Estimator {name} raised an exception. "
f"The tag self.input_tags.sparse={tags.input_tags.sparse} "
"might not be consistent with the estimator's ability to "
"handle sparse data (i.e. controlled by the parameter `accept_sparse`"
" in `validate_data` or `check_array` functions)."
)
raise AssertionError(err_msg) from e
else:
err_msg = (
f"Estimator {name} raised an exception. "
"The estimator failed when fitted on sparse data in accordance "
f"with its tag self.input_tags.sparse={tags.input_tags.sparse} "
"but didn't raise the appropriate error: error message should "
"state explicitly that sparse input is not supported if this is "
"not the case, e.g. by using check_array(X, accept_sparse=False)."
)
try:
estimator.fit(X, y) # should fail with appropriate error
except (ValueError, TypeError) as e:
if re.search("[Ss]parse", str(e)):
# Got the right error type and mentioning sparse issue
return
raise AssertionError(err_msg) from e
except Exception as e:
raise AssertionError(err_msg) from e
> raise AssertionError(
f"Estimator {name} didn't fail when fitted on sparse data "
"but should have according to its tag "
f"self.input_tags.sparse={tags.input_tags.sparse}. "
f"The tag is inconsistent and must be fixed."
)
E AssertionError: Estimator SMOTEN didn't fail when fitted on sparse data but should have according to its tag self.input_tags.sparse=False. The tag is inconsistent and must be fixed.
Versions
System:
python: 3.13.2 (main, Feb 6 2025, 00:00:00) [GCC 15.0.1 20250204 (Red Hat 15.0.1-0)]
executable: /usr/bin/python3
machine: Linux-6.13.5-200.fc41.x86_64-x86_64-with-glibc2.41.9000
Python dependencies:
sklearn: 1.6.1
pip: 24.3.1
setuptools: 74.1.3
numpy: 2.2.4
scipy: 1.14.1
Cython: None
pandas: 2.2.3
matplotlib: None
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: flexiblas
num_threads: 8
prefix: libflexiblas
filepath: /usr/lib64/libflexiblas.so.3.4
version: 3.4.5
available_backends: ['NETLIB', 'OPENBLAS-OPENMP']
loaded_backends: ['OPENBLAS-OPENMP']
current_backend: OPENBLAS-OPENMP
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
filepath: /usr/lib64/libopenblaso-r0.3.29.so
version: 0.3.29
threading_layer: openmp
architecture: Sandybridge
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: libgomp
filepath: /usr/lib64/libgomp.so.1.0.0
version: None