From 4ee713de17e978eb8d6c255eede3d31d70c7fcd8 Mon Sep 17 00:00:00 2001 From: jpaillard Date: Mon, 17 Feb 2025 14:38:59 +0100 Subject: [PATCH] add docstring --- .../permutation_importance_classes.py | 112 +++++++++++++++++- 1 file changed, 108 insertions(+), 4 deletions(-) diff --git a/src/hidimstat/permutation_importance_classes.py b/src/hidimstat/permutation_importance_classes.py index 04ad6a3..6c24cb2 100644 --- a/src/hidimstat/permutation_importance_classes.py +++ b/src/hidimstat/permutation_importance_classes.py @@ -8,7 +8,7 @@ from hidimstat.utils import _check_vim_predict_method -class BasePermutation(BaseEstimator): +class BasePerturbation(BaseEstimator): def __init__( self, estimator, @@ -17,6 +17,28 @@ def __init__( method: str = "predict", n_jobs: int = 1, ): + """ + Base class for model agnostic variable importance measures based on + perturbation. + + Parameters + ---------- + estimator : object + The estimator to use for the prediction. + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + n_permutations : int, default=50 + Only for PermutationImportance or CPI. The number of permutations to + perform. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + + """ check_is_fitted(estimator) self.estimator = estimator self.loss = loss @@ -111,7 +133,7 @@ def _permutation(self, X, group_id): raise NotImplementedError -class PermutationImportance(BasePermutation): +class PermutationImportance(BasePerturbation): def __init__( self, estimator, @@ -121,6 +143,31 @@ def __init__( n_permutations: int = 50, random_state: int = None, ): + """ + Permutation Importance algorithm as presented in + :footcite:t:`breimanRandomForests2001`. For each variable/group of variables, + the importance is computed as the difference between the loss of the initial + model and the loss of the model with the variable/group permuted. + + Parameters + ---------- + estimator : object + The estimator to use for the prediction. + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + n_permutations : int, default=50 + The number of permutations to perform. For each variable/group of variables, + the mean of the losses over the `n_permutations` is computed. + random_state : int, default=None + The random state to use for sampling. + """ super().__init__( estimator=estimator, loss=loss, @@ -141,7 +188,7 @@ def _permutation(self, X, group_id): return X_perm_j -class LOCO(BasePermutation): +class LOCO(BasePerturbation): def __init__( self, estimator, @@ -149,6 +196,26 @@ def __init__( method: str = "predict", n_jobs: int = 1, ): + """ + Leave-One-Covariate-Out (LOCO) as presented in + :footcite:t:`Williamson_General_2023`. The model is re-fitted for each variable/ + group of variables. The importance is then computed as the difference between + the loss of the full model and the loss of the model without the variable/group. + + Parameters + ---------- + estimator : object + The estimator to use for the prediction. + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + """ super().__init__( estimator=estimator, loss=loss, @@ -193,7 +260,8 @@ def _check_fit(self): check_is_fitted(m) -class CPI(BasePermutation): +class CPI(BasePerturbation): + def __init__( self, estimator, @@ -207,6 +275,42 @@ def __init__( random_state: int = None, categorical_max_cardinality: int = 10, ): + """ + Conditional Permutation Importance (CPI) algorithm. + :footcite:t:`Chamma_NeurIPS2023` and for group-level see + :footcite:t:`Chamma_AAAI2024`. + + Parameters + ---------- + estimator : object + The estimator to use for the prediction. + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + n_permutations : int, default=50 + The number of permutations to perform. For each variable/group of variables, + the mean of the losses over the `n_permutations` is computed. + imputation_model_continuous : object, default=None + The model used to estimate the conditional distribution of a given + continuous variable/group of variables given the others. + imputation_model_binary : object, default=None + The model used to estimate the conditional distribution of a given + binary variable/group of variables given the others. + imputation_model_classification : object, default=None + The model used to estimate the conditional distribution of a given + categorical variable/group of variables given the others. + random_state : int, default=None + The random state to use for sampling. + categorical_max_cardinality : int, default=10 + The maximum cardinality of a variable to be considered as categorical + when the variable type is inferred (set to "auto" or not provided). + """ super().__init__( estimator=estimator,