Skip to content

Commit

Permalink
Adding honest forest checks
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Nov 2, 2023
1 parent bc2780e commit 6d914f6
Show file tree
Hide file tree
Showing 15 changed files with 9,894 additions and 117 deletions.
743 changes: 643 additions & 100 deletions benchmarks_nonasv/bench_multiview_hyppo.ipynb

Large diffs are not rendered by default.

2,201 changes: 2,201 additions & 0 deletions benchmarks_nonasv/cv_partial_auc_mv_vs_rf_correlated_latentfactor_model.csv

Large diffs are not rendered by default.

2,401 changes: 2,401 additions & 0 deletions benchmarks_nonasv/cv_partial_auc_mv_vs_rf_ind_views_gaussian_mixture.csv

Large diffs are not rendered by default.

2,401 changes: 2,401 additions & 0 deletions benchmarks_nonasv/cv_partial_auc_mv_vs_rf_ind_views_gaussian_mixture_v2.csv

Large diffs are not rendered by default.

2,201 changes: 2,201 additions & 0 deletions benchmarks_nonasv/cv_partial_auc_mv_vs_rf_linear_transform.csv

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ tree models.
PermutationForestClassifier
PermutationForestRegressor

Datasets
------------------------------
We provide some convenience functions for simulating datasets beyond
those offered in scikit-learn.

.. currentmodule:: sktree.datasets
.. autosummary::
:toctree: generated/

make_gaussian_mixture
make_joint_factor_model
make_quadratic_classification


Experimental Functionality
--------------------------
Expand Down
7 changes: 7 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,11 @@ @misc{perry2009crossvalidation
eprint = {0909.3052},
archiveprefix = {arXiv},
primaryclass = {stat.ME}
}

@article{panda2018learning,
title = {Learning Interpretable Characteristic Kernels via Decision Forests},
author = {Panda, Sambit and Shen, Cencheng and Vogelstein, Joshua T},
journal = {arXiv preprint arXiv:1812.00029},
year = {2018}
}
2 changes: 2 additions & 0 deletions sktree/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .hyppo import make_quadratic_classification
from .multiview import make_gaussian_mixture, make_joint_factor_model
9 changes: 8 additions & 1 deletion sktree/datasets/hyppo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np


def quadratic(n_samples: int, n_features: int, noise=False, seed=None):
def make_quadratic_classification(n_samples: int, n_features: int, noise=False, seed=None):
"""Simulate classification data from a quadratic model.
This is a form of the simulation used in :footcite:`panda2018learning`.
Parameters
----------
n_samples : int
Expand All @@ -21,6 +23,10 @@ def quadratic(n_samples: int, n_features: int, noise=False, seed=None):
Data array.
v : array-like, shape (n_samples,)
Target array of 1's and 0's.
References
----------
.. footbibliography::
"""
rng = np.random.default_rng(seed)

Expand All @@ -31,6 +37,7 @@ def quadratic(n_samples: int, n_features: int, noise=False, seed=None):
x_coeffs = x * coeffs
y = x_coeffs**2 + noise * eps

# generate the classification labels
n1 = x.shape[0]
n2 = y.shape[0]
v = np.vstack([np.zeros((n1, 1)), np.ones((n2, 1))])
Expand Down
4 changes: 2 additions & 2 deletions sktree/datasets/multiview.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Original source: https://github.com/mvlearn/mvlearn
# License: MIT
# Author: Ronan Perry

import numpy as np
from scipy.stats import ortho_group
Expand Down Expand Up @@ -339,7 +339,7 @@ def make_joint_factor_model(
U = np.linalg.qr(U)[0]

# random noise for each view
Es = [noise_std * rng.standard_normal(size=(n_samples, d)) for d in zip(n_features)]
Es = [noise_std * rng.standard_normal(size=(n_samples, d)) for d in n_features]
Xs = [(U * svals) @ view_loadings[b].T + Es[b] for b in range(n_views)]

if return_decomp:
Expand Down
2 changes: 1 addition & 1 deletion sktree/experimental/mutual_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def mutual_info_ksg(
algorithm="kd_tree",
n_jobs: int = -1,
transform: str = "rank",
random_seed: int = None,
random_seed: Optional[int] = None,
):
"""Compute the generalized (conditional) mutual information KSG estimate.
Expand Down
12 changes: 6 additions & 6 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Tuple, Union
from typing import Optional, Callable, Tuple, Union

import numpy as np
from joblib import Parallel, delayed
Expand Down Expand Up @@ -43,7 +43,7 @@ def _parallel_build_trees_and_compute_posteriors(
predict_posteriors: bool,
permute_per_tree: bool,
type_of_target,
sample_weight: ArrayLike = None,
sample_weight: Optional[ArrayLike] = None,
class_weight=None,
missing_values_in_feature_mask=None,
classes=None,
Expand Down Expand Up @@ -255,7 +255,7 @@ def _statistic(
):
raise NotImplementedError("Subclasses should implement this!")

def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: ArrayLike = None):
def _check_input(self, X: ArrayLike, y: ArrayLike, covariate_index: Optional[ArrayLike] = None):
X, y = check_X_y(X, y, ensure_2d=True, copy=True, multi_output=True)
if y.ndim != 2:
y = y.reshape(-1, 1)
Expand Down Expand Up @@ -295,7 +295,7 @@ def statistic(
self,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric="mi",
return_posteriors: bool = False,
check_input: bool = True,
Expand Down Expand Up @@ -414,7 +414,7 @@ def test(
self,
X,
y,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric: str = "mi",
n_repeats: int = 1000,
return_posteriors: bool = True,
Expand Down Expand Up @@ -660,7 +660,7 @@ def statistic(
self,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric="mse",
return_posteriors: bool = False,
check_input: bool = True,
Expand Down
5 changes: 3 additions & 2 deletions sktree/stats/permutationforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sktree._lib.sklearn.ensemble._forest import BaseForest, ForestClassifier, ForestRegressor

from .utils import METRIC_FUNCTIONS, REGRESSOR_METRICS, _compute_null_distribution_perm
from typing import Optional


class BasePermutationForest(MetaEstimatorMixin):
Expand Down Expand Up @@ -62,7 +63,7 @@ def _statistic(
estimator: BaseForest,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric="mse",
return_posteriors: bool = False,
seed=None,
Expand Down Expand Up @@ -117,7 +118,7 @@ def statistic(
self,
X: ArrayLike,
y: ArrayLike,
covariate_index: ArrayLike = None,
covariate_index: Optional[ArrayLike] = None,
metric="mse",
return_posteriors: bool = False,
check_input: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions sktree/stats/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Optional, Tuple

import numpy as np
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -112,7 +112,7 @@ def _compute_null_distribution_perm(
est: ForestClassifier,
metric: str = "mse",
n_repeats: int = 1000,
seed: int = None,
seed: Optional[int] = None,
) -> ArrayLike:
"""Compute null distribution using permutation method.
Expand Down Expand Up @@ -173,7 +173,7 @@ def _compute_null_distribution_coleman(
y_pred_proba_perm: ArrayLike,
metric: str = "mse",
n_repeats: int = 1000,
seed: int = None,
seed: Optional[int] = None,
) -> Tuple[ArrayLike, ArrayLike]:
"""Compute null distribution using Coleman method.
Expand Down
4 changes: 2 additions & 2 deletions sktree/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.utils.estimator_checks import parametrize_with_checks

from sktree._lib.sklearn.tree import DecisionTreeClassifier
from sktree.datasets.hyppo import quadratic
from sktree.datasets.hyppo import make_quadratic_classification
from sktree.ensemble import HonestForestClassifier
from sktree.tree import ObliqueDecisionTreeClassifier, PatchObliqueDecisionTreeClassifier

Expand Down Expand Up @@ -262,7 +262,7 @@ def test_honest_forest_with_sklearn_trees():
https://github.com/neurodata/scikit-tree/pull/157."""

# generate the high-dimensional quadratic data
X, y = quadratic(1024, 4096, noise=True, seed=0)
X, y = make_quadratic_classification(1024, 4096, noise=True, seed=0)
y = y.squeeze()
print(X.shape, y.shape)
print(np.sum(y) / len(y))
Expand Down

0 comments on commit 6d914f6

Please sign in to comment.