Skip to content

Commit

Permalink
Fix type checK
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Nov 8, 2023
1 parent 30e9e95 commit 36c5582
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 17 deletions.
22 changes: 10 additions & 12 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Tuple, Union
from typing import Optional, Callable, Tuple, Union

import numpy as np
from joblib import Parallel, delayed
Expand Down Expand Up @@ -54,7 +54,7 @@ def _parallel_build_trees_with_sepdata(
covariate_index,
bootstrap: bool,
max_samples,
sample_weight: ArrayLike = None,
sample_weight: Optional[ArrayLike] = None,
class_weight=None,
missing_values_in_feature_mask=None,
classes=None,
Expand All @@ -76,6 +76,7 @@ def _parallel_build_trees_with_sepdata(
else:
n_samples_bootstrap = None

# XXX: this currently creates a copy of X_train on RAM, which is not ideal
# individual tree permutation of y labels
if covariate_index is not None:
indices = np.arange(X_train.shape[0], dtype=int)
Expand Down Expand Up @@ -280,7 +281,7 @@ def _statistic(
):
raise NotImplementedError("Subclasses should implement this!")

def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike = None):
def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: Optional[ArrayLike] = None):
X, y = check_X_y(X, y, ensure_2d=True, copy=True, multi_output=True, dtype=DTYPE)
if y.ndim != 2:
y = y.reshape(-1, 1)
Expand Down Expand Up @@ -315,17 +316,15 @@ def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike =
)

if not self.train_test_split and not isinstance(self.estimator, HonestForestClassifier):
raise RuntimeError(
"Train test split must occur if not using honest forest classifier."
)
raise RuntimeError("Train test split must occur if not using honest forest classifier.")

return X, y, covariate_index

def statistic(
self,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric="mi",
return_posteriors: bool = False,
check_input: bool = True,
Expand Down Expand Up @@ -444,7 +443,7 @@ def test(
self,
X,
y,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric: str = "mi",
n_repeats: int = 1000,
return_posteriors: bool = True,
Expand Down Expand Up @@ -612,7 +611,8 @@ class FeatureImportanceForestRegressor(BaseForestHT):
permute_forest_fraction : float, default=None
The fraction of trees to permute the covariate index for. If None, then
just one permutation is performed.
just one permutation is performed. If sampling a permutation per tree
is desirable, then the fraction should be set to ``1. / n_estimators``.
train_test_split : bool, default=True
Whether to split the dataset before passing to the forest.
Expand Down Expand Up @@ -669,7 +669,6 @@ def __init__(
random_state=None,
verbose=0,
test_size=0.2,
# permute_per_tree=False,
sample_dataset_per_tree=False,
permute_forest_fraction=None,
train_test_split=True,
Expand All @@ -679,7 +678,6 @@ def __init__(
random_state=random_state,
verbose=verbose,
test_size=test_size,
# permute_per_tree=permute_per_tree,
sample_dataset_per_tree=sample_dataset_per_tree,
permute_forest_fraction=permute_forest_fraction,
train_test_split=train_test_split,
Expand All @@ -698,7 +696,7 @@ def statistic(
self,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric="mse",
return_posteriors: bool = False,
check_input: bool = True,
Expand Down
5 changes: 3 additions & 2 deletions sktree/stats/permutationforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sktree._lib.sklearn.ensemble._forest import BaseForest, ForestClassifier, ForestRegressor

from .utils import METRIC_FUNCTIONS, REGRESSOR_METRICS, _compute_null_distribution_perm
from typing import Optional


class BasePermutationForest(MetaEstimatorMixin):
Expand Down Expand Up @@ -62,7 +63,7 @@ def _statistic(
estimator: BaseForest,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric="mse",
return_posteriors: bool = False,
seed=None,
Expand Down Expand Up @@ -117,7 +118,7 @@ def statistic(
self,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric="mse",
return_posteriors: bool = False,
check_input: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions sktree/stats/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Optional, Tuple

import numpy as np
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -112,7 +112,7 @@ def _compute_null_distribution_perm(
est: ForestClassifier,
metric: str = "mse",
n_repeats: int = 1000,
seed: int = None,
seed: Optional[int] = None,
) -> ArrayLike:
"""Compute null distribution using permutation method.
Expand Down Expand Up @@ -173,7 +173,7 @@ def _compute_null_distribution_coleman(
y_pred_proba_perm: ArrayLike,
metric: str = "mse",
n_repeats: int = 1000,
seed: int = None,
seed: Optional[int] = None,
) -> Tuple[ArrayLike, ArrayLike]:
"""Compute null distribution using Coleman method.
Expand Down

0 comments on commit 36c5582

Please sign in to comment.