diff --git a/src/hidimstat/base_perturbation.py b/src/hidimstat/base_perturbation.py new file mode 100644 index 0000000..56b2db3 --- /dev/null +++ b/src/hidimstat/base_perturbation.py @@ -0,0 +1,132 @@ +import numpy as np +import pandas as pd +from joblib import Parallel, delayed +from sklearn.base import BaseEstimator, check_is_fitted +from sklearn.metrics import root_mean_squared_error + +from hidimstat.utils import _check_vim_predict_method + + +class BasePerturbation(BaseEstimator): + def __init__( + self, + estimator, + loss: callable = root_mean_squared_error, + n_permutations: int = 50, + method: str = "predict", + n_jobs: int = 1, + ): + """ + Base class for model agnostic variable importance measures based on + perturbation. + + Parameters + ---------- + estimator : object + The estimator to use for the prediction. + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + n_permutations : int, default=50 + Only for PermutationImportance or CPI. The number of permutations to + perform. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + + """ + check_is_fitted(estimator) + self.estimator = estimator + self.loss = loss + _check_vim_predict_method(method) + self.method = method + self.n_jobs = n_jobs + self.n_permutations = n_permutations + self.n_groups = None + + def fit(self, X, y, groups=None): + if groups is None: + self.n_groups = X.shape[1] + self.groups = {j: [j] for j in range(self.n_groups)} + self._groups_ids = np.array(list(self.groups.values()), dtype=int) + else: + self.n_groups = len(groups) + self.groups = groups + if isinstance(X, pd.DataFrame): + self._groups_ids = [] + for group_key in self.groups.keys(): + self._groups_ids.append( + [ + i + for i, col in enumerate(X.columns) + if col in self.groups[group_key] + ] + ) + else: + self._groups_ids = np.array(list(self.groups.values()), dtype=int) + + def predict(self, X): + self._check_fit() + X_ = np.asarray(X) + + # Parallelize the computation of the importance scores for each group + out_list = Parallel(n_jobs=self.n_jobs)( + delayed(self._joblib_predict_one_group)(X_, group_id, group_key) + for group_id, group_key in enumerate(self.groups.keys()) + ) + return np.stack(out_list, axis=0) + + def score(self, X, y): + self._check_fit() + + out_dict = dict() + + y_pred = getattr(self.estimator, self.method)(X) + loss_reference = self.loss(y, y_pred) + out_dict["loss_reference"] = loss_reference + + y_pred = self.predict(X) + out_dict["loss"] = dict() + for j, y_pred_j in enumerate(y_pred): + list_loss = [] + for y_pred_perm in y_pred_j: + list_loss.append(self.loss(y, y_pred_perm)) + out_dict["loss"][j] = np.array(list_loss) + + out_dict["importance"] = np.array( + [ + np.mean(out_dict["loss"][j]) - loss_reference + for j in range(self.n_groups) + ] + ) + return out_dict + + def _check_fit(self): + pass + + def _joblib_predict_one_group(self, X, group_id, group_key): + group_ids = self._groups_ids[group_id] + non_group_ids = np.delete(np.arange(X.shape[1]), group_ids) + # Create an array X_perm_j of shape (n_permutations, n_samples, n_features) + # where the j-th group of covariates is permuted + X_perm = np.empty((self.n_permutations, X.shape[0], X.shape[1])) + X_perm[:, :, non_group_ids] = np.delete(X, group_ids, axis=1) + X_perm[:, :, group_ids] = self._permutation(X, group_id=group_id) + # Reshape X_perm to allow for batch prediction + X_perm_batch = X_perm.reshape(-1, X.shape[1]) + y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch) + + # In case of classification, the output is a 2D array. Reshape accordingly + if y_pred_perm.ndim == 1: + y_pred_perm = y_pred_perm.reshape(self.n_permutations, X.shape[0]) + else: + y_pred_perm = y_pred_perm.reshape( + self.n_permutations, X.shape[0], y_pred_perm.shape[1] + ) + return y_pred_perm + + def _permutation(self, X, group_id): + raise NotImplementedError diff --git a/src/hidimstat/conditional_permutation_importance.py b/src/hidimstat/conditional_permutation_importance.py new file mode 100644 index 0000000..9c43932 --- /dev/null +++ b/src/hidimstat/conditional_permutation_importance.py @@ -0,0 +1,144 @@ +import numpy as np +from joblib import Parallel, delayed +from sklearn.base import check_is_fitted, clone +from sklearn.metrics import root_mean_squared_error + +from hidimstat.base_perturbation import BasePerturbation +from hidimstat.conditional_sampling import ConditionalSampler + + +class CPI(BasePerturbation): + + def __init__( + self, + estimator, + loss: callable = root_mean_squared_error, + method: str = "predict", + n_jobs: int = 1, + n_permutations: int = 50, + imputation_model_continuous=None, + imputation_model_binary=None, + imputation_model_classification=None, + random_state: int = None, + categorical_max_cardinality: int = 10, + ): + """ + Conditional Permutation Importance (CPI) algorithm. + :footcite:t:`Chamma_NeurIPS2023` and for group-level see + :footcite:t:`Chamma_AAAI2024`. + + Parameters + ---------- + estimator : object + The estimator to use for the prediction. + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + n_permutations : int, default=50 + The number of permutations to perform. For each variable/group of variables, + the mean of the losses over the `n_permutations` is computed. + imputation_model_continuous : object, default=None + The model used to estimate the conditional distribution of a given + continuous variable/group of variables given the others. + imputation_model_binary : object, default=None + The model used to estimate the conditional distribution of a given + binary variable/group of variables given the others. + imputation_model_classification : object, default=None + The model used to estimate the conditional distribution of a given + categorical variable/group of variables given the others. + random_state : int, default=None + The random state to use for sampling. + categorical_max_cardinality : int, default=10 + The maximum cardinality of a variable to be considered as categorical + when the variable type is inferred (set to "auto" or not provided). + """ + + super().__init__( + estimator=estimator, + loss=loss, + method=method, + n_jobs=n_jobs, + n_permutations=n_permutations, + ) + self.rng = np.random.RandomState(random_state) + self._list_imputation_models = [] + + self.imputation_model = { + "continuous": imputation_model_continuous, + "binary": imputation_model_binary, + "categorical": imputation_model_classification, + } + self.categorical_max_cardinality = categorical_max_cardinality + + def fit(self, X, groups=None, var_type="auto"): + super().fit(X, None, groups=groups) + if isinstance(var_type, str): + self.var_type = [var_type for _ in range(self.n_groups)] + else: + self.var_type = var_type + + self._list_imputation_models = [ + ConditionalSampler( + data_type=self.var_type[groupd_id], + model_regression=( + None + if self.imputation_model["continuous"] is None + else clone(self.imputation_model["continuous"]) + ), + model_binary=( + None + if self.imputation_model["binary"] is None + else clone(self.imputation_model["binary"]) + ), + model_categorical=( + None + if self.imputation_model["categorical"] is None + else clone(self.imputation_model["categorical"]) + ), + random_state=self.rng, + categorical_max_cardinality=self.categorical_max_cardinality, + ) + for groupd_id in range(self.n_groups) + ] + + # Parallelize the fitting of the covariate estimators + X_ = np.asarray(X) + self._list_imputation_models = Parallel(n_jobs=self.n_jobs)( + delayed(self._joblib_fit_one_group)(estimator, X_, groups_ids) + for groups_ids, estimator in zip( + self._groups_ids, self._list_imputation_models + ) + ) + + return self + + def _joblib_fit_one_group(self, estimator, X, groups_ids): + X_ = self._remove_nan(X) + X_j = X_[:, groups_ids].copy() + X_minus_j = np.delete(X_, groups_ids, axis=1) + estimator.fit(X_minus_j, X_j) + return estimator + + def _check_fit(self): + if len(self._list_imputation_models) == 0: + raise ValueError("The estimators require to be fit before to use them") + for m in self._list_imputation_models: + check_is_fitted(m.model) + + def _permutation(self, X, group_id): + X_ = self._remove_nan(X) + X_j = X_[:, self._groups_ids[group_id]].copy() + X_minus_j = np.delete(X_, self._groups_ids[group_id], axis=1) + return self._list_imputation_models[group_id].sample( + X_minus_j, X_j, n_samples=self.n_permutations + ) + + def _remove_nan(self, X): + # TODO: specify the strategy to handle NaN values + return X diff --git a/src/hidimstat/leave_one_covariate_out.py b/src/hidimstat/leave_one_covariate_out.py new file mode 100644 index 0000000..6f68a05 --- /dev/null +++ b/src/hidimstat/leave_one_covariate_out.py @@ -0,0 +1,79 @@ +import numpy as np +import pandas as pd +from joblib import Parallel, delayed +from sklearn.base import check_is_fitted, clone +from sklearn.metrics import root_mean_squared_error + +from hidimstat.base_perturbation import BasePerturbation + + +class LOCO(BasePerturbation): + def __init__( + self, + estimator, + loss: callable = root_mean_squared_error, + method: str = "predict", + n_jobs: int = 1, + ): + """ + Leave-One-Covariate-Out (LOCO) as presented in + :footcite:t:`Williamson_General_2023`. The model is re-fitted for each variable/ + group of variables. The importance is then computed as the difference between + the loss of the full model and the loss of the model without the variable/group. + + Parameters + ---------- + estimator : object + The estimator to use for the prediction. + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + """ + super().__init__( + estimator=estimator, + loss=loss, + method=method, + n_jobs=n_jobs, + n_permutations=1, + ) + self._list_estimators = [] + + def fit(self, X, y, groups=None): + super().fit(X, y, groups) + # create a list of covariate estimators for each group if not provided + self._list_estimators = [clone(self.estimator) for _ in range(self.n_groups)] + + # Parallelize the fitting of the covariate estimators + self._list_estimators = Parallel(n_jobs=self.n_jobs)( + delayed(self._joblib_fit_one_group)(estimator, X, y, key_groups) + for key_groups, estimator in zip(self.groups.keys(), self._list_estimators) + ) + return self + + def _joblib_fit_one_group(self, estimator, X, y, key_groups): + if isinstance(X, pd.DataFrame): + X_minus_j = X.drop(columns=self.groups[key_groups]) + else: + X_minus_j = np.delete(X, self.groups[key_groups], axis=1) + estimator.fit(X_minus_j, y) + return estimator + + def _joblib_predict_one_group(self, X, group_id, key_groups): + X_minus_j = np.delete(X, self._groups_ids[group_id], axis=1) + + y_pred_loco = getattr(self._list_estimators[group_id], self.method)(X_minus_j) + + return [y_pred_loco] + + def _check_fit(self): + check_is_fitted(self.estimator) + if len(self._list_estimators) == 0: + raise ValueError("The estimators require to be fit before to use them") + for m in self._list_estimators: + check_is_fitted(m) diff --git a/src/hidimstat/permutation_importance.py b/src/hidimstat/permutation_importance.py index 41ae4e3..d90e953 100644 --- a/src/hidimstat/permutation_importance.py +++ b/src/hidimstat/permutation_importance.py @@ -1,327 +1,59 @@ import numpy as np -from joblib import Parallel, delayed from sklearn.metrics import root_mean_squared_error -from sklearn.exceptions import NotFittedError -from sklearn.utils import check_random_state -from sklearn.base import clone -from hidimstat.utils import _check_vim_predict_method - - -def _base_permutation( - X, - y, - estimator, - n_permutations: int = 50, - loss: callable = root_mean_squared_error, - method: str = "predict", - n_jobs: int = None, - groups=None, - permutation_data=None, - update_estimator=False -): - """ - # Permutation importance - - Calculate permutation importance scores for features or feature groups in a machine learning model. - Permutation importance is a model inspection technique that measures the increase in the model's - prediction error after permuting a feature's values. A feature is considered "important" if shuffling - its values increases the model error, because the model relied on the feature for the prediction. - The implementation follows the methodology described in chapter 10 :cite:breimanRandomForests2001. - One implementation: https://github.com/SkadiEye/deepTL/blob/master/R/4-2-permfit.R - - Parameters - ---------- - X : np.ndarray of shape (n_samples, n_features) - Training data. Can be numpy array or pandas DataFrame. - y : np.ndarray of shape (n_samples,) - Target values for the model. - estimator : object - A fitted estimator object implementing scikit-learn estimator interface. - The estimator must have a fitting method and one of the following prediction methods: - 'predict', 'predict_proba', 'decision_function', or 'transform'. - n_permutations : int, default=50 - Number of times to permute each feature or feature group. - Higher values give more stable results but take longer to compute. - loss : callable, default=root_mean_squared_error - Function to measure the prediction error. Must take two arguments (y_true, y_pred) - and return a scalar value. Higher return values must indicate worse predictions. - method : str, default='predict' - The estimator method used for prediction. Must be one of: - - 'predict': Use estimator.predict() - - 'predict_proba': Use estimator.predict_proba() - - 'decision_function': Use estimator.decision_function() - - 'transform': Use estimator.transform() - random_state : int, default=None - Controls the randomness of the feature permutations. - Pass an int for reproducible results across multiple function calls. - n_jobs : int, default=None - Number of jobs to run in parallel. None means 1 unless in a joblib.parallel_backend context. - -1 means using all processors. - groups : dict, default=None - Dictionary specifying feature groups. Keys are group names and values are lists of feature - indices or feature names (if X is a pandas DataFrame). If None, each feature is treated - as its own group. - - Returns - ------- - importance : np.ndarray of shape (n_features,) or (n_groups,) - The importance scores for each feature or feature group. - Higher values indicate more important features. - list_loss_j : np.ndarray - Array containing all computed loss values for each permutation of each feature/group. - loss_reference : float - The reference loss (baseline) computed on the original, non-permuted data. - - Notes - ----- - The implementation supports both individual feature importance and group feature importance. - For group importance, features within the same group are permuted together. - - References - ---------- - .. footbibliography:: - """ - - # check parameters - _check_vim_predict_method(method) - - # management of the group - if groups is None: - n_groups = X.shape[1] - groups_ = {j: [j] for j in range(n_groups)} - else: - n_groups = len(groups) - if type(list(groups.values())[0][0]) is str: - groups_ = {} - for key, indexe_names in zip(groups.keys(), groups.values()): - groups_[key] = [] - for index_name in indexe_names: - index = np.where(index_name == X.columns)[0] - assert len(index) == 1 - groups_[key].append(index[0]) - else: - groups_ = groups - - X_ = np.asarray(X) # avoid the management of panda dataframe - - # compute the reference residual - try: - y_pred = getattr(estimator, method)(X) - estimator_ = estimator - except NotFittedError: - estimator_ = clone(estimator) - # case for not fitted esimator - estimator_.fit(X_, y) - y_pred = getattr(estimator_, method)(X) - loss_reference = loss(y, y_pred) - - # Parallelize the computation of the residual for each permutation - # of each group - if permutation_data is None: - raise ValueError("Require a function") - list_result = Parallel(n_jobs=n_jobs)( - delayed(_predict_one_group_generic)( - j, - estimator_, - groups_[j], - X_, - y, - loss, - n_permutations, - method, - permutation_data=permutation_data, - update_estimator=update_estimator +from hidimstat.base_perturbation import BasePerturbation + + +class PermutationImportance(BasePerturbation): + def __init__( + self, + estimator, + loss: callable = root_mean_squared_error, + method: str = "predict", + n_jobs: int = 1, + n_permutations: int = 50, + random_state: int = None, + ): + """ + Permutation Importance algorithm as presented in + :footcite:t:`breimanRandomForests2001`. For each variable/group of variables, + the importance is computed as the difference between the loss of the initial + model and the loss of the model with the variable/group permuted. + + Parameters + ---------- + estimator : object + The estimator to use for the prediction. + loss : callable, default=root_mean_squared_error + The loss function to use when comparing the perturbed model to the full + model. + method : str, default="predict" + The method to use for the prediction. This determines the predictions passed + to the loss function. + n_jobs : int, default=1 + The number of jobs to run in parallel. Parallelization is done over the + variables or groups of variables. + n_permutations : int, default=50 + The number of permutations to perform. For each variable/group of variables, + the mean of the losses over the `n_permutations` is computed. + random_state : int, default=None + The random state to use for sampling. + """ + super().__init__( + estimator=estimator, + loss=loss, + method=method, + n_jobs=n_jobs, + n_permutations=n_permutations, ) - for j in groups_.keys() - ) - list_loss_j = np.array([i[0] for i in list_result]) - list_additional_output = [i[1] for i in list_result] - - # compute the importance - # equation 5 of mi2021permutation - importance = np.mean(list_loss_j - loss_reference, axis=1) - - return (importance, list_loss_j, loss_reference), list_additional_output - - -def _predict_one_group_generic( - index_group, estimator, group_ids, X, y, loss, n_permutations, method, permutation_data=None, - update_estimator=False -): - """ - Compute prediction loss scores after permuting a single group of features. - - Parameters - ---------- - estimator : object - Fitted estimator implementing scikit-learn API - group_ids : list - Indices of features in the group to permute - X : np.ndarray - Input data matrix - y : np.ndarray - Target values - loss : callable - Loss function to evaluate predictions - n_permutations : int - Number of permutations to perform - rng : RandomState - Random number generator instance - method : str - Prediction method to use ('predict', 'predict_proba', etc.) - - Returns - ------- - list - Loss values for each permutation - """ - # get ids - non_group_ids = np.delete(np.arange(X.shape[1]), group_ids) - - # get data - X_j = X[:, group_ids].copy() - X_minus_j = np.delete(X, group_ids, axis=1) - - # Create an array X_perm_j of shape (n_permutations, n_samples, n_features) - # where the j-th group of covariates is permuted - X_perm_j = np.empty((n_permutations, X.shape[0], X.shape[1])) - X_perm_j[:, :, non_group_ids] = X_minus_j - - if permutation_data is None: - raise ValueError("require a function") - else: - X_perm_j, additional_output = permutation_data(index_group=index_group, - X_minus_j=X_minus_j, X_j=X_j, X_perm_j=X_perm_j, group_ids=group_ids - ) - if update_estimator: - estimator = additional_output[0] - additional_output = additional_output[1] - - # Reshape X_perm_j to allow for remove the indexation by groups - y_pred_perm = getattr(estimator, method)(X_perm_j) - - if y_pred_perm.ndim == 1: - # one value per y: regression - y_pred_perm = y_pred_perm.reshape(n_permutations, X.shape[0]) - else: - # probability per y: classification - y_pred_perm = y_pred_perm.reshape( - n_permutations, X.shape[0], y_pred_perm.shape[1] - ) - loss_i = [loss(y, y_pred_perm[i]) for i in range(n_permutations)] - return loss_i, additional_output - - -def permutation_importance( - *args, - # additional argument - random_state: int = None, - n_permutations: int = 50, - **kwargs, -): - # define a random generator - check_random_state(random_state) - rng = np.random.RandomState(random_state) - - def permute_column(index_group, X_minus_j, X_j, X_perm_j, group_ids): - # Create the permuted data for the j-th group of covariates - group_j_permuted = np.array( - [rng.permutation(X_j) for _ in range(n_permutations)] - ) - X_perm_j[:, :, group_ids] = group_j_permuted - X_perm_j = X_perm_j.reshape(-1, X_minus_j.shape[1] + X_j.shape[1]) - return X_perm_j, None - - result, _ = _base_permutation( - *args, **kwargs, n_permutations=n_permutations, permutation_data=permute_column - ) - return result - - -def loco( - X_train, - y_train, - *args, - # additional argument - **kwargs, -): - if len(args)>=3: - estimator = clone(args[2]) - else: - estimator = kwargs['estimator'] - X_train_ = np.asarray(X_train) - - - def create_new_estimator(index_group, X_minus_j, X_j, X_perm_j, group_ids): - # Modify the actual estimator for fitting without the colomn j - X_train_minus_j = np.delete(X_train_, group_ids, axis=1) - estimator_ = clone(estimator) - estimator_.fit(X_train_minus_j, y_train) - X_perm_j = X_minus_j - return X_perm_j, (estimator_, estimator_) - - result, list_estimator = _base_permutation( - *args, **kwargs, n_permutations=1, permutation_data=create_new_estimator, update_estimator=True - ) - return result - - -def cpi( - X_train, - *args, - # additional argument - imputation_model=None, - imputation_method: str = "predict", - random_state: int = None, - distance_residual: callable = np.subtract, - n_permutations: int = 50, - **kwargs, -): - X_train_ = np.asarray(X_train) - if imputation_model is None: - raise ValueError("missing estimator for imputation") - n_permutations = n_permutations - # define a random generator - check_random_state(random_state) - rng = np.random.RandomState(random_state) - - def permutation_conditional(index_group, X_minus_j, X_j, X_perm_j, group_ids): - X_train_j = X_train_[:, group_ids].copy() - X_train_minus_j = np.delete(X_train_, group_ids, axis=1) - # create X from residual - # add one parameter: estimator_imputation - if type(imputation_model) is list or type(imputation_model) is dict: - estimator_ = imputation_model[index_group] - else: - estimator_ = clone(imputation_model) - estimator_.fit(X_train_minus_j, X_train_j) - - # Reshape X_perm_j to allow for remove the indexation by groups - X_j_hat = getattr(estimator_, imputation_method)(X_minus_j) - - if X_j_hat.ndim == 1 or X_j_hat.shape[1] == 1: - # one value per X_j_hat: regression - X_j_hat = X_j_hat.reshape(X_j.shape) - else: - # probability per X_j_hat: classification - X_j_hat = X_j_hat.reshape(X_j.shape[0], X_j_hat.shape[1]) - residual_j = distance_residual(X_j, X_j_hat) + self.rng = np.random.RandomState(random_state) + def _permutation(self, X, group_id): # Create the permuted data for the j-th group of covariates - residual_j_perm = np.array( - [rng.permutation(residual_j) for _ in range(n_permutations)] + X_perm_j = np.array( + [ + self.rng.permutation(X[:, self._groups_ids[group_id]].copy()) + for _ in range(self.n_permutations) + ] ) - X_perm_j[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm - - X_perm_j = X_perm_j.reshape(-1, X_minus_j.shape[1] + X_j.shape[1]) - - return X_perm_j, estimator_ - - result, list_estimator = _base_permutation( - *args, - **kwargs, - n_permutations=n_permutations, - permutation_data=permutation_conditional, - ) - return result + return X_perm_j diff --git a/src/hidimstat/permutation_importance_classes.py b/src/hidimstat/permutation_importance_classes.py deleted file mode 100644 index 6c24cb2..0000000 --- a/src/hidimstat/permutation_importance_classes.py +++ /dev/null @@ -1,397 +0,0 @@ -import numpy as np -import pandas as pd -from joblib import Parallel, delayed -from sklearn.base import BaseEstimator, check_is_fitted, clone -from sklearn.metrics import root_mean_squared_error - -from hidimstat.conditional_sampling import ConditionalSampler -from hidimstat.utils import _check_vim_predict_method - - -class BasePerturbation(BaseEstimator): - def __init__( - self, - estimator, - loss: callable = root_mean_squared_error, - n_permutations: int = 50, - method: str = "predict", - n_jobs: int = 1, - ): - """ - Base class for model agnostic variable importance measures based on - perturbation. - - Parameters - ---------- - estimator : object - The estimator to use for the prediction. - loss : callable, default=root_mean_squared_error - The loss function to use when comparing the perturbed model to the full - model. - n_permutations : int, default=50 - Only for PermutationImportance or CPI. The number of permutations to - perform. - method : str, default="predict" - The method to use for the prediction. This determines the predictions passed - to the loss function. - n_jobs : int, default=1 - The number of jobs to run in parallel. Parallelization is done over the - variables or groups of variables. - - """ - check_is_fitted(estimator) - self.estimator = estimator - self.loss = loss - _check_vim_predict_method(method) - self.method = method - self.n_jobs = n_jobs - self.n_permutations = n_permutations - self.n_groups = None - - def fit(self, X, y, groups=None): - if groups is None: - self.n_groups = X.shape[1] - self.groups = {j: [j] for j in range(self.n_groups)} - self._groups_ids = np.array(list(self.groups.values()), dtype=int) - else: - self.n_groups = len(groups) - self.groups = groups - if isinstance(X, pd.DataFrame): - self._groups_ids = [] - for group_key in self.groups.keys(): - self._groups_ids.append( - [ - i - for i, col in enumerate(X.columns) - if col in self.groups[group_key] - ] - ) - else: - self._groups_ids = np.array(list(self.groups.values()), dtype=int) - - def predict(self, X): - self._check_fit() - X_ = np.asarray(X) - - # Parallelize the computation of the importance scores for each group - out_list = Parallel(n_jobs=self.n_jobs)( - delayed(self._joblib_predict_one_group)(X_, group_id, group_key) - for group_id, group_key in enumerate(self.groups.keys()) - ) - return np.stack(out_list, axis=0) - - def score(self, X, y): - self._check_fit() - - out_dict = dict() - - y_pred = getattr(self.estimator, self.method)(X) - loss_reference = self.loss(y, y_pred) - out_dict["loss_reference"] = loss_reference - - y_pred = self.predict(X) - out_dict["loss"] = dict() - for j, y_pred_j in enumerate(y_pred): - list_loss = [] - for y_pred_perm in y_pred_j: - list_loss.append(self.loss(y, y_pred_perm)) - out_dict["loss"][j] = np.array(list_loss) - - out_dict["importance"] = np.array( - [ - np.mean(out_dict["loss"][j]) - loss_reference - for j in range(self.n_groups) - ] - ) - return out_dict - - def _check_fit(self): - pass - - def _joblib_predict_one_group(self, X, group_id, group_key): - group_ids = self._groups_ids[group_id] - non_group_ids = np.delete(np.arange(X.shape[1]), group_ids) - # Create an array X_perm_j of shape (n_permutations, n_samples, n_features) - # where the j-th group of covariates is permuted - X_perm = np.empty((self.n_permutations, X.shape[0], X.shape[1])) - X_perm[:, :, non_group_ids] = np.delete(X, group_ids, axis=1) - X_perm[:, :, group_ids] = self._permutation(X, group_id=group_id) - # Reshape X_perm to allow for batch prediction - X_perm_batch = X_perm.reshape(-1, X.shape[1]) - y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch) - - # In case of classification, the output is a 2D array. Reshape accordingly - if y_pred_perm.ndim == 1: - y_pred_perm = y_pred_perm.reshape(self.n_permutations, X.shape[0]) - else: - y_pred_perm = y_pred_perm.reshape( - self.n_permutations, X.shape[0], y_pred_perm.shape[1] - ) - return y_pred_perm - - def _permutation(self, X, group_id): - raise NotImplementedError - - -class PermutationImportance(BasePerturbation): - def __init__( - self, - estimator, - loss: callable = root_mean_squared_error, - method: str = "predict", - n_jobs: int = 1, - n_permutations: int = 50, - random_state: int = None, - ): - """ - Permutation Importance algorithm as presented in - :footcite:t:`breimanRandomForests2001`. For each variable/group of variables, - the importance is computed as the difference between the loss of the initial - model and the loss of the model with the variable/group permuted. - - Parameters - ---------- - estimator : object - The estimator to use for the prediction. - loss : callable, default=root_mean_squared_error - The loss function to use when comparing the perturbed model to the full - model. - method : str, default="predict" - The method to use for the prediction. This determines the predictions passed - to the loss function. - n_jobs : int, default=1 - The number of jobs to run in parallel. Parallelization is done over the - variables or groups of variables. - n_permutations : int, default=50 - The number of permutations to perform. For each variable/group of variables, - the mean of the losses over the `n_permutations` is computed. - random_state : int, default=None - The random state to use for sampling. - """ - super().__init__( - estimator=estimator, - loss=loss, - method=method, - n_jobs=n_jobs, - n_permutations=n_permutations, - ) - self.rng = np.random.RandomState(random_state) - - def _permutation(self, X, group_id): - # Create the permuted data for the j-th group of covariates - X_perm_j = np.array( - [ - self.rng.permutation(X[:, self._groups_ids[group_id]].copy()) - for _ in range(self.n_permutations) - ] - ) - return X_perm_j - - -class LOCO(BasePerturbation): - def __init__( - self, - estimator, - loss: callable = root_mean_squared_error, - method: str = "predict", - n_jobs: int = 1, - ): - """ - Leave-One-Covariate-Out (LOCO) as presented in - :footcite:t:`Williamson_General_2023`. The model is re-fitted for each variable/ - group of variables. The importance is then computed as the difference between - the loss of the full model and the loss of the model without the variable/group. - - Parameters - ---------- - estimator : object - The estimator to use for the prediction. - loss : callable, default=root_mean_squared_error - The loss function to use when comparing the perturbed model to the full - model. - method : str, default="predict" - The method to use for the prediction. This determines the predictions passed - to the loss function. - n_jobs : int, default=1 - The number of jobs to run in parallel. Parallelization is done over the - variables or groups of variables. - """ - super().__init__( - estimator=estimator, - loss=loss, - method=method, - n_jobs=n_jobs, - n_permutations=1, - ) - self._list_estimators = [] - - def fit(self, X, y, groups=None): - super().fit(X, y, groups) - # create a list of covariate estimators for each group if not provided - self._list_estimators = [clone(self.estimator) for _ in range(self.n_groups)] - - # Parallelize the fitting of the covariate estimators - self._list_estimators = Parallel(n_jobs=self.n_jobs)( - delayed(self._joblib_fit_one_group)(estimator, X, y, key_groups) - for key_groups, estimator in zip(self.groups.keys(), self._list_estimators) - ) - return self - - def _joblib_fit_one_group(self, estimator, X, y, key_groups): - if isinstance(X, pd.DataFrame): - X_minus_j = X.drop(columns=self.groups[key_groups]) - else: - X_minus_j = np.delete(X, self.groups[key_groups], axis=1) - estimator.fit(X_minus_j, y) - return estimator - - def _joblib_predict_one_group(self, X, group_id, key_groups): - X_minus_j = np.delete(X, self._groups_ids[group_id], axis=1) - - y_pred_loco = getattr(self._list_estimators[group_id], self.method)(X_minus_j) - - return [y_pred_loco] - - def _check_fit(self): - check_is_fitted(self.estimator) - if len(self._list_estimators) == 0: - raise ValueError("The estimators require to be fit before to use them") - for m in self._list_estimators: - check_is_fitted(m) - - -class CPI(BasePerturbation): - - def __init__( - self, - estimator, - loss: callable = root_mean_squared_error, - method: str = "predict", - n_jobs: int = 1, - n_permutations: int = 50, - imputation_model_continuous=None, - imputation_model_binary=None, - imputation_model_classification=None, - random_state: int = None, - categorical_max_cardinality: int = 10, - ): - """ - Conditional Permutation Importance (CPI) algorithm. - :footcite:t:`Chamma_NeurIPS2023` and for group-level see - :footcite:t:`Chamma_AAAI2024`. - - Parameters - ---------- - estimator : object - The estimator to use for the prediction. - loss : callable, default=root_mean_squared_error - The loss function to use when comparing the perturbed model to the full - model. - method : str, default="predict" - The method to use for the prediction. This determines the predictions passed - to the loss function. - n_jobs : int, default=1 - The number of jobs to run in parallel. Parallelization is done over the - variables or groups of variables. - n_permutations : int, default=50 - The number of permutations to perform. For each variable/group of variables, - the mean of the losses over the `n_permutations` is computed. - imputation_model_continuous : object, default=None - The model used to estimate the conditional distribution of a given - continuous variable/group of variables given the others. - imputation_model_binary : object, default=None - The model used to estimate the conditional distribution of a given - binary variable/group of variables given the others. - imputation_model_classification : object, default=None - The model used to estimate the conditional distribution of a given - categorical variable/group of variables given the others. - random_state : int, default=None - The random state to use for sampling. - categorical_max_cardinality : int, default=10 - The maximum cardinality of a variable to be considered as categorical - when the variable type is inferred (set to "auto" or not provided). - """ - - super().__init__( - estimator=estimator, - loss=loss, - method=method, - n_jobs=n_jobs, - n_permutations=n_permutations, - ) - self.rng = np.random.RandomState(random_state) - self._list_imputation_models = [] - - self.imputation_model = { - "continuous": imputation_model_continuous, - "binary": imputation_model_binary, - "categorical": imputation_model_classification, - } - self.categorical_max_cardinality = categorical_max_cardinality - - def fit(self, X, groups=None, var_type="auto"): - super().fit(X, None, groups=groups) - if isinstance(var_type, str): - self.var_type = [var_type for _ in range(self.n_groups)] - else: - self.var_type = var_type - - self._list_imputation_models = [ - ConditionalSampler( - data_type=self.var_type[groupd_id], - model_regression=( - None - if self.imputation_model["continuous"] is None - else clone(self.imputation_model["continuous"]) - ), - model_binary=( - None - if self.imputation_model["binary"] is None - else clone(self.imputation_model["binary"]) - ), - model_categorical=( - None - if self.imputation_model["categorical"] is None - else clone(self.imputation_model["categorical"]) - ), - random_state=self.rng, - categorical_max_cardinality=self.categorical_max_cardinality, - ) - for groupd_id in range(self.n_groups) - ] - - # Parallelize the fitting of the covariate estimators - X_ = np.asarray(X) - self._list_imputation_models = Parallel(n_jobs=self.n_jobs)( - delayed(self._joblib_fit_one_group)(estimator, X_, groups_ids) - for groups_ids, estimator in zip( - self._groups_ids, self._list_imputation_models - ) - ) - - return self - - def _joblib_fit_one_group(self, estimator, X, groups_ids): - X_ = self._remove_nan(X) - X_j = X_[:, groups_ids].copy() - X_minus_j = np.delete(X_, groups_ids, axis=1) - estimator.fit(X_minus_j, X_j) - return estimator - - def _check_fit(self): - if len(self._list_imputation_models) == 0: - raise ValueError("The estimators require to be fit before to use them") - for m in self._list_imputation_models: - check_is_fitted(m.model) - - def _permutation(self, X, group_id): - X_ = self._remove_nan(X) - X_j = X_[:, self._groups_ids[group_id]].copy() - X_minus_j = np.delete(X_, self._groups_ids[group_id], axis=1) - return self._list_imputation_models[group_id].sample( - X_minus_j, X_j, n_samples=self.n_permutations - ) - - def _remove_nan(self, X): - # TODO: specify the strategy to handle NaN values - return X diff --git a/src/hidimstat/permutation_importance_func.py b/src/hidimstat/permutation_importance_func.py new file mode 100644 index 0000000..83d45fc --- /dev/null +++ b/src/hidimstat/permutation_importance_func.py @@ -0,0 +1,342 @@ +import numpy as np +from joblib import Parallel, delayed +from sklearn.base import clone +from sklearn.exceptions import NotFittedError +from sklearn.metrics import root_mean_squared_error +from sklearn.utils import check_random_state + +from hidimstat.utils import _check_vim_predict_method + + +def _base_permutation( + X, + y, + estimator, + n_permutations: int = 50, + loss: callable = root_mean_squared_error, + method: str = "predict", + n_jobs: int = None, + groups=None, + permutation_data=None, + update_estimator=False, +): + """ + # Permutation importance + + Calculate permutation importance scores for features or feature groups in a machine learning model. + Permutation importance is a model inspection technique that measures the increase in the model's + prediction error after permuting a feature's values. A feature is considered "important" if shuffling + its values increases the model error, because the model relied on the feature for the prediction. + The implementation follows the methodology described in chapter 10 :cite:breimanRandomForests2001. + One implementation: https://github.com/SkadiEye/deepTL/blob/master/R/4-2-permfit.R + + Parameters + ---------- + X : np.ndarray of shape (n_samples, n_features) + Training data. Can be numpy array or pandas DataFrame. + y : np.ndarray of shape (n_samples,) + Target values for the model. + estimator : object + A fitted estimator object implementing scikit-learn estimator interface. + The estimator must have a fitting method and one of the following prediction methods: + 'predict', 'predict_proba', 'decision_function', or 'transform'. + n_permutations : int, default=50 + Number of times to permute each feature or feature group. + Higher values give more stable results but take longer to compute. + loss : callable, default=root_mean_squared_error + Function to measure the prediction error. Must take two arguments (y_true, y_pred) + and return a scalar value. Higher return values must indicate worse predictions. + method : str, default='predict' + The estimator method used for prediction. Must be one of: + - 'predict': Use estimator.predict() + - 'predict_proba': Use estimator.predict_proba() + - 'decision_function': Use estimator.decision_function() + - 'transform': Use estimator.transform() + random_state : int, default=None + Controls the randomness of the feature permutations. + Pass an int for reproducible results across multiple function calls. + n_jobs : int, default=None + Number of jobs to run in parallel. None means 1 unless in a joblib.parallel_backend context. + -1 means using all processors. + groups : dict, default=None + Dictionary specifying feature groups. Keys are group names and values are lists of feature + indices or feature names (if X is a pandas DataFrame). If None, each feature is treated + as its own group. + + Returns + ------- + importance : np.ndarray of shape (n_features,) or (n_groups,) + The importance scores for each feature or feature group. + Higher values indicate more important features. + list_loss_j : np.ndarray + Array containing all computed loss values for each permutation of each feature/group. + loss_reference : float + The reference loss (baseline) computed on the original, non-permuted data. + + Notes + ----- + The implementation supports both individual feature importance and group feature importance. + For group importance, features within the same group are permuted together. + + References + ---------- + .. footbibliography:: + """ + + # check parameters + _check_vim_predict_method(method) + + # management of the group + if groups is None: + n_groups = X.shape[1] + groups_ = {j: [j] for j in range(n_groups)} + else: + n_groups = len(groups) + if type(list(groups.values())[0][0]) is str: + groups_ = {} + for key, indexe_names in zip(groups.keys(), groups.values()): + groups_[key] = [] + for index_name in indexe_names: + index = np.where(index_name == X.columns)[0] + assert len(index) == 1 + groups_[key].append(index[0]) + else: + groups_ = groups + + X_ = np.asarray(X) # avoid the management of panda dataframe + + # compute the reference residual + try: + y_pred = getattr(estimator, method)(X) + estimator_ = estimator + except NotFittedError: + estimator_ = clone(estimator) + # case for not fitted esimator + estimator_.fit(X_, y) + y_pred = getattr(estimator_, method)(X) + loss_reference = loss(y, y_pred) + + # Parallelize the computation of the residual for each permutation + # of each group + if permutation_data is None: + raise ValueError("Require a function") + list_result = Parallel(n_jobs=n_jobs)( + delayed(_predict_one_group_generic)( + j, + estimator_, + groups_[j], + X_, + y, + loss, + n_permutations, + method, + permutation_data=permutation_data, + update_estimator=update_estimator, + ) + for j in groups_.keys() + ) + list_loss_j = np.array([i[0] for i in list_result]) + list_additional_output = [i[1] for i in list_result] + + # compute the importance + # equation 5 of mi2021permutation + importance = np.mean(list_loss_j - loss_reference, axis=1) + + return (importance, list_loss_j, loss_reference), list_additional_output + + +def _predict_one_group_generic( + index_group, + estimator, + group_ids, + X, + y, + loss, + n_permutations, + method, + permutation_data=None, + update_estimator=False, +): + """ + Compute prediction loss scores after permuting a single group of features. + + Parameters + ---------- + estimator : object + Fitted estimator implementing scikit-learn API + group_ids : list + Indices of features in the group to permute + X : np.ndarray + Input data matrix + y : np.ndarray + Target values + loss : callable + Loss function to evaluate predictions + n_permutations : int + Number of permutations to perform + rng : RandomState + Random number generator instance + method : str + Prediction method to use ('predict', 'predict_proba', etc.) + + Returns + ------- + list + Loss values for each permutation + """ + # get ids + non_group_ids = np.delete(np.arange(X.shape[1]), group_ids) + + # get data + X_j = X[:, group_ids].copy() + X_minus_j = np.delete(X, group_ids, axis=1) + + # Create an array X_perm_j of shape (n_permutations, n_samples, n_features) + # where the j-th group of covariates is permuted + X_perm_j = np.empty((n_permutations, X.shape[0], X.shape[1])) + X_perm_j[:, :, non_group_ids] = X_minus_j + + if permutation_data is None: + raise ValueError("require a function") + else: + X_perm_j, additional_output = permutation_data( + index_group=index_group, + X_minus_j=X_minus_j, + X_j=X_j, + X_perm_j=X_perm_j, + group_ids=group_ids, + ) + if update_estimator: + estimator = additional_output[0] + additional_output = additional_output[1] + + # Reshape X_perm_j to allow for remove the indexation by groups + y_pred_perm = getattr(estimator, method)(X_perm_j) + + if y_pred_perm.ndim == 1: + # one value per y: regression + y_pred_perm = y_pred_perm.reshape(n_permutations, X.shape[0]) + else: + # probability per y: classification + y_pred_perm = y_pred_perm.reshape( + n_permutations, X.shape[0], y_pred_perm.shape[1] + ) + loss_i = [loss(y, y_pred_perm[i]) for i in range(n_permutations)] + return loss_i, additional_output + + +def permutation_importance( + *args, + # additional argument + random_state: int = None, + n_permutations: int = 50, + **kwargs, +): + # define a random generator + check_random_state(random_state) + rng = np.random.RandomState(random_state) + + def permute_column(index_group, X_minus_j, X_j, X_perm_j, group_ids): + # Create the permuted data for the j-th group of covariates + group_j_permuted = np.array( + [rng.permutation(X_j) for _ in range(n_permutations)] + ) + X_perm_j[:, :, group_ids] = group_j_permuted + X_perm_j = X_perm_j.reshape(-1, X_minus_j.shape[1] + X_j.shape[1]) + return X_perm_j, None + + result, _ = _base_permutation( + *args, **kwargs, n_permutations=n_permutations, permutation_data=permute_column + ) + return result + + +def loco( + X_train, + y_train, + *args, + # additional argument + **kwargs, +): + if len(args) >= 3: + estimator = clone(args[2]) + else: + estimator = kwargs["estimator"] + X_train_ = np.asarray(X_train) + + def create_new_estimator(index_group, X_minus_j, X_j, X_perm_j, group_ids): + # Modify the actual estimator for fitting without the colomn j + X_train_minus_j = np.delete(X_train_, group_ids, axis=1) + estimator_ = clone(estimator) + estimator_.fit(X_train_minus_j, y_train) + X_perm_j = X_minus_j + return X_perm_j, (estimator_, estimator_) + + result, list_estimator = _base_permutation( + *args, + **kwargs, + n_permutations=1, + permutation_data=create_new_estimator, + update_estimator=True, + ) + return result + + +def cpi( + X_train, + *args, + # additional argument + imputation_model=None, + imputation_method: str = "predict", + random_state: int = None, + distance_residual: callable = np.subtract, + n_permutations: int = 50, + **kwargs, +): + X_train_ = np.asarray(X_train) + if imputation_model is None: + raise ValueError("missing estimator for imputation") + n_permutations = n_permutations + # define a random generator + check_random_state(random_state) + rng = np.random.RandomState(random_state) + + def permutation_conditional(index_group, X_minus_j, X_j, X_perm_j, group_ids): + X_train_j = X_train_[:, group_ids].copy() + X_train_minus_j = np.delete(X_train_, group_ids, axis=1) + # create X from residual + # add one parameter: estimator_imputation + if type(imputation_model) is list or type(imputation_model) is dict: + estimator_ = imputation_model[index_group] + else: + estimator_ = clone(imputation_model) + estimator_.fit(X_train_minus_j, X_train_j) + + # Reshape X_perm_j to allow for remove the indexation by groups + X_j_hat = getattr(estimator_, imputation_method)(X_minus_j) + + if X_j_hat.ndim == 1 or X_j_hat.shape[1] == 1: + # one value per X_j_hat: regression + X_j_hat = X_j_hat.reshape(X_j.shape) + else: + # probability per X_j_hat: classification + X_j_hat = X_j_hat.reshape(X_j.shape[0], X_j_hat.shape[1]) + residual_j = distance_residual(X_j, X_j_hat) + + # Create the permuted data for the j-th group of covariates + residual_j_perm = np.array( + [rng.permutation(residual_j) for _ in range(n_permutations)] + ) + X_perm_j[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm + + X_perm_j = X_perm_j.reshape(-1, X_minus_j.shape[1] + X_j.shape[1]) + + return X_perm_j, estimator_ + + result, list_estimator = _base_permutation( + *args, + **kwargs, + n_permutations=n_permutations, + permutation_data=permutation_conditional, + ) + return result