Skip to content

Commit

Permalink
document test better
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Nov 6, 2023
1 parent c02326d commit 714bf76
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions sktree/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numpy.testing import assert_allclose, assert_array_almost_equal
from scipy.stats import entropy
from sklearn import datasets
from sklearn.metrics import accuracy_score, r2_score
from sklearn.metrics import accuracy_score, r2_score, roc_auc_score
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier as skDecisionTreeClassifier
from sklearn.utils import check_random_state
Expand Down Expand Up @@ -296,7 +296,8 @@ def test_honest_forest_with_sklearn_trees_with_auc():
"""Test against regression in power-curves discussed in:
https://github.com/neurodata/scikit-tree/pull/157.
This unit-test now tests explicitly the power curve.
This unit-test tests the equivalent of the AUC using sklearn's DTC
vs our forked version of sklearn's DTC as the base tree.
"""
skForest = HonestForestClassifier(
n_estimators=10, tree_estimator=skDecisionTreeClassifier(), random_state=0
Expand All @@ -306,6 +307,7 @@ def test_honest_forest_with_sklearn_trees_with_auc():
n_estimators=10, tree_estimator=DecisionTreeClassifier(), random_state=0
)

max_fpr = 0.1
scores = []
sk_scores = []
for idx in range(10):
Expand All @@ -316,27 +318,27 @@ def test_honest_forest_with_sklearn_trees_with_auc():
Forest.fit(X, y)

# compute MI
y_pred_proba = skForest.predict_proba(X)[:, 1]
sk_mi = _mutual_information(y, y_pred_proba)
y_pred_proba = skForest.predict_proba(X)[:, 1].reshape(-1, 1)
sk_mi = roc_auc_score(y, y_pred_proba, max_fpr=max_fpr)

y_pred_proba = Forest.predict_proba(X)[:, 1]
mi = _mutual_information(y, y_pred_proba)
y_pred_proba = Forest.predict_proba(X)[:, 1].reshape(-1, 1)
mi = roc_auc_score(y, y_pred_proba, max_fpr=max_fpr)

scores.append(mi)
sk_scores.append(sk_mi)

print(scores, sk_scores)
print(np.mean(scores), np.mean(sk_scores))
print(np.std(scores), np.std(sk_scores))

assert_allclose(np.mean(sk_scores), np.mean(scores), atol=0.05)


def test_honest_forest_with_sklearn_trees_with_mi():
"""Test against regression in power-curves discussed in:
https://github.com/neurodata/scikit-tree/pull/157.
This unit-test now tests explicitly the power curve.
This unit-test tests the equivalent of the MI using sklearn's DTC
vs our forked version of sklearn's DTC as the base tree.
"""
skForest = HonestForestClassifier(
n_estimators=10, tree_estimator=skDecisionTreeClassifier(), random_state=0
Expand Down Expand Up @@ -373,5 +375,4 @@ def test_honest_forest_with_sklearn_trees_with_mi():
print(scores, sk_scores)
print(np.mean(scores), np.mean(sk_scores))
print(np.std(scores), np.std(sk_scores))

assert_allclose(np.mean(sk_scores), np.mean(scores), atol=0.05)

0 comments on commit 714bf76

Please sign in to comment.