Skip to content

Commit

Permalink
Make sklearn dependency optional (#20657)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw authored Dec 18, 2024
1 parent bce0f5b commit ed1442e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 66 deletions.
52 changes: 8 additions & 44 deletions keras/src/wrappers/fixes.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,7 @@
import sklearn
from packaging.version import parse as parse_version
from sklearn import get_config

sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)

if sklearn_version < parse_version("1.6"):

def patched_more_tags(estimator, expected_failed_checks):
import copy

from sklearn.utils._tags import _safe_tags

original_tags = copy.deepcopy(_safe_tags(estimator))

def patched_more_tags(self):
original_tags.update({"_xfail_checks": expected_failed_checks})
return original_tags

estimator.__class__._more_tags = patched_more_tags
return estimator

def parametrize_with_checks(
estimators,
*,
legacy: bool = True,
expected_failed_checks=None,
):
# legacy is not supported and ignored
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001

estimators = [
patched_more_tags(estimator, expected_failed_checks(estimator))
for estimator in estimators
]

return parametrize_with_checks(estimators)
else:
from sklearn.utils.estimator_checks import parametrize_with_checks # noqa: F401, I001
try:
import sklearn
except ImportError:
sklearn = None


def _validate_data(estimator, *args, **kwargs):
Expand All @@ -59,9 +24,6 @@ def _validate_data(estimator, *args, **kwargs):


def type_of_target(y, input_name="", *, raise_unknown=False):
# fix for raise_unknown which is introduced in scikit-learn 1.6
from sklearn.utils.multiclass import type_of_target

def _raise_or_return(target_type):
"""Depending on the value of raise_unknown, either raise an error or
return 'unknown'.
Expand All @@ -72,7 +34,9 @@ def _raise_or_return(target_type):
else:
return target_type

target_type = type_of_target(y, input_name=input_name)
target_type = sklearn.utils.multiclass.type_of_target(
y, input_name=input_name
)
return _raise_or_return(target_type)


Expand All @@ -86,7 +50,7 @@ def _routing_enabled():
TODO: remove when the config key is no longer available in scikit-learn
"""
return get_config().get("enable_metadata_routing", False)
return sklearn.get_config().get("enable_metadata_routing", False)


def _raise_for_params(params, owner, method):
Expand Down
45 changes: 43 additions & 2 deletions keras/src/wrappers/sklearn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from contextlib import contextmanager

import pytest
import sklearn
from packaging.version import parse as parse_version
from sklearn.utils.estimator_checks import parametrize_with_checks

import keras
from keras.src.backend import floatx
Expand All @@ -13,7 +16,45 @@
from keras.src.wrappers import SKLearnClassifier
from keras.src.wrappers import SKLearnRegressor
from keras.src.wrappers import SKLearnTransformer
from keras.src.wrappers.fixes import parametrize_with_checks


def wrapped_parametrize_with_checks(
estimators,
*,
legacy: bool = True,
expected_failed_checks=None,
):
"""Wrapped `parametrize_with_checks` handling backwards compat."""
sklearn_version = parse_version(
parse_version(sklearn.__version__).base_version
)

if sklearn_version >= parse_version("1.6"):
return parametrize_with_checks(
estimators,
legacy=legacy,
expected_failed_checks=expected_failed_checks,
)

def patched_more_tags(estimator, expected_failed_checks):
import copy

original_tags = copy.deepcopy(sklearn.utils._tags._safe_tags(estimator))

def patched_more_tags(self):
original_tags.update({"_xfail_checks": expected_failed_checks})
return original_tags

estimator.__class__._more_tags = patched_more_tags
return estimator

estimators = [
patched_more_tags(estimator, expected_failed_checks(estimator))
for estimator in estimators
]

# legacy is not supported and ignored
return parametrize_with_checks(estimators)


def dynamic_model(X, y, loss, layers=[10]):
Expand Down Expand Up @@ -80,7 +121,7 @@ def use_floatx(x: str):
}


@parametrize_with_checks(
@wrapped_parametrize_with_checks(
estimators=[
SKLearnClassifier(
model=dynamic_model,
Expand Down
44 changes: 31 additions & 13 deletions keras/src/wrappers/sklearn_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import copy

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.base import RegressorMixin
from sklearn.base import TransformerMixin
from sklearn.base import check_is_fitted
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils.metadata_routing import MetadataRequest

from keras.src.api_export import keras_export
from keras.src.models.cloning import clone_model
Expand All @@ -18,6 +10,28 @@
from keras.src.wrappers.fixes import type_of_target
from keras.src.wrappers.utils import TargetReshaper
from keras.src.wrappers.utils import _check_model
from keras.src.wrappers.utils import assert_sklearn_installed

try:
import sklearn
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.base import RegressorMixin
from sklearn.base import TransformerMixin
except ImportError:
sklearn = None

class BaseEstimator:
pass

class ClassifierMixin:
pass

class RegressorMixin:
pass

class TransformerMixin:
pass


class SKLBase(BaseEstimator):
Expand Down Expand Up @@ -64,6 +78,7 @@ def __init__(
model_kwargs=None,
fit_kwargs=None,
):
assert_sklearn_installed(self.__class__.__name__)
self.model = model
self.warm_start = warm_start
self.model_kwargs = model_kwargs
Expand Down Expand Up @@ -119,7 +134,9 @@ def set_fit_request(self, **kwargs):
"sklearn.set_config(enable_metadata_routing=True)."
)

self._metadata_request = MetadataRequest(owner=self.__class__.__name__)
self._metadata_request = sklearn.utils.metadata_routing.MetadataRequest(
owner=self.__class__.__name__
)
for param, alias in kwargs.items():
self._metadata_request.score.add_request(param=param, alias=alias)
return self
Expand Down Expand Up @@ -155,7 +172,7 @@ def fit(self, X, y, **kwargs):

def predict(self, X):
"""Predict using the model."""
check_is_fitted(self)
sklearn.base.check_is_fitted(self)
X = _validate_data(self, X, reset=False)
raw_output = self.model_.predict(X)
return self._reverse_process_target(raw_output)
Expand Down Expand Up @@ -267,8 +284,9 @@ def _process_target(self, y, reset=False):
f" Target type: {target_type}"
)
if reset:
self._target_encoder = make_pipeline(
TargetReshaper(), OneHotEncoder(sparse_output=False)
self._target_encoder = sklearn.pipeline.make_pipeline(
TargetReshaper(),
sklearn.preprocessing.OneHotEncoder(sparse_output=False),
).fit(y)
self.classes_ = np.unique(y)
if len(self.classes_) == 1:
Expand Down Expand Up @@ -454,7 +472,7 @@ def transform(self, X):
X_transformed: array-like, shape=(n_samples, n_features)
The transformed data.
"""
check_is_fitted(self)
sklearn.base.check_is_fitted(self)
X = _validate_data(self, X, reset=False)
return self.model_.predict(X)

Expand Down
28 changes: 22 additions & 6 deletions keras/src/wrappers/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.base import check_is_fitted
from sklearn.utils._array_api import get_namespace
try:
import sklearn
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
except ImportError:
sklearn = None

class BaseEstimator:
pass

class TransformerMixin:
pass


def assert_sklearn_installed(symbol_name):
if sklearn is None:
raise ImportError(
f"{symbol_name} requires `scikit-learn` to be installed. "
"Run `pip install scikit-learn` to install it."
)


def _check_model(model):
Expand Down Expand Up @@ -64,8 +80,8 @@ def inverse_transform(self, y):
is passed, it will be squeezed back to 1D. Otherwise, it
will eb left untouched.
"""
check_is_fitted(self)
xp, _ = get_namespace(y)
sklearn.base.check_is_fitted(self)
xp, _ = sklearn.utils._array_api.get_namespace(y)
if self.ndim_ == 1 and y.ndim == 2:
return xp.squeeze(y, axis=1)
return y
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies = [
"optree",
"ml-dtypes",
"packaging",
"scikit-learn",
]
# Run also: pip install -r requirements.txt

Expand Down

0 comments on commit ed1442e

Please sign in to comment.