-
-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Adam Li <[email protected]>
- Loading branch information
Showing
23 changed files
with
2,242 additions
and
835 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
:orphan: | ||
|
||
.. include:: _contributors.rst | ||
.. currentmodule:: sktree | ||
|
||
.. _current: | ||
|
||
Version 0.4 | ||
=========== | ||
|
||
**In Development** | ||
|
||
Changelog | ||
--------- | ||
|
||
- |API| ``FeatureImportanceForest*`` now has a hyperparameter to control the number of permutations is done per forest ``permute_per_forest_fraction``, by `Adam Li`_ (:pr:`145`) | ||
|
||
Code and Documentation Contributors | ||
----------------------------------- | ||
|
||
Thanks to everyone who has contributed to the maintenance and improvement of | ||
the project since version inception, including: | ||
|
||
* `Adam Li`_ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
""" | ||
===================================================== | ||
Compute partial AUC using multi-view MIGHT (MV-MIGHT) | ||
===================================================== | ||
An example using :class:`~sktree.stats.FeatureImportanceForestClassifier` for nonparametric | ||
multivariate hypothesis test, on simulated mutli-view datasets. Here, we present | ||
how to estimate partial AUROC from a multi-view feature set. | ||
We simulate a dataset with 510 features, 1000 samples, and a binary class target variable. | ||
The first 10 features (X) are strongly correlated with the target, and the second | ||
feature set (W) is weakly correlated with the target (y). | ||
We then use MV-MIGHT to calculate the partial AUC of these sets. | ||
""" | ||
|
||
import numpy as np | ||
from scipy.special import expit | ||
|
||
from sktree import HonestForestClassifier | ||
from sktree.stats import FeatureImportanceForestClassifier | ||
from sktree.tree import DecisionTreeClassifier, MultiViewDecisionTreeClassifier | ||
|
||
seed = 12345 | ||
rng = np.random.default_rng(seed) | ||
|
||
# %% | ||
# Simulate data | ||
# ------------- | ||
# We simulate the two feature sets, and the target variable. We then combine them | ||
# into a single dataset to perform hypothesis testing. | ||
|
||
n_samples = 1000 | ||
n_features_set = 500 | ||
mean = 1.0 | ||
sigma = 2.0 | ||
beta = 5.0 | ||
|
||
unimportant_mean = 0.0 | ||
unimportant_sigma = 4.5 | ||
|
||
# first sample the informative features, and then the uniformative features | ||
X_important = rng.normal(loc=mean, scale=sigma, size=(n_samples, 10)) | ||
X_unimportant = rng.normal( | ||
loc=unimportant_mean, scale=unimportant_sigma, size=(n_samples, n_features_set) | ||
) | ||
X = np.hstack([X_important, X_unimportant]) | ||
|
||
# simulate the binary target variable | ||
y = rng.binomial(n=1, p=expit(beta * X_important[:, :10].sum(axis=1)), size=n_samples) | ||
|
||
# %% | ||
# Use partial AUC as test statistic | ||
# --------------------------------- | ||
# You can specify the maximum specificity by modifying ``max_fpr`` in ``statistic``. | ||
|
||
n_estimators = 125 | ||
max_features = "sqrt" | ||
metric = "auc" | ||
test_size = 0.2 | ||
n_jobs = -1 | ||
honest_fraction = 0.5 | ||
max_fpr = 0.1 | ||
|
||
est_mv = FeatureImportanceForestClassifier( | ||
estimator=HonestForestClassifier( | ||
n_estimators=n_estimators, | ||
max_features=max_features, | ||
tree_estimator=MultiViewDecisionTreeClassifier(feature_set_ends=[10, 10 + n_features_set]), | ||
honest_fraction=honest_fraction, | ||
n_jobs=n_jobs, | ||
), | ||
random_state=seed, | ||
test_size=test_size, | ||
permute_per_tree=True, | ||
sample_dataset_per_tree=True, | ||
) | ||
|
||
# we test with the multi-view setting, thus should return a higher AUC | ||
stat, posterior_arr, samples = est_mv.statistic( | ||
X, | ||
y, | ||
metric=metric, | ||
return_posteriors=True, | ||
max_fpr=max_fpr, | ||
) | ||
|
||
print(f"ASH-90 / Partial AUC: {stat}") | ||
print(f"Shape of Observed Samples: {samples.shape}") | ||
print(f"Shape of Tree Posteriors for the positive class: {posterior_arr.shape}") | ||
|
||
# %% | ||
# Repeat without multi-view | ||
# --------------------------------- | ||
# This feature set has a smaller statistic, which is expected due to its lack of multi-view setting. | ||
|
||
est = FeatureImportanceForestClassifier( | ||
estimator=HonestForestClassifier( | ||
n_estimators=n_estimators, | ||
max_features=max_features, | ||
tree_estimator=DecisionTreeClassifier(), | ||
honest_fraction=honest_fraction, | ||
n_jobs=n_jobs, | ||
), | ||
random_state=seed, | ||
test_size=test_size, | ||
permute_per_tree=True, | ||
sample_dataset_per_tree=True, | ||
) | ||
|
||
stat, posterior_arr, samples = est.statistic( | ||
X, | ||
y, | ||
metric=metric, | ||
return_posteriors=True, | ||
max_fpr=max_fpr, | ||
) | ||
|
||
print(f"ASH-90 / Partial AUC: {stat}") | ||
print(f"Shape of Observed Samples: {samples.shape}") | ||
print(f"Shape of Tree Posteriors for the positive class: {posterior_arr.shape}") | ||
|
||
# %% | ||
# All posteriors are saved within the model | ||
# ----------------------------------------- | ||
# Extract the results from the model variables anytime. You can save the model with ``pickle``. | ||
# | ||
# ASH-90 / Partial AUC: ``est_mv.observe_stat_`` | ||
# | ||
# Observed Samples: ``est_mv.observe_samples_`` | ||
# | ||
# Tree Posteriors for the positive class: ``est_mv.observe_posteriors_`` | ||
# (n_trees, n_samples_test, 1) | ||
# | ||
# True Labels: ``est_mv.y_true_final_`` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _quantile_examples: | ||
|
||
Quantile Predictions with Random Forest | ||
--------------------------------------- | ||
|
||
Examples demonstrating how to generate quantile predictions using Random Forest variants. |
111 changes: 111 additions & 0 deletions
111
examples/quantile_predictions/plot_quantile_interpolation_with_RF.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
======================================================== | ||
Predicting with different quantile interpolation methods | ||
======================================================== | ||
An example comparison of interpolation methods that can be applied during | ||
prediction when the desired quantile lies between two data points. | ||
""" | ||
|
||
from collections import defaultdict | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from sklearn.ensemble import RandomForestRegressor | ||
|
||
# %% | ||
# Generate the data | ||
# ----------------- | ||
# We use four simple data points to illustrate the difference between the intervals that are | ||
# generated using different interpolation methods. | ||
|
||
X = np.array([[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1]]) | ||
y = np.array([-2, -1, 0, 1, 2]) | ||
|
||
# %% | ||
# The interpolation methods | ||
# ------------------------- | ||
# The following interpolation methods demonstrated here are: | ||
# To interpolate between the data points, i and j (``i <= j``), | ||
# linear, lower, higher, midpoint, or nearest. For more details, see `sktree.RandomForestRegressor`. | ||
# The difference between the methods can be illustrated with the following example: | ||
|
||
interpolations = ["linear", "lower", "higher", "midpoint", "nearest"] | ||
colors = ["#006aff", "#ffd237", "#0d4599", "#f2a619", "#a6e5ff"] | ||
quantiles = [0.025, 0.5, 0.975] | ||
|
||
y_medians = [] | ||
y_errs = [] | ||
est = RandomForestRegressor( | ||
n_estimators=1, | ||
random_state=0, | ||
) | ||
# fit the model | ||
est.fit(X, y) | ||
# get the leaf nodes that each sample fell into | ||
leaf_ids = est.apply(X) | ||
# create a list of dictionary that maps node to samples that fell into it | ||
# for each tree | ||
node_to_indices = [] | ||
for tree in range(leaf_ids.shape[1]): | ||
d = defaultdict(list) | ||
for id, leaf in enumerate(leaf_ids[:, tree]): | ||
d[leaf].append(id) | ||
node_to_indices.append(d) | ||
# drop the X_test to the trained tree and | ||
# get the indices of leaf nodes that fall into it | ||
leaf_ids_test = est.apply(X) | ||
# for each samples, collect the indices of the samples that fell into | ||
# the same leaf node for each tree | ||
y_pred_quantile = [] | ||
for sample in range(leaf_ids_test.shape[0]): | ||
li = [ | ||
node_to_indices[tree][leaf_ids_test[sample][tree]] for tree in range(leaf_ids_test.shape[1]) | ||
] | ||
# merge the list of indices into one | ||
idx = [item for sublist in li for item in sublist] | ||
# get the y_train for each corresponding id`` | ||
y_pred_quantile.append(y[idx]) | ||
|
||
for interpolation in interpolations: | ||
# get the quatile preditions for each predicted sample | ||
y_pred = [ | ||
np.array( | ||
[ | ||
np.quantile(y_pred_quantile[i], quantile, method=interpolation) | ||
for i in range(len(y_pred_quantile)) | ||
] | ||
) | ||
for quantile in quantiles | ||
] | ||
y_medians.append(y_pred[1]) | ||
y_errs.append( | ||
np.concatenate( | ||
( | ||
[y_pred[1] - y_pred[0]], | ||
[y_pred[2] - y_pred[1]], | ||
), | ||
axis=0, | ||
) | ||
) | ||
|
||
sc = plt.scatter(np.arange(len(y)) - 0.35, y, color="k", zorder=10) | ||
ebs = [] | ||
for i, (median, y_err) in enumerate(zip(y_medians, y_errs)): | ||
ebs.append( | ||
plt.errorbar( | ||
np.arange(len(y)) + (0.15 * (i + 1)) - 0.35, | ||
median, | ||
yerr=y_err, | ||
color=colors[i], | ||
ecolor=colors[i], | ||
fmt="o", | ||
) | ||
) | ||
plt.xlim([-0.75, len(y) - 0.25]) | ||
plt.xticks(np.arange(len(y)), X.tolist()) | ||
plt.xlabel("Samples (Feature Values)") | ||
plt.ylabel("Actual and Predicted Values") | ||
plt.legend([sc] + ebs, ["actual"] + interpolations, loc=2) | ||
plt.show() |
Oops, something went wrong.