Skip to content

Commit

Permalink
more docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
jpaillard committed Feb 20, 2025
1 parent f4e572b commit c325bcc
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 18 deletions.
4 changes: 3 additions & 1 deletion src/hidimstat/base_perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(
averaged over all permutations.
method : str, default="predict"
The method used for making predictions. This determines the predictions
passed to the loss function.
passed to the loss function. Supported methods are "predict",
"predict_proba", "decision_function", "transform".
n_jobs : int, default=1
The number of parallel jobs to run. Parallelization is done over the
variables or groups of variables.
Expand All @@ -51,6 +52,7 @@ def __init__(
self.n_groups = None

def fit(self, X, y, groups=None):
"""Base fit method for perturbation-based methods. Identifies the groups."""
if groups is None:
self.n_groups = X.shape[1]
self.groups = {j: [j] for j in range(self.n_groups)}
Expand Down
18 changes: 12 additions & 6 deletions src/hidimstat/conditional_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class CPI(BasePerturbation):

def __init__(
self,
estimator,
Expand All @@ -29,27 +28,28 @@ def __init__(
Parameters
----------
estimator : object
estimator : sklearn compatible estimator, optional
The estimator to use for the prediction.
loss : callable, default=root_mean_squared_error
The loss function to use when comparing the perturbed model to the full
model.
method : str, default="predict"
The method to use for the prediction. This determines the predictions passed
to the loss function.
to the loss function. Supported methods are "predict", "predict_proba",
"decision_function", "transform".
n_jobs : int, default=1
The number of jobs to run in parallel. Parallelization is done over the
variables or groups of variables.
n_permutations : int, default=50
The number of permutations to perform. For each variable/group of variables,
the mean of the losses over the `n_permutations` is computed.
imputation_model_continuous : object, default=None
imputation_model_continuous : sklearn compatible estimator, optional
The model used to estimate the conditional distribution of a given
continuous variable/group of variables given the others.
imputation_model_binary : object, default=None
imputation_model_binary : sklearn compatible estimator, optional
The model used to estimate the conditional distribution of a given
binary variable/group of variables given the others.
imputation_model_categorical : object, default=None
imputation_model_categorical : sklearn compatible estimator, optional
The model used to estimate the conditional distribution of a given
categorical variable/group of variables given the others.
random_state : int, default=None
Expand Down Expand Up @@ -77,6 +77,7 @@ def __init__(
self.categorical_max_cardinality = categorical_max_cardinality

def fit(self, X, y=None, groups=None, var_type="auto"):
"""Fit the imputation models."""
super().fit(X, None, groups=groups)
if isinstance(var_type, str):
self.var_type = [var_type for _ in range(self.n_groups)]
Expand Down Expand Up @@ -119,18 +120,23 @@ def fit(self, X, y=None, groups=None, var_type="auto"):
return self

def _joblib_fit_one_group(self, estimator, X, groups_ids):
"""Fit a single imputation model, for a single group of variables. This method
is parallelized."""
X_j = X[:, groups_ids].copy()
X_minus_j = np.delete(X, groups_ids, axis=1)
estimator.fit(X_minus_j, X_j)
return estimator

def _check_fit(self):
"""Check if the imputation models are fitted."""
if len(self._list_imputation_models) == 0:
raise ValueError("The estimators require to be fit before to use them")
for m in self._list_imputation_models:
check_is_fitted(m.model)

def _permutation(self, X, group_id):
"""Sample from the conditional distribution using a permutation of the
residuals."""
X_j = X[:, self._groups_ids[group_id]].copy()
X_minus_j = np.delete(X, self._groups_ids[group_id], axis=1)
return self._list_imputation_models[group_id].sample(
Expand Down
10 changes: 5 additions & 5 deletions src/hidimstat/conditional_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ def __init__(
Parameters
----------
model_regression : object
model_regression : sklearn compatible estimator, optional
The model to use for continuous data.
model_binary : object
model_binary : sklearn compatible estimator, optional
The model to use for binary data.
model_categorical : object
model_categorical : sklearn compatible estimator, optional
The model to use for categorical data.
data_type : str, default="auto"
The variable type. Supported types include "auto", "continuous", "binary",
and "categorical". If "auto", the type is inferred from the cardinality of
the unique values passed to the `fit` method. For categorical variables, the
default strategy is to use a one-vs-rest classifier.
random_state : int
random_state : int, optional
The random state to use for sampling.
categorical_max_cardinality : int
categorical_max_cardinality : int, default=10
The maximum cardinality of a variable to be considered as categorical
when `data_type` is "auto".
Expand Down
11 changes: 9 additions & 2 deletions src/hidimstat/leave_one_covariate_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ def __init__(
Parameters
----------
estimator : object
estimator : sklearn compatible estimator, optional
The estimator to use for the prediction.
loss : callable, default=root_mean_squared_error
The loss function to use when comparing the perturbed model to the full
model.
method : str, default="predict"
The method to use for the prediction. This determines the predictions passed
to the loss function.
to the loss function. Supported methods are "predict", "predict_proba",
"decision_function", "transform".
n_jobs : int, default=1
The number of jobs to run in parallel. Parallelization is done over the
variables or groups of variables.
Expand All @@ -45,6 +46,7 @@ def __init__(
self._list_estimators = []

def fit(self, X, y, groups=None):
"""Fit a model after removing each covariate/group of covariates."""
super().fit(X, y, groups)
# create a list of covariate estimators for each group if not provided
self._list_estimators = [clone(self.estimator) for _ in range(self.n_groups)]
Expand All @@ -57,6 +59,7 @@ def fit(self, X, y, groups=None):
return self

def _joblib_fit_one_group(self, estimator, X, y, key_groups):
"""Fit the estimator after removing a group of covariates. Used in parallel."""
if isinstance(X, pd.DataFrame):
X_minus_j = X.drop(columns=self.groups[key_groups])
else:
Expand All @@ -65,13 +68,17 @@ def _joblib_fit_one_group(self, estimator, X, y, key_groups):
return estimator

def _joblib_predict_one_group(self, X, group_id, key_groups):
"""Predict the target variable after removing a group of covariates.
Used in parallel."""
X_minus_j = np.delete(X, self._groups_ids[group_id], axis=1)

y_pred_loco = getattr(self._list_estimators[group_id], self.method)(X_minus_j)

return [y_pred_loco]

def _check_fit(self):
"""Check that an estimator has been fitted after removing each group of
covariates."""
check_is_fitted(self.estimator)
if len(self._list_estimators) == 0:
raise ValueError("The estimators require to be fit before to use them")
Expand Down
7 changes: 4 additions & 3 deletions src/hidimstat/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ def __init__(
Parameters
----------
estimator : object
estimator : sklearn compatible estimator, optionals
The estimator to use for the prediction.
loss : callable, default=root_mean_squared_error
The loss function to use when comparing the perturbed model to the full
model.
method : str, default="predict"
The method to use for the prediction. This determines the predictions passed
to the loss function.
to the loss function. Supported methods are "predict", "predict_proba",
"decision_function", "transform".
n_jobs : int, default=1
The number of jobs to run in parallel. Parallelization is done over the
variables or groups of variables.
Expand All @@ -50,7 +51,7 @@ def __init__(
self.rng = check_random_state(random_state)

def _permutation(self, X, group_id):
# Create the permuted data for the j-th group of covariates
"""Create the permuted data for the j-th group of covariates"""
X_perm_j = np.array(
[
self.rng.permutation(X[:, self._groups_ids[group_id]].copy())
Expand Down
5 changes: 4 additions & 1 deletion test/test_conditional_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def test_continuous_case():
"""Test sampling from the conditional distribution of a continuous variable."""
n = 1000
np.random.seed(40)
sampler = ConditionalSampler(
Expand Down Expand Up @@ -47,6 +48,7 @@ def test_continuous_case():


def test_binary_case():
"""Test sampling from the conditional distribution of a binary variable."""
n = 1000
np.random.seed(40)

Expand Down Expand Up @@ -92,6 +94,7 @@ def test_binary_case():


def test_error():
"""Test for error when model does not have predict_proba or predict."""
# Test for error when model does not have predict_proba
np.random.seed(40)
sampler = ConditionalSampler(
Expand All @@ -111,7 +114,7 @@ def test_error():
)
with pytest.raises(AttributeError):
sampler.fit(np.delete(X, 1, axis=1), X[:, 1])
sampler.sample(np.delete(X, 1, axis=1), X[:, 1])
sampler.sample()

sampler = ConditionalSampler(
data_type="auto",
Expand Down
3 changes: 3 additions & 0 deletions test/test_cpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


def test_cpi(linear_scenario):
"""Test the Conditional Permutation Importance algorithm on a linear scenario."""
X, y, beta = linear_scenario
important_features = np.where(beta != 0)[0]
non_important_features = np.where(beta == 0)[0]
Expand Down Expand Up @@ -98,6 +99,8 @@ def test_cpi(linear_scenario):
def test_raises_value_error(
linear_scenario,
):
"""Test for the ValueError raised by the Conditional Permutation Importance
algorithm."""
X, y, _ = linear_scenario

# Predict method not recognized
Expand Down
2 changes: 2 additions & 0 deletions test/test_loco.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def test_loco(linear_scenario):
"""Test the Leave-One-Covariate-Out algorithm on a linear scenario."""
X, y, beta = linear_scenario
important_features = np.where(beta != 0)[0]
non_important_features = np.where(beta == 0)[0]
Expand Down Expand Up @@ -88,6 +89,7 @@ def test_loco(linear_scenario):
def test_raises_value_error(
linear_scenario,
):
"""Test for error when model does not have predict_proba or predict."""
X, y, _ = linear_scenario
# Not fitted estimator
with pytest.raises(NotFittedError):
Expand Down
1 change: 1 addition & 0 deletions test/test_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def test_permutation_importance(linear_scenario):
"""Test the Permutation Importance algorithm on a linear scenario."""
X, y, beta = linear_scenario
important_features = np.where(beta != 0)[0]
non_important_features = np.where(beta == 0)[0]
Expand Down

0 comments on commit c325bcc

Please sign in to comment.