Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
jpaillard committed Feb 17, 2025
1 parent 0b7a467 commit 4ee713d
Showing 1 changed file with 108 additions and 4 deletions.
112 changes: 108 additions & 4 deletions src/hidimstat/permutation_importance_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from hidimstat.utils import _check_vim_predict_method


class BasePermutation(BaseEstimator):
class BasePerturbation(BaseEstimator):
def __init__(
self,
estimator,
Expand All @@ -17,6 +17,28 @@ def __init__(
method: str = "predict",
n_jobs: int = 1,
):
"""
Base class for model agnostic variable importance measures based on
perturbation.
Parameters
----------
estimator : object
The estimator to use for the prediction.
loss : callable, default=root_mean_squared_error
The loss function to use when comparing the perturbed model to the full
model.
n_permutations : int, default=50
Only for PermutationImportance or CPI. The number of permutations to
perform.
method : str, default="predict"
The method to use for the prediction. This determines the predictions passed
to the loss function.
n_jobs : int, default=1
The number of jobs to run in parallel. Parallelization is done over the
variables or groups of variables.
"""
check_is_fitted(estimator)
self.estimator = estimator
self.loss = loss
Expand Down Expand Up @@ -111,7 +133,7 @@ def _permutation(self, X, group_id):
raise NotImplementedError


class PermutationImportance(BasePermutation):
class PermutationImportance(BasePerturbation):
def __init__(
self,
estimator,
Expand All @@ -121,6 +143,31 @@ def __init__(
n_permutations: int = 50,
random_state: int = None,
):
"""
Permutation Importance algorithm as presented in
:footcite:t:`breimanRandomForests2001`. For each variable/group of variables,
the importance is computed as the difference between the loss of the initial
model and the loss of the model with the variable/group permuted.
Parameters
----------
estimator : object
The estimator to use for the prediction.
loss : callable, default=root_mean_squared_error
The loss function to use when comparing the perturbed model to the full
model.
method : str, default="predict"
The method to use for the prediction. This determines the predictions passed
to the loss function.
n_jobs : int, default=1
The number of jobs to run in parallel. Parallelization is done over the
variables or groups of variables.
n_permutations : int, default=50
The number of permutations to perform. For each variable/group of variables,
the mean of the losses over the `n_permutations` is computed.
random_state : int, default=None
The random state to use for sampling.
"""
super().__init__(
estimator=estimator,
loss=loss,
Expand All @@ -141,14 +188,34 @@ def _permutation(self, X, group_id):
return X_perm_j


class LOCO(BasePermutation):
class LOCO(BasePerturbation):
def __init__(
self,
estimator,
loss: callable = root_mean_squared_error,
method: str = "predict",
n_jobs: int = 1,
):
"""
Leave-One-Covariate-Out (LOCO) as presented in
:footcite:t:`Williamson_General_2023`. The model is re-fitted for each variable/
group of variables. The importance is then computed as the difference between
the loss of the full model and the loss of the model without the variable/group.
Parameters
----------
estimator : object
The estimator to use for the prediction.
loss : callable, default=root_mean_squared_error
The loss function to use when comparing the perturbed model to the full
model.
method : str, default="predict"
The method to use for the prediction. This determines the predictions passed
to the loss function.
n_jobs : int, default=1
The number of jobs to run in parallel. Parallelization is done over the
variables or groups of variables.
"""
super().__init__(
estimator=estimator,
loss=loss,
Expand Down Expand Up @@ -193,7 +260,8 @@ def _check_fit(self):
check_is_fitted(m)


class CPI(BasePermutation):
class CPI(BasePerturbation):

def __init__(
self,
estimator,
Expand All @@ -207,6 +275,42 @@ def __init__(
random_state: int = None,
categorical_max_cardinality: int = 10,
):
"""
Conditional Permutation Importance (CPI) algorithm.
:footcite:t:`Chamma_NeurIPS2023` and for group-level see
:footcite:t:`Chamma_AAAI2024`.
Parameters
----------
estimator : object
The estimator to use for the prediction.
loss : callable, default=root_mean_squared_error
The loss function to use when comparing the perturbed model to the full
model.
method : str, default="predict"
The method to use for the prediction. This determines the predictions passed
to the loss function.
n_jobs : int, default=1
The number of jobs to run in parallel. Parallelization is done over the
variables or groups of variables.
n_permutations : int, default=50
The number of permutations to perform. For each variable/group of variables,
the mean of the losses over the `n_permutations` is computed.
imputation_model_continuous : object, default=None
The model used to estimate the conditional distribution of a given
continuous variable/group of variables given the others.
imputation_model_binary : object, default=None
The model used to estimate the conditional distribution of a given
binary variable/group of variables given the others.
imputation_model_classification : object, default=None
The model used to estimate the conditional distribution of a given
categorical variable/group of variables given the others.
random_state : int, default=None
The random state to use for sampling.
categorical_max_cardinality : int, default=10
The maximum cardinality of a variable to be considered as categorical
when the variable type is inferred (set to "auto" or not provided).
"""

super().__init__(
estimator=estimator,
Expand Down

0 comments on commit 4ee713d

Please sign in to comment.