From 36c5582be3e3df51bfffd3eb6c10e67cd2c5e2b2 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 8 Nov 2023 09:09:01 -0500 Subject: [PATCH] Fix type checK Signed-off-by: Adam Li --- sktree/stats/forestht.py | 22 ++++++++++------------ sktree/stats/permutationforest.py | 5 +++-- sktree/stats/utils.py | 6 +++--- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/sktree/stats/forestht.py b/sktree/stats/forestht.py index 7f35ff9f4..130739ace 100644 --- a/sktree/stats/forestht.py +++ b/sktree/stats/forestht.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple, Union +from typing import Optional, Callable, Tuple, Union import numpy as np from joblib import Parallel, delayed @@ -54,7 +54,7 @@ def _parallel_build_trees_with_sepdata( covariate_index, bootstrap: bool, max_samples, - sample_weight: ArrayLike = None, + sample_weight: Optional[ArrayLike] = None, class_weight=None, missing_values_in_feature_mask=None, classes=None, @@ -76,6 +76,7 @@ def _parallel_build_trees_with_sepdata( else: n_samples_bootstrap = None + # XXX: this currently creates a copy of X_train on RAM, which is not ideal # individual tree permutation of y labels if covariate_index is not None: indices = np.arange(X_train.shape[0], dtype=int) @@ -280,7 +281,7 @@ def _statistic( ): raise NotImplementedError("Subclasses should implement this!") - def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike = None): + def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: Optional[ArrayLike] = None): X, y = check_X_y(X, y, ensure_2d=True, copy=True, multi_output=True, dtype=DTYPE) if y.ndim != 2: y = y.reshape(-1, 1) @@ -315,9 +316,7 @@ def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike = ) if not self.train_test_split and not isinstance(self.estimator, HonestForestClassifier): - raise RuntimeError( - "Train test split must occur if not using honest forest classifier." - ) + raise RuntimeError("Train test split must occur if not using honest forest classifier.") return X, y, covariate_index @@ -325,7 +324,7 @@ def statistic( self, X: ArrayLike, y: ArrayLike, - covariate_index: ArrayLike = None, + covariate_index: Optional[ArrayLike] = None, metric="mi", return_posteriors: bool = False, check_input: bool = True, @@ -444,7 +443,7 @@ def test( self, X, y, - covariate_index: ArrayLike = None, + covariate_index: Optional[ArrayLike] = None, metric: str = "mi", n_repeats: int = 1000, return_posteriors: bool = True, @@ -612,7 +611,8 @@ class FeatureImportanceForestRegressor(BaseForestHT): permute_forest_fraction : float, default=None The fraction of trees to permute the covariate index for. If None, then - just one permutation is performed. + just one permutation is performed. If sampling a permutation per tree + is desirable, then the fraction should be set to ``1. / n_estimators``. train_test_split : bool, default=True Whether to split the dataset before passing to the forest. @@ -669,7 +669,6 @@ def __init__( random_state=None, verbose=0, test_size=0.2, - # permute_per_tree=False, sample_dataset_per_tree=False, permute_forest_fraction=None, train_test_split=True, @@ -679,7 +678,6 @@ def __init__( random_state=random_state, verbose=verbose, test_size=test_size, - # permute_per_tree=permute_per_tree, sample_dataset_per_tree=sample_dataset_per_tree, permute_forest_fraction=permute_forest_fraction, train_test_split=train_test_split, @@ -698,7 +696,7 @@ def statistic( self, X: ArrayLike, y: ArrayLike, - covariate_index: ArrayLike = None, + covariate_index: Optional[ArrayLike] = None, metric="mse", return_posteriors: bool = False, check_input: bool = True, diff --git a/sktree/stats/permutationforest.py b/sktree/stats/permutationforest.py index 4a27f539f..78a2437a4 100644 --- a/sktree/stats/permutationforest.py +++ b/sktree/stats/permutationforest.py @@ -10,6 +10,7 @@ from sktree._lib.sklearn.ensemble._forest import BaseForest, ForestClassifier, ForestRegressor from .utils import METRIC_FUNCTIONS, REGRESSOR_METRICS, _compute_null_distribution_perm +from typing import Optional class BasePermutationForest(MetaEstimatorMixin): @@ -62,7 +63,7 @@ def _statistic( estimator: BaseForest, X: ArrayLike, y: ArrayLike, - covariate_index: ArrayLike = None, + covariate_index: Optional[ArrayLike] = None, metric="mse", return_posteriors: bool = False, seed=None, @@ -117,7 +118,7 @@ def statistic( self, X: ArrayLike, y: ArrayLike, - covariate_index: ArrayLike = None, + covariate_index: Optional[ArrayLike] = None, metric="mse", return_posteriors: bool = False, check_input: bool = True, diff --git a/sktree/stats/utils.py b/sktree/stats/utils.py index f21dd00a8..7ff7a9f05 100644 --- a/sktree/stats/utils.py +++ b/sktree/stats/utils.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import numpy as np from numpy.typing import ArrayLike @@ -112,7 +112,7 @@ def _compute_null_distribution_perm( est: ForestClassifier, metric: str = "mse", n_repeats: int = 1000, - seed: int = None, + seed: Optional[int] = None, ) -> ArrayLike: """Compute null distribution using permutation method. @@ -173,7 +173,7 @@ def _compute_null_distribution_coleman( y_pred_proba_perm: ArrayLike, metric: str = "mse", n_repeats: int = 1000, - seed: int = None, + seed: Optional[int] = None, ) -> Tuple[ArrayLike, ArrayLike]: """Compute null distribution using Coleman method.