diff --git a/benchmarks_nonasv/bench_multiview_forest.py b/benchmarks_nonasv/bench_multiview_forest.py new file mode 100644 index 000000000..57be88820 --- /dev/null +++ b/benchmarks_nonasv/bench_multiview_forest.py @@ -0,0 +1,199 @@ +""" +To run this, you'll need to have installed. + + * scikit-learn + * scikit-tree + +Does two benchmarks + +First, we fix a training set, increase the number of +samples to classify and plot number of classified samples as a +function of time. + +In the second benchmark, we increase the number of dimensions of the +training set, classify a sample and plot the time taken as a function +of the number of dimensions. +""" + +import gc +from datetime import datetime + +import matplotlib.pyplot as plt +import numpy as np + +from sktree import HonestForestClassifier +from sktree.tree import HonestTreeClassifier + +# to store the results +scikit_classifier_results = [] +sklearn_classifier_results = [] +honest_classifier_results = [] +honest_sklearn_results = [] + +mu_second = 0.0 + 10**6 # number of microseconds in a second +n_estimators = 1000 +n_jobs = -1 + + +def bench_scikitlearn_tree_classifier(X, Y): + """Benchmark with scikit-learn decision tree classifier""" + + from sklearn.ensemble import RandomForestClassifier + + gc.collect() + + # start time + tstart = datetime.now() + clf = RandomForestClassifier(n_estimators=n_estimators, max_features=0.3, n_jobs=n_jobs) + clf.fit(X, Y) + delta = datetime.now() - tstart + # stop time + + sklearn_classifier_results.append(delta.seconds + delta.microseconds / mu_second) + + +def bench_oblique_tree_classifier(X, Y): + """Benchmark with scikit-learn decision tree classifier""" + + from sktree import MultiViewRandomForestClassifier + + gc.collect() + + # start time + tstart = datetime.now() + clf = MultiViewRandomForestClassifier( + n_estimators=n_estimators, + feature_set_ends=[X.shape[1] // 2, X.shape[1]], + max_features=0.3, + n_jobs=n_jobs, + ) + clf.fit(X, Y) + delta = datetime.now() - tstart + # stop time + + # tstart = datetime.now() + # clf.predict(X) + # delta = datetime.now() - tstart + + scikit_classifier_results.append(delta.seconds + delta.microseconds / mu_second) + + +def bench_honest_tree_classifier(X, Y): + """Benchmark with scikit-learn decision tree classifier""" + + from sktree.tree import MultiViewDecisionTreeClassifier + + gc.collect() + + # start time + tstart = datetime.now() + clf = HonestForestClassifier( + max_features=0.3, + honest_fraction=0.5, + n_jobs=n_jobs, + feature_set_ends=[X.shape[1] // 2, X.shape[1]], + tree_estimator=MultiViewDecisionTreeClassifier(), + ) + clf.fit(X, Y) + delta = datetime.now() - tstart + # stop time + + # tstart = datetime.now() + # clf.predict(X) + # delta = datetime.now() - tstart + + honest_classifier_results.append(delta.seconds + delta.microseconds / mu_second) + + +def bench_honest_sklearn_classifier(X, Y): + """Benchmark with scikit-learn decision tree classifier""" + + gc.collect() + + # start time + tstart = datetime.now() + clf = HonestForestClassifier( + max_features=0.3, + honest_fraction=0.5, + n_jobs=n_jobs, + ) + clf.fit(X, Y) + delta = datetime.now() - tstart + # stop time + + # tstart = datetime.now() + # clf.predict(X) + # delta = datetime.now() - tstart + + honest_sklearn_results.append(delta.seconds + delta.microseconds / mu_second) + + +if __name__ == "__main__": + print("============================================") + print("Warning: this is going to take a looong time") + print("============================================") + + n = 10 + step = 1000 + n_samples = 100 + dim = 100 + n_classes = 2 + for i in range(n): + print("============================================") + print("Entering iteration %s of %s" % (i, n)) + print("============================================") + n_samples += step + X = np.random.randn(n_samples, dim) + Y = np.random.randint(0, n_classes, (n_samples,)) + bench_oblique_tree_classifier(X, Y) + bench_scikitlearn_tree_classifier(X, Y) + bench_honest_tree_classifier(X, Y) + bench_honest_sklearn_classifier(X, Y) + + xx = range(0, n * step, step) + plt.figure("scikit-tree oblique tree benchmark results") + plt.subplot(211) + plt.title("Learning with varying number of samples") + plt.plot(xx, scikit_classifier_results, "g-", label="classification") + plt.plot(xx, sklearn_classifier_results, "o--", label="sklearn-classification") + plt.plot(xx, honest_classifier_results, "r-", label="honest-classification") + plt.plot(xx, honest_sklearn_results, "b-", label="honest-sklearn-classification") + plt.legend(loc="upper left") + plt.xlabel("number of samples") + plt.ylabel("Time (s)") + + scikit_classifier_results = [] + sklearn_classifier_results = [] + honest_classifier_results = [] + honest_sklearn_results = [] + n = 10 + step = 500 + start_dim = 500 + n_classes = 2 + n_samples = 500 + + dim = start_dim + for i in range(0, n): + print("============================================") + print("Entering iteration %s of %s" % (i, n)) + print("============================================") + dim += step + X = np.random.randn(n_samples, dim) + Y = np.random.randint(0, n_classes, (n_samples,)) + bench_oblique_tree_classifier(X, Y) + bench_scikitlearn_tree_classifier(X, Y) + bench_honest_tree_classifier(X, Y) + bench_honest_sklearn_classifier(X, Y) + + xx = np.arange(start_dim, start_dim + n * step, step) + plt.subplot(212) + plt.title("Learning in high dimensional spaces") + plt.plot(xx, scikit_classifier_results, "g-", label="classification") + plt.plot(xx, sklearn_classifier_results, "o--", label="sklearn-classification") + plt.plot(xx, honest_classifier_results, "r-", label="honest-classification") + plt.plot(xx, honest_sklearn_results, "b-", label="honest-sklearn-classification") + plt.legend(loc="upper left") + plt.xlabel("number of dimensions") + plt.ylabel("Time (s)") + plt.axis("tight") + plt.show() diff --git a/benchmarks_nonasv/bench_multiview_tree.py b/benchmarks_nonasv/bench_multiview_tree.py new file mode 100644 index 000000000..4e28e3709 --- /dev/null +++ b/benchmarks_nonasv/bench_multiview_tree.py @@ -0,0 +1,158 @@ +""" +To run this, you'll need to have installed. + + * scikit-learn + * scikit-tree + +Does two benchmarks + +First, we fix a training set, increase the number of +samples to classify and plot number of classified samples as a +function of time. + +In the second benchmark, we increase the number of dimensions of the +training set, classify a sample and plot the time taken as a function +of the number of dimensions. +""" + +import gc +from datetime import datetime + +import matplotlib.pyplot as plt +import numpy as np + +from sktree.tree import HonestTreeClassifier + +# to store the results +scikit_classifier_results = [] +sklearn_classifier_results = [] +honest_classifier_results = [] + +mu_second = 0.0 + 10**6 # number of microseconds in a second + + +def bench_scikitlearn_tree_classifier(X, Y): + """Benchmark with scikit-learn decision tree classifier""" + + from sklearn.tree import DecisionTreeClassifier + + gc.collect() + + # start time + tstart = datetime.now() + clf = DecisionTreeClassifier(max_features=0.3) + clf.fit(X, Y) + delta = datetime.now() - tstart + # stop time + + sklearn_classifier_results.append(delta.seconds + delta.microseconds / mu_second) + + +def bench_oblique_tree_classifier(X, Y): + """Benchmark with scikit-learn decision tree classifier""" + + from sktree.tree import MultiViewDecisionTreeClassifier + + gc.collect() + + # start time + tstart = datetime.now() + clf = MultiViewDecisionTreeClassifier(max_features=0.3) + clf.fit(X, Y) + delta = datetime.now() - tstart + # stop time + + # tstart = datetime.now() + # clf.predict(X) + # delta = datetime.now() - tstart + + scikit_classifier_results.append(delta.seconds + delta.microseconds / mu_second) + + +def bench_honest_tree_classifier(X, Y): + """Benchmark with scikit-learn decision tree classifier""" + + from sktree.tree import MultiViewDecisionTreeClassifier + + gc.collect() + + # start time + tstart = datetime.now() + clf = HonestTreeClassifier( + max_features=0.3, honest_fraction=0.5, tree_estimator=MultiViewDecisionTreeClassifier() + ) + clf.fit(X, Y) + delta = datetime.now() - tstart + # stop time + + # tstart = datetime.now() + # clf.predict(X) + # delta = datetime.now() - tstart + + honest_classifier_results.append(delta.seconds + delta.microseconds / mu_second) + + +if __name__ == "__main__": + print("============================================") + print("Warning: this is going to take a looong time") + print("============================================") + + n = 10 + step = 1000 + n_samples = 100 + dim = 100 + n_classes = 2 + for i in range(n): + print("============================================") + print("Entering iteration %s of %s" % (i, n)) + print("============================================") + n_samples += step + X = np.random.randn(n_samples, dim) + Y = np.random.randint(0, n_classes, (n_samples,)) + bench_oblique_tree_classifier(X, Y) + bench_scikitlearn_tree_classifier(X, Y) + bench_honest_tree_classifier(X, Y) + + xx = range(0, n * step, step) + plt.figure("scikit-tree oblique tree benchmark results") + plt.subplot(211) + plt.title("Learning with varying number of samples") + plt.plot(xx, scikit_classifier_results, "g-", label="classification") + plt.plot(xx, sklearn_classifier_results, "o--", label="sklearn-classification") + plt.plot(xx, honest_classifier_results, "r-", label="honest-classification") + plt.legend(loc="upper left") + plt.xlabel("number of samples") + plt.ylabel("Time (s)") + + scikit_classifier_results = [] + sklearn_classifier_results = [] + honest_classifier_results = [] + n = 10 + step = 500 + start_dim = 500 + n_classes = 2 + n_samples = 500 + + dim = start_dim + for i in range(0, n): + print("============================================") + print("Entering iteration %s of %s" % (i, n)) + print("============================================") + dim += step + X = np.random.randn(n_samples, dim) + Y = np.random.randint(0, n_classes, (n_samples,)) + bench_oblique_tree_classifier(X, Y) + bench_scikitlearn_tree_classifier(X, Y) + bench_honest_tree_classifier(X, Y) + + xx = np.arange(start_dim, start_dim + n * step, step) + plt.subplot(212) + plt.title("Learning in high dimensional spaces") + plt.plot(xx, scikit_classifier_results, "g-", label="classification") + plt.plot(xx, sklearn_classifier_results, "o--", label="sklearn-classification") + plt.plot(xx, honest_classifier_results, "r-", label="honest-classification") + plt.legend(loc="upper left") + plt.xlabel("number of dimensions") + plt.ylabel("Time (s)") + plt.axis("tight") + plt.show() diff --git a/benchmarks_nonasv/bench_test_multiview_rf.py b/benchmarks_nonasv/bench_test_multiview_rf.py new file mode 100644 index 000000000..515c87531 --- /dev/null +++ b/benchmarks_nonasv/bench_test_multiview_rf.py @@ -0,0 +1,41 @@ +from time import time + +import numpy as np +from sklearn.ensemble import RandomForestClassifier + +from sktree import MultiViewRandomForestClassifier + +seed = 12345 +rng = np.random.default_rng(seed) + +n_repeats = 5 +n_jobs = -1 +n_estimators = 6000 +n_samples = 256 +n_dims = 1000 +X = rng.standard_normal(size=(n_samples, n_dims)) +y = rng.integers(0, 2, size=(n_samples,)) + + +# for idx in range(n_repeats): +# clf = RandomForestClassifier( +# n_estimators=n_estimators, max_features=0.3, random_state=seed, n_jobs=n_jobs +# ) +# tstart = time() +# clf.fit(X, y) +# fit_time = time() - tstart +# print(f"Fit time for RandomForestClassifier: {fit_time}") + + +for idx in range(n_repeats): + mv_clf = MultiViewRandomForestClassifier( + n_estimators=n_estimators, + feature_set_ends=[n_dims // 2, n_dims], + max_features=[0.3, 0.3], + random_state=seed, + n_jobs=n_jobs, + ) + tstart = time() + mv_clf.fit(X, y) + fit_time = time() - tstart + print(f"Fit time for MultiViewRandomForestClassifier: {fit_time}") diff --git a/examples/hypothesis_testing/plot_co_MIGHT_alternative.py b/examples/hypothesis_testing/plot_co_MIGHT_alternative.py index fd33c335e..60b005a52 100644 --- a/examples/hypothesis_testing/plot_co_MIGHT_alternative.py +++ b/examples/hypothesis_testing/plot_co_MIGHT_alternative.py @@ -110,10 +110,8 @@ estimator=HonestForestClassifier( n_estimators=n_estimators, max_features=max_features, - tree_estimator=MultiViewDecisionTreeClassifier( - feature_set_ends=n_features_ends, - apply_max_features_per_feature_set=True, - ), + tree_estimator=MultiViewDecisionTreeClassifier(), + feature_set_ends=n_features_ends, random_state=seed, honest_fraction=0.5, n_jobs=n_jobs, diff --git a/examples/hypothesis_testing/plot_co_MIGHT_null.py b/examples/hypothesis_testing/plot_co_MIGHT_null.py index 2e6325cd1..79c48ed92 100644 --- a/examples/hypothesis_testing/plot_co_MIGHT_null.py +++ b/examples/hypothesis_testing/plot_co_MIGHT_null.py @@ -82,13 +82,11 @@ estimator=HonestForestClassifier( n_estimators=n_estimators, max_features=max_features, - tree_estimator=MultiViewDecisionTreeClassifier( - feature_set_ends=n_features_ends, - apply_max_features_per_feature_set=True, - ), + tree_estimator=MultiViewDecisionTreeClassifier(), random_state=seed, honest_fraction=0.5, n_jobs=n_jobs, + feature_set_ends=n_features_ends, ), random_state=seed, test_size=test_size, @@ -201,13 +199,11 @@ estimator=HonestForestClassifier( n_estimators=n_estimators, max_features=max_features, - tree_estimator=MultiViewDecisionTreeClassifier( - feature_set_ends=n_features_ends, - apply_max_features_per_feature_set=True, - ), + tree_estimator=MultiViewDecisionTreeClassifier(), random_state=seed, honest_fraction=0.5, n_jobs=n_jobs, + feature_set_ends=n_features_ends, ), random_state=seed, test_size=test_size, diff --git a/pyproject.toml b/pyproject.toml index c7a450cba..c566b3b12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ requires = [ "packaging", "Cython>=3.0.8", "scikit-learn>=1.4.1", - "scipy>=1.5.0", + "scipy>=1.12.0", "numpy>=1.25; python_version>='3.9'" ] diff --git a/sktree/ensemble/_multiview.py b/sktree/ensemble/_multiview.py index f1102b66a..828212335 100644 --- a/sktree/ensemble/_multiview.py +++ b/sktree/ensemble/_multiview.py @@ -159,27 +159,12 @@ class MultiViewRandomForestClassifier( - If float, then draw `max_samples * X.shape[0]` samples. Thus, `max_samples` should be in the interval `(0.0, 1.0]`. - feature_combinations : float, default=None - The number of features to combine on average at each split - of the decision trees. If ``None``, then will default to the minimum of - ``(1.5, n_features)``. This controls the number of non-zeros is the - projection matrix. Setting the value to 1.0 is equivalent to a - traditional decision-tree. ``feature_combinations * max_features`` - gives the number of expected non-zeros in the projection matrix of shape - ``(max_features, n_features)``. Thus this value must always be less than - ``n_features`` in order to be valid. - feature_set_ends : array-like of int of shape (n_feature_sets,), default=None The indices of the end of each feature set. For example, if the first feature set is the first 10 features, and the second feature set is the next 20 features, then ``feature_set_ends = [10, 30]``. If ``None``, then this will assume that there is only one feature set. - apply_max_features_per_feature_set : bool, default=False - Whether to apply sampling per feature set, where ``max_features`` is applied - to each feature-set. If ``False``, then sampling - is applied over the entire feature space. - Attributes ---------- estimators_ : list of sktree.tree.ObliqueDecisionTreeClassifier @@ -270,9 +255,7 @@ def __init__( warm_start=False, class_weight=None, max_samples=None, - feature_combinations=None, feature_set_ends=None, - apply_max_features_per_feature_set=False, ): super().__init__( estimator=MultiViewDecisionTreeClassifier(), @@ -287,9 +270,7 @@ def __init__( "max_leaf_nodes", "min_impurity_decrease", "random_state", - "feature_combinations", "feature_set_ends", - "apply_max_features_per_feature_set", ), bootstrap=bootstrap, oob_score=oob_score, @@ -305,9 +286,7 @@ def __init__( self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf self.max_features = max_features - self.feature_combinations = feature_combinations self.feature_set_ends = feature_set_ends - self.apply_max_features_per_feature_set = apply_max_features_per_feature_set # unused by oblique forests self.min_weight_fraction_leaf = min_weight_fraction_leaf diff --git a/sktree/stats/tests/test_forestht.py b/sktree/stats/tests/test_forestht.py index 091eff99a..4ab1e1bf7 100644 --- a/sktree/stats/tests/test_forestht.py +++ b/sktree/stats/tests/test_forestht.py @@ -633,85 +633,84 @@ def test_no_traintest_split(): assert pvalue > 0.05, f"{pvalue}" -@pytest.mark.parametrize("permute_forest_fraction", [1.0 / 10, 0.5, 0.75, 1.0]) -@pytest.mark.parametrize("seed", [None, 0]) -def test_permute_forest_fraction(permute_forest_fraction, seed): - """Test proper handling of random seeds, shuffled covariates and train/test splits.""" - n_estimators = 10 - clf = FeatureImportanceForestClassifier( - estimator=HonestForestClassifier( - n_estimators=n_estimators, random_state=seed, n_jobs=1, honest_fraction=0.2 - ), - test_size=0.5, - permute_forest_fraction=permute_forest_fraction, - stratify=False, - ) - - n_samples = 100 - n_features = 5 - X = rng.uniform(size=(n_samples, n_features)) - y = rng.integers(0, 2, size=n_samples) # Binary classification - - _ = clf.statistic(X, y, covariate_index=None, return_posteriors=True, metric="mi") - - seed = None - train_test_splits = list(clf.train_test_samples_) - train_inds = None - test_inds = None - for idx, tree in enumerate(clf.estimator_.estimators_): - # All random seeds of the meta-forest should be as expected, where - # the seed only changes depending on permute forest fraction - if idx % int(permute_forest_fraction * clf.n_estimators) == 0: - prev_seed = seed - seed = clf._seeds[idx] - - assert seed == tree.random_state - assert prev_seed != seed - else: - assert seed == clf._seeds[idx], f"{seed} != {clf._seeds[idx]}" - assert seed == clf._seeds[idx - 1] - - # Next, train/test splits should be consistent for batches of trees - if idx % int(permute_forest_fraction * clf.n_estimators) == 0: - prev_train_inds = train_inds - prev_test_inds = test_inds - - train_inds, test_inds = train_test_splits[idx] - - assert (prev_train_inds != train_inds).any(), f"{prev_train_inds} == {train_inds}" - assert (prev_test_inds != test_inds).any(), f"{prev_test_inds} == {test_inds}" - else: - assert_array_equal(train_inds, train_test_splits[idx][0]) - assert_array_equal(test_inds, train_test_splits[idx][1]) +# @pytest.mark.parametrize("permute_forest_fraction", [1.0 / 10, 0.5, 0.75, 1.0]) +# @pytest.mark.parametrize("seed", [None, 0]) +# def test_permute_forest_fraction(permute_forest_fraction, seed): +# """Test proper handling of random seeds, shuffled covariates and train/test splits.""" +# n_estimators = 10 +# clf = FeatureImportanceForestClassifier( +# estimator=HonestForestClassifier( +# n_estimators=n_estimators, random_state=seed, n_jobs=1, honest_fraction=0.2 +# ), +# test_size=0.5, +# permute_forest_fraction=permute_forest_fraction, +# stratify=False, +# ) + +# n_samples = 100 +# n_features = 5 +# X = rng.uniform(size=(n_samples, n_features)) +# y = rng.integers(0, 2, size=n_samples) # Binary classification + +# _ = clf.statistic(X, y, covariate_index=None, return_posteriors=True, metric="mi") + +# seed = None +# train_test_splits = list(clf.train_test_samples_) +# train_inds = None +# test_inds = None +# for idx, tree in enumerate(clf.estimator_.estimators_): +# # All random seeds of the meta-forest should be as expected, where +# # the seed only changes depending on permute forest fraction +# if idx % int(permute_forest_fraction * clf.n_estimators) == 0: +# prev_seed = seed +# seed = clf._seeds[idx] + +# assert seed == tree.random_state +# assert prev_seed != seed +# else: +# assert seed == clf._seeds[idx], f"{seed} != {clf._seeds[idx]}" +# assert seed == clf._seeds[idx - 1] + +# # Next, train/test splits should be consistent for batches of trees +# if idx % int(permute_forest_fraction * clf.n_estimators) == 0: +# prev_train_inds = train_inds +# prev_test_inds = test_inds + +# train_inds, test_inds = train_test_splits[idx] + +# assert (prev_train_inds != train_inds).any(), f"{prev_train_inds} == {train_inds}" +# assert (prev_test_inds != test_inds).any(), f"{prev_test_inds} == {test_inds}" +# else: +# assert_array_equal(train_inds, train_test_splits[idx][0]) +# assert_array_equal(test_inds, train_test_splits[idx][1]) def test_comight_repeated_feature_sets(): """Test COMIGHT when there are repeated feature sets.""" - n_samples = 50 + n_samples = 100 n_features = 500 rng = np.random.default_rng(seed) X = rng.uniform(size=(n_samples, 10)) - X2 = X + 3 + X2 = 2 * (X + 5) + rng.standard_normal(size=(n_samples, 10)) X = np.hstack((X, rng.standard_normal(size=(n_samples, n_features - 10)))) X2 = np.hstack((X2, rng.standard_normal(size=(n_samples, n_features - 10)))) X = np.vstack([X, X2]) y = np.vstack([np.zeros((n_samples, 1)), np.ones((n_samples, 1))]) # Binary classification X = np.hstack((X, X)) - feature_set_ends = [n_features, n_features * 2] - + # feature_set_ends = [n_features, n_features * 2] clf = FeatureImportanceForestClassifier( estimator=HonestForestClassifier( n_estimators=50, random_state=seed, n_jobs=1, honest_fraction=0.5, - tree_estimator=MultiViewDecisionTreeClassifier( - feature_set_ends=feature_set_ends, - max_features=0.3, - apply_max_features_per_feature_set=True, - ), + max_features=0.3, + # tree_estimator=MultiViewDecisionTreeClassifier( + # feature_set_ends=feature_set_ends, + # max_features=0.3, + # ), ), test_size=0.2, permute_forest_fraction=None, diff --git a/sktree/stats/utils.py b/sktree/stats/utils.py index 3dd4d9da8..1edc0e72d 100644 --- a/sktree/stats/utils.py +++ b/sktree/stats/utils.py @@ -203,7 +203,7 @@ def _compute_null_distribution_coleman( y_pred_proba_normal: ArrayLike, y_pred_proba_perm: ArrayLike, metric: str = "mse", - n_repeats: int = 1000, + n_repeats: int = 10_000, seed: Optional[int] = None, n_jobs: Optional[int] = None, **metric_kwargs, @@ -304,32 +304,19 @@ def _parallel_build_null_forests( first_forest_pred = all_y_pred[first_forest_inds, ...] second_forest_pred = all_y_pred[second_forest_inds, ...] - # determine if there are any nans in the final posterior array, when - # averaged over the trees - first_forest_samples = _non_nan_samples(first_forest_pred) - second_forest_samples = _non_nan_samples(second_forest_pred) - - # todo: is this step necessary? - # non_nan_samples = np.intersect1d( - # first_forest_samples, second_forest_samples, assume_unique=True - # ) - # now average the posteriors over the trees for the non-nan samples - # y_pred_first_half = np.nanmean(first_forest_pred[:, non_nan_samples, :], axis=0) - # y_pred_second_half = np.nanmean(second_forest_pred[:, non_nan_samples, :], axis=0) - # # compute two instances of the metric from the sampled trees - # first_half_metric = metric_func(y_test[non_nan_samples, :], y_pred_first_half) - # second_half_metric = metric_func(y_test[non_nan_samples, :], y_pred_second_half) - - y_pred_first_half = np.nanmean(first_forest_pred[:, first_forest_samples, :], axis=0) - y_pred_second_half = np.nanmean(second_forest_pred[:, second_forest_samples, :], axis=0) + # (n_samples, n_outputs) and (n_samples, n_outputs) + y_pred_first_half = np.nanmean(first_forest_pred[:, :, :], axis=0) + y_pred_second_half = np.nanmean(second_forest_pred[:, :, :], axis=0) + + if any(np.isnan(y_pred_first_half).any()) or any(np.isnan(y_pred_second_half).any()): + raise RuntimeError("NaNs in the first half of the posteriors.") + + # Or... figure out if any sample indices have nans after averaging over trees + # and just slice them out in both y_test and y_pred. # compute two instances of the metric from the sampled trees - first_half_metric = metric_func( - y_test[first_forest_samples, :], y_pred_first_half, **metric_kwargs - ) - second_half_metric = metric_func( - y_test[second_forest_samples, :], y_pred_second_half, **metric_kwargs - ) + first_half_metric = metric_func(y_test, y_pred_first_half, **metric_kwargs) + second_half_metric = metric_func(y_test, y_pred_second_half, **metric_kwargs) return first_half_metric, second_half_metric diff --git a/sktree/tests/test_multiview_forest.py b/sktree/tests/test_multiview_forest.py index 95119b580..da168bbba 100644 --- a/sktree/tests/test_multiview_forest.py +++ b/sktree/tests/test_multiview_forest.py @@ -150,7 +150,6 @@ def test_three_view_dataset(n_views, max_features): clf = MultiViewRandomForestClassifier( random_state=seed, feature_set_ends=feature_set_ends, - apply_max_features_per_feature_set=True, max_features=max_features, n_estimators=n_estimators, ) diff --git a/sktree/tree/_honest_tree.py b/sktree/tree/_honest_tree.py index a35e84f99..b69f52ec7 100644 --- a/sktree/tree/_honest_tree.py +++ b/sktree/tree/_honest_tree.py @@ -285,6 +285,7 @@ class frequency in the voting subsample. "honest_prior": [StrOptions({"empirical", "uniform", "ignore"})], "stratify": ["boolean"], } + _parameter_constraints.pop("max_features") def __init__( self, diff --git a/sktree/tree/_multiview.py b/sktree/tree/_multiview.py index 52b70c0df..c37940121 100644 --- a/sktree/tree/_multiview.py +++ b/sktree/tree/_multiview.py @@ -10,11 +10,10 @@ from .._lib.sklearn.tree import DecisionTreeClassifier, _criterion from .._lib.sklearn.tree import _tree as _sklearn_tree from .._lib.sklearn.tree._criterion import BaseCriterion -from .._lib.sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder -from . import _oblique_splitter +from .._lib.sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree +from . import _multiview_splitter, _oblique_splitter +from ._multiview_splitter import MultiViewSplitter from ._neighbors import SimMatrixMixin -from ._oblique_splitter import ObliqueSplitter -from ._oblique_tree import ObliqueTree DTYPE = _sklearn_tree.DTYPE DOUBLE = _sklearn_tree.DOUBLE @@ -32,7 +31,8 @@ } DENSE_SPLITTERS = { - "best": _oblique_splitter.MultiViewSplitter, + "bestv2": _oblique_splitter.MultiViewSplitter, + "best": _multiview_splitter.BestMultiViewSplitter, } @@ -158,9 +158,6 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. - feature_combinations : float, default=None - Not used. - ccp_alpha : non-negative float, default=0.0 Not used. @@ -181,11 +178,6 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): next 20 features, then ``feature_set_ends = [10, 30]``. If ``None``, then this will assume that there is only one feature set. - apply_max_features_per_feature_set : bool, default=False - Whether to apply sampling per feature set, where ``max_features`` is applied - to each feature-set. If ``False``, then sampling - is applied over the entire feature space. - Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -226,9 +218,6 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object. - feature_combinations_ : float - The number of feature combinations on average taken to fit the tree. - feature_set_ends_ : array-like of int of shape (n_feature_sets,) The indices of the end of each feature set. @@ -248,12 +237,7 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): _parameter_constraints = { **DecisionTreeClassifier._parameter_constraints, - "feature_combinations": [ - Interval(Real, 1.0, None, closed="left"), - None, - ], "feature_set_ends": ["array-like", None], - "apply_max_features_per_feature_set": ["boolean"], } _parameter_constraints.pop("max_features") _parameter_constraints["max_features"] = [ @@ -263,6 +247,8 @@ class MultiViewDecisionTreeClassifier(SimMatrixMixin, DecisionTreeClassifier): "array-like", None, ] + _parameter_constraints.pop("splitter") + _parameter_constraints["splitter"] = [StrOptions({"best", "bestv2"})] def __init__( self, @@ -278,12 +264,10 @@ def __init__( max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None, - feature_combinations=None, ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, feature_set_ends=None, - apply_max_features_per_feature_set=False, ): super().__init__( criterion=criterion, @@ -302,9 +286,7 @@ def __init__( monotonic_cst=monotonic_cst, ) - self.feature_combinations = feature_combinations self.feature_set_ends = feature_set_ends - self.apply_max_features_per_feature_set = apply_max_features_per_feature_set self._max_features_arr = None def _build_tree( @@ -395,7 +377,6 @@ def _build_tree( if isinstance(self._max_features_arr, (Integral, Real, str, type(None))): max_features_arr_ = [self._max_features_arr] * self.n_feature_sets_ - stratify_mtry_per_view = self.apply_max_features_per_feature_set else: if not isinstance(self._max_features_arr, (list, np.ndarray)): raise ValueError( @@ -408,76 +389,55 @@ def _build_tree( f"got {len(self.max_features)}" ) max_features_arr_ = self._max_features_arr - stratify_mtry_per_view = True self.n_features_in_set_ = [] - if stratify_mtry_per_view: - # XXX: experimental - # we can replace max_features_ here based on whether or not uniform logic over - # feature sets - max_features_per_set = [] - n_features_in_prev = 0 - for idx in range(self.n_feature_sets_): - max_features = max_features_arr_[idx] - - n_features_in_ = self.feature_set_ends_[idx] - n_features_in_prev - n_features_in_prev += n_features_in_ - self.n_features_in_set_.append(n_features_in_) - if isinstance(max_features, str): - if max_features == "sqrt": - max_features = max(1, math.ceil(np.sqrt(n_features_in_))) - elif max_features == "log2": - max_features = max(1, math.ceil(np.log2(n_features_in_))) - elif max_features is None: - max_features = n_features_in_ - elif isinstance(max_features, numbers.Integral): - max_features = max_features - else: # float - if max_features > 0.0: - max_features = max(1, math.ceil(max_features * n_features_in_)) - else: - max_features = 0 - - if max_features > n_features_in_: - raise ValueError( - f"max_features must be less than or equal to " - f"the number of features in feature set {idx}: {n_features_in_}, but " - f"max_features = {max_features} when applying sampling" - f"per feature set." - ) - - max_features_per_set.append(max_features) - self.max_features_ = np.sum(max_features_per_set) - if self.max_features_ > n_features: - raise ValueError( - "max_features is greater than the number of features: " - f"{max_features} > {n_features}." - "This should not be possible. Please submit a bug report." - ) - self.max_features_per_set_ = np.asarray(max_features_per_set, dtype=np.intp) - # the total number of features to sample per split - self.max_features_ = np.sum(self.max_features_per_set_) - else: - self.max_features_per_set_ = None - self.max_features = self._max_features_arr - if isinstance(self.max_features, str): - if self.max_features == "sqrt": - max_features = max(1, int(np.sqrt(self.n_features_in_))) - elif self.max_features == "log2": - max_features = max(1, int(np.log2(self.n_features_in_))) - elif self.max_features is None: - max_features = self.n_features_in_ - elif isinstance(self.max_features, numbers.Integral): - max_features = self.max_features + # XXX: experimental + # we can replace max_features_ here based on whether or not uniform logic over + # feature sets + max_features_per_set = [] + n_features_in_prev = 0 + for idx in range(self.n_feature_sets_): + max_features = max_features_arr_[idx] + + n_features_in_ = self.feature_set_ends_[idx] - n_features_in_prev + n_features_in_prev += n_features_in_ + self.n_features_in_set_.append(n_features_in_) + if isinstance(max_features, str): + if max_features == "sqrt": + max_features = max(1, math.ceil(np.sqrt(n_features_in_))) + elif max_features == "log2": + max_features = max(1, math.ceil(np.log2(n_features_in_))) + elif max_features is None: + max_features = n_features_in_ + elif isinstance(max_features, numbers.Integral): + max_features = max_features else: # float - if self.max_features > 0.0: - max_features = max(1, int(self.max_features * self.n_features_in_)) + if max_features > 0.0: + max_features = max(1, math.ceil(max_features * n_features_in_)) else: max_features = 0 - self.max_features_ = max_features + if max_features > n_features_in_: + raise ValueError( + f"max_features must be less than or equal to " + f"the number of features in feature set {idx}: {n_features_in_}, but " + f"max_features = {max_features} when applying sampling" + f"per feature set." + ) + + max_features_per_set.append(max_features) + self.max_features_ = np.sum(max_features_per_set) + if self.max_features_ > n_features: + raise ValueError( + "max_features is greater than the number of features: " + f"{max_features} > {n_features}." + "This should not be possible. Please submit a bug report." + ) + self.max_features_per_set_ = np.asarray(max_features_per_set, dtype=np.intp) + # the total number of features to sample per split + self.max_features_ = np.sum(self.max_features_per_set_) - if not isinstance(self.splitter, ObliqueSplitter): + if not isinstance(self.splitter, MultiViewSplitter): splitter = SPLITTERS[self.splitter]( criterion, self.max_features_, @@ -485,13 +445,13 @@ def _build_tree( min_weight_leaf, random_state, monotonic_cst, - self.feature_combinations_, + # self.feature_combinations_, self.feature_set_ends_, self.n_feature_sets_, self.max_features_per_set_, ) - self.tree_ = ObliqueTree(self.n_features_in_, self.n_classes_, self.n_outputs_) + self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: @@ -571,7 +531,9 @@ def _fit( # - self.max_features_ to the original value # - self.max_features_arr contains a possible array-like setting of max_features self._max_features_arr = self.max_features - self.max_features = None + self.max_features = ( + None # if isinstance(self.max_features, (list, np.ndarray)) else self.max_features + ) super()._fit(X, y, sample_weight, check_input, missing_values_in_feature_mask, classes) self.max_features = self._max_features_arr return self diff --git a/sktree/tree/_multiview_splitter.pxd b/sktree/tree/_multiview_splitter.pxd new file mode 100644 index 000000000..7d25af348 --- /dev/null +++ b/sktree/tree/_multiview_splitter.pxd @@ -0,0 +1,64 @@ +import numpy as np + +cimport numpy as cnp +from libcpp.vector cimport vector + +from .._lib.sklearn.tree._criterion cimport Criterion +from .._lib.sklearn.tree._splitter cimport SplitRecord, Splitter +from .._lib.sklearn.tree._utils cimport UINT32_t +from .._lib.sklearn.utils._typedefs cimport float32_t, float64_t, intp_t +from ._sklearn_splitter cimport sort + + +cdef struct MultiViewSplitRecord: + # Data to track sample split + intp_t feature # Which feature to split on. + intp_t pos # Split samples array at the given position, + # # i.e. count of samples below threshold for feature. + # # pos is >= end if the node is a leaf. + float64_t threshold # Threshold to split at. + float64_t improvement # Impurity improvement given parent node. + float64_t impurity_left # Impurity of the left split. + float64_t impurity_right # Impurity of the right split. + float64_t lower_bound # Lower bound on value of both children for monotonicity + float64_t upper_bound # Upper bound on value of both children for monotonicity + unsigned char missing_go_to_left # Controls if missing values go to the left node. + intp_t n_missing # Number of missing values for the feature being split on + intp_t n_constant_features # Number of constant features in the split + + # could maybe be optimized + vector[intp_t] vec_n_constant_features # Number of constant features in the split for each feature view + + +# XXX: This splitter is experimental. Expect changes frequently. +cdef class MultiViewSplitter(Splitter): + cdef const intp_t[:] feature_set_ends # an array indicating the column indices of the end of each feature set + cdef intp_t n_feature_sets # the number of feature sets is the length of feature_set_ends + 1 + + cdef const intp_t[:] max_features_per_set # the maximum number of features to sample from each feature set + + # The following are used to track per feature set: + # - the number of visited features + # - the number of found constants in this split search + # - the number of drawn constants in this split search + cdef intp_t[:] vec_n_visited_features + cdef intp_t[:] vec_n_found_constants + cdef intp_t[:] vec_n_drawn_constants + + # XXX: moved from partitioner to this class + cdef const float32_t[:, :] X + cdef const unsigned char[::1] missing_values_in_feature_mask + cdef intp_t n_missing + cdef void sort_samples_and_feature_values( + self, intp_t current_feature + ) noexcept nogil + + cdef void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil + cdef intp_t partition_samples(self, float64_t current_threshold) noexcept nogil + cdef void partition_samples_final( + self, + intp_t best_pos, + float64_t best_threshold, + intp_t best_feature, + intp_t best_n_missing, + ) noexcept nogil diff --git a/sktree/tree/_multiview_splitter.pyx b/sktree/tree/_multiview_splitter.pyx new file mode 100644 index 000000000..cee647c2f --- /dev/null +++ b/sktree/tree/_multiview_splitter.pyx @@ -0,0 +1,604 @@ +# distutils: language=c++ +# cython: language_level=3 +# cython: boundscheck=False +# cython: wraparound=False +# cython: initializedcheck=False +import numpy as np + +from cython.operator cimport dereference as deref +from libc.math cimport isnan +from libc.string cimport memcpy +from libcpp.algorithm cimport fill +from libcpp.numeric cimport accumulate + +from .._lib.sklearn.tree._splitter cimport shift_missing_values_to_left_if_required +from .._lib.sklearn.tree._utils cimport rand_int + + +cdef float64_t INFINITY = np.inf + +# Mitigate precision differences between 32 bit and 64 bit +cdef float32_t FEATURE_THRESHOLD = 1e-7 + +# Constant to switch between algorithm non zero value extract algorithm +# in SparseSplitter +cdef float32_t EXTRACT_NNZ_SWITCH = 0.1 + + +cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil: + self.impurity_left = INFINITY + self.impurity_right = INFINITY + self.pos = start_pos + self.feature = 0 + self.threshold = 0. + self.improvement = -INFINITY + self.missing_go_to_left = False + self.n_missing = 0 + self.n_constant_features = 0 + + +cdef class MultiViewSplitter(Splitter): + def __cinit__( + self, + Criterion criterion, + intp_t max_features, + intp_t min_samples_leaf, + float64_t min_weight_leaf, + object random_state, + const cnp.int8_t[:] monotonic_cst, + const intp_t[:] feature_set_ends, + intp_t n_feature_sets, + const intp_t[:] max_features_per_set, + *argv + ): + """ + Parameters + ---------- + criterion : Criterion + The criterion to measure the quality of a split. + + max_features : intp_t + The maximal number of randomly selected features which can be + considered for a split. + + min_samples_leaf : intp_t + The minimal number of samples each leaf can have, where splits + which would result in having less samples in a leaf are not + considered. + + min_weight_leaf : float64_t + The minimal weight each leaf can have, where the weight is the sum + of the weights of each sample in it. + + random_state : object + The user inputted random state to be used for pseudo-randomness + + monotonic_cst : const cnp.int8_t[:] + Monotonicity constraints + + """ + self.feature_set_ends = feature_set_ends + + # infer the number of feature sets + self.n_feature_sets = n_feature_sets + + # replaces usage of max_features + self.max_features_per_set = max_features_per_set + + # initialize arrays to store the number of visited features, drawn constants and found constants + # for each feature set during each iteration of the split search + self.vec_n_visited_features = np.empty(n_feature_sets, dtype=np.intp) + self.vec_n_drawn_constants = np.empty(n_feature_sets, dtype=np.intp) + self.vec_n_found_constants = np.empty(n_feature_sets, dtype=np.intp) + + self.n_missing = 0 + + def __reduce__(self): + return (type(self), + ( + self.criterion, + self.max_features, + self.min_samples_leaf, + self.min_weight_leaf, + self.random_state, + self.monotonic_cst.base if self.monotonic_cst is not None else None, + self.feature_set_ends.base if self.feature_set_ends is not None else None, + self.n_feature_sets, + self.max_features_per_set.base if self.max_features_per_set is not None else None, + ), self.__getstate__()) + + cdef int init( + self, + object X, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, + const unsigned char[::1] missing_values_in_feature_mask, + ) except -1: + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + self.X = X + self.missing_values_in_feature_mask = missing_values_in_feature_mask + + cdef intp_t pointer_size(self) noexcept nogil: + """Get size of a pointer to record for ObliqueSplitter.""" + + return sizeof(MultiViewSplitRecord) + + cdef inline void sort_samples_and_feature_values( + self, intp_t current_feature + ) noexcept nogil: + """Simultaneously sort based on the feature_values. + + Missing values are stored at the end of feature_values. + The number of missing values observed in feature_values is stored + in self.n_missing. + """ + cdef: + intp_t i, current_end + float32_t[::1] feature_values = self.feature_values + const float32_t[:, :] X = self.X + intp_t[::1] samples = self.samples + intp_t n_missing = 0 + const unsigned char[::1] missing_values_in_feature_mask = self.missing_values_in_feature_mask + + # Sort samples along that feature; by + # copying the values into an array and + # sorting the array in a manner which utilizes the cache more + # effectively. + if missing_values_in_feature_mask is not None and missing_values_in_feature_mask[current_feature]: + i, current_end = self.start, self.end - 1 + # Missing values are placed at the end and do not participate in the sorting. + while i <= current_end: + # Finds the right-most value that is not missing so that + # it can be swapped with missing values at its left. + if isnan(X[samples[current_end], current_feature]): + n_missing += 1 + current_end -= 1 + continue + + # X[samples[current_end], current_feature] is a non-missing value + if isnan(X[samples[i], current_feature]): + samples[i], samples[current_end] = samples[current_end], samples[i] + n_missing += 1 + current_end -= 1 + + feature_values[i] = X[samples[i], current_feature] + i += 1 + else: + # When there are no missing values, we only need to copy the data into + # feature_values + for i in range(self.start, self.end): + feature_values[i] = X[samples[i], current_feature] + + sort(&feature_values[self.start], &samples[self.start], self.end - self.start - n_missing) + self.n_missing = n_missing + + cdef inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil: + """Compute the next p_prev and p for iteratiing over feature values. + + The missing values are not included when iterating through the feature values. + """ + cdef: + float32_t[::1] feature_values = self.feature_values + intp_t end_non_missing = self.end - self.n_missing + + while ( + p[0] + 1 < end_non_missing and + feature_values[p[0] + 1] <= feature_values[p[0]] + FEATURE_THRESHOLD + ): + p[0] += 1 + + p_prev[0] = p[0] + + # By adding 1, we have + # (feature_values[p] >= end) or (feature_values[p] > feature_values[p - 1]) + p[0] += 1 + + cdef inline intp_t partition_samples(self, float64_t current_threshold) noexcept nogil: + """Partition samples for feature_values at the current_threshold.""" + cdef: + intp_t p = self.start + intp_t partition_end = self.end + intp_t[::1] samples = self.samples + float32_t[::1] feature_values = self.feature_values + + while p < partition_end: + if feature_values[p] <= current_threshold: + p += 1 + else: + partition_end -= 1 + + feature_values[p], feature_values[partition_end] = ( + feature_values[partition_end], feature_values[p] + ) + samples[p], samples[partition_end] = samples[partition_end], samples[p] + + return partition_end + + cdef inline void partition_samples_final( + self, + intp_t best_pos, + float64_t best_threshold, + intp_t best_feature, + intp_t best_n_missing, + ) noexcept nogil: + """Partition samples for X at the best_threshold and best_feature. + + If missing values are present, this method partitions `samples` + so that the `best_n_missing` missing values' indices are in the + right-most end of `samples`, that is `samples[end_non_missing:end]`. + """ + cdef: + # Local invariance: start <= p <= partition_end <= end + intp_t start = self.start + intp_t p = start + intp_t end = self.end - 1 + intp_t partition_end = end - best_n_missing + intp_t[::1] samples = self.samples + const float32_t[:, :] X = self.X + float32_t current_value + + if best_n_missing != 0: + # Move samples with missing values to the end while partitioning the + # non-missing samples + while p < partition_end: + # Keep samples with missing values at the end + if isnan(X[samples[end], best_feature]): + end -= 1 + continue + + # Swap sample with missing values with the sample at the end + current_value = X[samples[p], best_feature] + if isnan(current_value): + samples[p], samples[end] = samples[end], samples[p] + end -= 1 + + # The swapped sample at the end is always a non-missing value, so + # we can continue the algorithm without checking for missingness. + current_value = X[samples[p], best_feature] + + # Partition the non-missing samples + if current_value <= best_threshold: + p += 1 + else: + samples[p], samples[partition_end] = samples[partition_end], samples[p] + partition_end -= 1 + else: + # Partitioning routine when there are no missing values + while p < partition_end: + if X[samples[p], best_feature] <= best_threshold: + p += 1 + else: + samples[p], samples[partition_end] = samples[partition_end], samples[p] + partition_end -= 1 + + +cdef class BestMultiViewSplitter(MultiViewSplitter): + """Splitter for finding the best split on dense data.""" + cdef int node_split( + self, + float64_t impurity, # Impurity of the node + SplitRecord* split, + float64_t lower_bound, + float64_t upper_bound, + ) except -1 nogil: + """Find the best split on node samples[start:end] + + Note: this implementation differs from scikit-learn because + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ + # typecast the pointer to an ObliqueSplitRecord + cdef MultiViewSplitRecord* multiview_split = (split) + + # Find the best split + cdef intp_t start = self.start + cdef intp_t end = self.end + cdef intp_t end_non_missing + cdef intp_t n_missing = 0 + cdef bint has_missing = 0 + cdef intp_t n_searches + cdef intp_t n_left, n_right + cdef bint missing_go_to_left + + cdef intp_t[::1] samples = self.samples + cdef intp_t[::1] features = self.features + cdef intp_t[::1] constant_features = self.constant_features + + cdef float32_t[::1] feature_values = self.feature_values + cdef intp_t max_features = self.max_features + cdef intp_t min_samples_leaf = self.min_samples_leaf + cdef float64_t min_weight_leaf = self.min_weight_leaf + cdef UINT32_t* random_state = &self.rand_r_state + + cdef SplitRecord best_split, current_split + cdef float64_t current_proxy_improvement = -INFINITY + cdef float64_t best_proxy_improvement = -INFINITY + + cdef bint found_new_constants = False + + # pointer in the feature set + cdef intp_t f_i + + cdef intp_t f_j + cdef intp_t p + cdef intp_t p_prev + + cdef intp_t ifeature + cdef intp_t feature_set_begin = 0 + + # Number of features discovered to be constant during the split search + cdef intp_t[:] vec_n_found_constants = self.vec_n_found_constants + # Number of features known to be constant and drawn without replacement + cdef intp_t[:] vec_n_drawn_constants = self.vec_n_drawn_constants + cdef intp_t[:] vec_n_visited_features = self.vec_n_visited_features + + # We reset the number of visited features, drawn constants and found constants + # for each feature set to 0 at the beginning of the split search. + for ifeature in range(self.n_feature_sets): + vec_n_found_constants[ifeature] = 0 + vec_n_drawn_constants[ifeature] = 0 + vec_n_visited_features[ifeature] = 0 + + if deref(multiview_split).vec_n_constant_features.size() == 0: + deref(multiview_split).vec_n_constant_features.resize(self.n_feature_sets) + fill(deref(multiview_split).vec_n_constant_features.begin(), deref(multiview_split).vec_n_constant_features.end(), 0) + + cdef vector[intp_t] n_known_constants_vec = deref(multiview_split).vec_n_constant_features + + # n_total_constants = n_known_constants + n_found_constants + cdef vector[intp_t] n_total_constants_vec = n_known_constants_vec + + _init_split(&best_split, end) + + for ifeature in range(self.n_feature_sets): + # get the max-features for this feature-set + max_features = self.max_features_per_set[ifeature] + f_i = self.feature_set_ends[ifeature] + + # Sample up to max_features without replacement using a + # Fisher-Yates-based algorithm (using the local variables `f_i` and + # `f_j` to compute a permutation of the `features` array). + # + # Skip the CPU intensive evaluation of the impurity criterion for + # features that were already detected as constant (hence not suitable + # for good splitting) by ancestor nodes and save the information on + # newly discovered constant features to spare computation on descendant + # nodes. + while ((f_i - feature_set_begin) > n_total_constants_vec[ifeature] and # Stop early if remaining features + # are constant within this feature set + (vec_n_visited_features[ifeature] < max_features or # At least one drawn features must be non constant + vec_n_visited_features[ifeature] <= vec_n_found_constants[ifeature] + vec_n_drawn_constants[ifeature])): + + vec_n_visited_features[ifeature] += 1 + + # The following is loop invariant per feature set: + # [ --- view-one ---, --- view-two --- ] + # within each view, the features are ordered as follows: + # [constant, known constant, newly found constant, non-constant] + + # Loop invariant: elements of features in + # - [:n_drawn_constant[ holds drawn and known constant features; + # - [n_drawn_constant:n_known_constant[ holds known constant + # features that haven't been drawn yet; + # - [n_known_constant:n_total_constant[ holds newly found constant + # features; + # - [n_total_constant:f_i[ holds features that haven't been drawn + # yet and aren't constant apriori. + # - [f_i:n_features[ holds features that have been drawn + # and aren't constant. + + # Draw a feature at random from the feature-set + f_j = rand_int(vec_n_drawn_constants[ifeature] + feature_set_begin, f_i - vec_n_found_constants[ifeature], + random_state) + + # If the drawn feature is known to be constant, swap it with the + # last known constant feature and update the number of drawn constant features + if f_j < n_known_constants_vec[ifeature]: + # f_j in the interval [n_drawn_constants, n_known_constants[ + features[vec_n_drawn_constants[ifeature]], features[f_j] = features[f_j], features[vec_n_drawn_constants[ifeature]] + + vec_n_drawn_constants[ifeature] += 1 + found_new_constants = True + continue + + # f_j in the interval [n_known_constants, f_i - n_found_constants[ + f_j += vec_n_found_constants[ifeature] + # f_j in the interval [n_total_constants, f_i[ + current_split.feature = features[f_j] + self.sort_samples_and_feature_values(current_split.feature) + n_missing = self.n_missing + end_non_missing = end - n_missing + + if ( + # All values for this feature are missing, or + end_non_missing == start or + # This feature is considered constant (max - min <= FEATURE_THRESHOLD) + feature_values[end_non_missing - 1] <= feature_values[start] + FEATURE_THRESHOLD + ): + # We consider this feature constant in this case. + # Since finding a split among constant feature is not valuable, + # we do not consider this feature for splitting. + features[f_j], features[n_total_constants_vec[ifeature]] = features[n_total_constants_vec[ifeature]], features[f_j] + + vec_n_found_constants[ifeature] += 1 + n_total_constants_vec[ifeature] += 1 + found_new_constants = True + continue + + f_i -= 1 + features[f_i], features[f_j] = features[f_j], features[f_i] + has_missing = n_missing != 0 + self.criterion.init_missing(n_missing) # initialize even when n_missing == 0 + + # Evaluate all splits + # If there are missing values, then we search twice for the most optimal split. + # The first search will have all the missing values going to the right node. + # The second search will have all the missing values going to the left node. + # If there are no missing values, then we search only once for the most + # optimal split. + n_searches = 2 if has_missing else 1 + + for i in range(n_searches): + missing_go_to_left = i == 1 + self.criterion.missing_go_to_left = missing_go_to_left + self.criterion.reset() + + p = start + + while p < end_non_missing: + self.next_p(&p_prev, &p) + + if p >= end_non_missing: + continue + + current_split.pos = p + + # Reject if monotonicity constraints are not satisfied + if ( + self.with_monotonic_cst and + self.monotonic_cst[current_split.feature] != 0 and + not self.criterion.check_monotonicity( + self.monotonic_cst[current_split.feature], + lower_bound, + upper_bound, + ) + ): + continue + + # Reject if min_samples_leaf is not guaranteed + if missing_go_to_left: + n_left = current_split.pos - self.start + n_missing + n_right = end_non_missing - current_split.pos + else: + n_left = current_split.pos - self.start + n_right = end_non_missing - current_split.pos + n_missing + if self.check_presplit_conditions(¤t_split, n_missing, missing_go_to_left) == 1: + continue + + self.criterion.update(current_split.pos) + + # Reject if monotonicity constraints are not satisfied + if ( + self.with_monotonic_cst and + self.monotonic_cst[current_split.feature] != 0 and + not self.criterion.check_monotonicity( + self.monotonic_cst[current_split.feature], + lower_bound, + upper_bound, + ) + ): + continue + + # Reject if min_weight_leaf is not satisfied + if self.check_postsplit_conditions() == 1: + continue + + current_proxy_improvement = self.criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + # sum of halves is used to avoid infinite value + current_split.threshold = ( + feature_values[p_prev] / 2.0 + feature_values[p] / 2.0 + ) + + if ( + current_split.threshold == feature_values[p] or + current_split.threshold == INFINITY or + current_split.threshold == -INFINITY + ): + current_split.threshold = feature_values[p_prev] + + current_split.n_missing = n_missing + if n_missing == 0: + current_split.missing_go_to_left = n_left > n_right + else: + current_split.missing_go_to_left = missing_go_to_left + + best_split = current_split # copy + + # Evaluate when there are missing values and all missing values goes + # to the right node and non-missing values goes to the left node. + if has_missing: + n_left, n_right = end - start - n_missing, n_missing + p = end - n_missing + missing_go_to_left = 0 + + if not (n_left < min_samples_leaf or n_right < min_samples_leaf): + self.criterion.missing_go_to_left = missing_go_to_left + self.criterion.update(p) + + if not ((self.criterion.weighted_n_left < min_weight_leaf) or + (self.criterion.weighted_n_right < min_weight_leaf)): + current_proxy_improvement = self.criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + current_split.threshold = INFINITY + current_split.missing_go_to_left = missing_go_to_left + current_split.n_missing = n_missing + current_split.pos = p + best_split = current_split + + # update the feature_set_begin for the next iteration + feature_set_begin = self.feature_set_ends[ifeature] + + # Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end] + if best_split.pos < end: + self.partition_samples_final( + best_split.pos, + best_split.threshold, + best_split.feature, + best_split.n_missing + ) + self.criterion.init_missing(best_split.n_missing) + self.criterion.missing_go_to_left = best_split.missing_go_to_left + + self.criterion.reset() + self.criterion.update(best_split.pos) + self.criterion.children_impurity( + &best_split.impurity_left, &best_split.impurity_right + ) + best_split.improvement = self.criterion.impurity_improvement( + impurity, + best_split.impurity_left, + best_split.impurity_right + ) + + shift_missing_values_to_left_if_required(&best_split, samples, end) + + # Reorganize constant features per feature view + if found_new_constants: + feature_set_begin = 0 + for ifeature in range(self.n_feature_sets): + # Respect invariant for constant features: the original order of + # element in features[:n_known_constants] must be preserved for sibling + # and child nodes + memcpy(&features[feature_set_begin], &constant_features[feature_set_begin], sizeof(intp_t) * n_known_constants_vec[ifeature]) + + # Copy newly found constant features starting from [n_known_constants:n_found_constants] + memcpy(&constant_features[n_known_constants_vec[ifeature]], + &features[n_known_constants_vec[ifeature]], + sizeof(intp_t) * vec_n_found_constants[ifeature]) + + feature_set_begin = self.feature_set_ends[ifeature] + + # Return values + if found_new_constants: + deref(multiview_split).n_constant_features = accumulate( + n_known_constants_vec.begin(), + n_known_constants_vec.end(), + 0 + ) + deref(multiview_split).vec_n_constant_features = n_known_constants_vec + + deref(multiview_split).feature = best_split.feature + deref(multiview_split).pos = best_split.pos + deref(multiview_split).threshold = best_split.threshold + deref(multiview_split).improvement = best_split.improvement + deref(multiview_split).impurity_left = best_split.impurity_left + deref(multiview_split).impurity_right = best_split.impurity_right + # multiview_split[0] = best_split + # n_constant_features[0] = n_total_constants + return 0 diff --git a/sktree/tree/meson.build b/sktree/tree/meson.build index 9737016af..ed73eda88 100644 --- a/sktree/tree/meson.build +++ b/sktree/tree/meson.build @@ -5,6 +5,9 @@ tree_extension_metadata = { '_oblique_splitter': {'sources': ['_oblique_splitter.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']}, + '_multiview_splitter': + {'sources': ['_multiview_splitter.pyx'], + 'override_options': ['cython_language=cpp', 'optimization=3']}, '_oblique_tree': {'sources': ['_oblique_tree.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']}, diff --git a/sktree/tree/tests/test_honest_tree.py b/sktree/tree/tests/test_honest_tree.py index 907c386f1..fd249b9c4 100644 --- a/sktree/tree/tests/test_honest_tree.py +++ b/sktree/tree/tests/test_honest_tree.py @@ -66,14 +66,14 @@ def test_toy_accuracy(): @pytest.mark.parametrize( "tree, tree_kwargs", [ - (MultiViewDecisionTreeClassifier(), {"feature_set_ends": [10, 20]}), + (MultiViewDecisionTreeClassifier(), {"feature_set_ends": [10, 30]}), (ObliqueDecisionTreeClassifier(), {"feature_combinations": 2}), (PatchObliqueDecisionTreeClassifier(), {"max_patch_dims": 5}), ], ) def test_honest_tree_with_tree_estimator_params(tree, tree_kwargs): """Test that honest tree inherits all the fitted parameters of the tree estimator.""" - X = np.ones((20, 4)) + X = np.ones((20, 30)) X[10:] *= -1 y = [0] * 10 + [1] * 10 diff --git a/sktree/tree/tests/test_multiview.py b/sktree/tree/tests/test_multiview.py index 419ca378d..4ddbd2a90 100644 --- a/sktree/tree/tests/test_multiview.py +++ b/sktree/tree/tests/test_multiview.py @@ -64,19 +64,19 @@ def test_multiview_classification(baseline_est): clf = MultiViewDecisionTreeClassifier( random_state=seed, feature_set_ends=[n_features_1, X.shape[1]], - max_features=0.3, + max_features=[0.1, 0.1], ) clf.fit(X, y) assert ( accuracy_score(y, clf.predict(X)) == 1.0 ), f"Accuracy score: {accuracy_score(y, clf.predict(X))}" assert ( - cross_val_score(clf, X, y, cv=5).mean() > 0.9 + cross_val_score(clf, X, y, cv=5).mean() >= 0.9 ), f"CV score: {cross_val_score(clf, X, y, cv=5).mean()}" base_clf = baseline_est( random_state=seed, - max_features=0.3, + max_features=0.1, ) assert cross_val_score(base_clf, X, y, cv=5).mean() < cross_val_score(clf, X, y, cv=5).mean(), ( f"CV score: {cross_val_score(base_clf, X, y, cv=5).mean()} vs " @@ -102,7 +102,6 @@ def test_multiview_errors(): random_state=seed, feature_set_ends=[3, 5], max_features=6, - apply_max_features_per_feature_set=True, ) with pytest.raises(ValueError, match="the number of features in feature set"): clf.fit(X, y) @@ -117,7 +116,6 @@ def test_multiview_separate_feature_set_sampling_sets_attributes(): random_state=seed, feature_set_ends=[6, 10], max_features=0.5, - apply_max_features_per_feature_set=True, ) clf.fit(X, y) @@ -130,7 +128,6 @@ def test_multiview_separate_feature_set_sampling_sets_attributes(): random_state=seed, feature_set_ends=[9, 13], max_features="sqrt", - apply_max_features_per_feature_set=True, ) clf.fit(X, y) assert_array_equal(clf.max_features_per_set_, [3, 2]) @@ -142,7 +139,6 @@ def test_multiview_separate_feature_set_sampling_sets_attributes(): random_state=seed, feature_set_ends=[5, 9], max_features="sqrt", - apply_max_features_per_feature_set=True, ) clf.fit(X, y) assert_array_equal(clf.max_features_per_set_, [3, 2]) @@ -160,7 +156,6 @@ def test_at_least_one_feature_per_view_is_sampled(): random_state=seed, feature_set_ends=[1, 2, 4, 10], max_features=0.4, - apply_max_features_per_feature_set=True, ) clf.fit(X, y) @@ -178,7 +173,6 @@ def test_multiview_separate_feature_set_sampling_is_consistent(): random_state=seed, feature_set_ends=[1, 3, 6, 10], max_features=[1, 2, 2, 3], - apply_max_features_per_feature_set=True, ) clf.fit(X, y) @@ -192,15 +186,13 @@ def test_multiview_separate_feature_set_sampling_is_consistent(): random_state=seed, feature_set_ends=[1, 3, 6, 10], max_features=[1, 2, 2, 3], - apply_max_features_per_feature_set=False, ) other_clf.fit(X, y) assert_array_equal(other_clf.tree_.value, clf.tree_.value) -@pytest.mark.parametrize("stratify_mtry_per_view", [True, False]) -def test_separate_mtry_per_feature_set(stratify_mtry_per_view): +def test_separate_mtry_per_feature_set(): """Test that multiview decision tree can sample different numbers of features per view. Sets the ``max_feature`` argument as an array-like. @@ -213,7 +205,6 @@ def test_separate_mtry_per_feature_set(stratify_mtry_per_view): random_state=seed, feature_set_ends=[1, 2, 4, 10], max_features=[0.4, 0.5, 0.6, 0.7], - apply_max_features_per_feature_set=stratify_mtry_per_view, ) clf.fit(X, y) @@ -225,7 +216,6 @@ def test_separate_mtry_per_feature_set(stratify_mtry_per_view): random_state=seed, feature_set_ends=[1, 2, 4, 10], max_features=[1, 1, 1, 1.0], - apply_max_features_per_feature_set=stratify_mtry_per_view, ) clf.fit(X, y) assert_array_equal(clf.max_features_per_set_, [1, 1, 1, 6]) @@ -236,14 +226,10 @@ def test_separate_mtry_per_feature_set(stratify_mtry_per_view): random_state=seed, feature_set_ends=[1, 2, 4, 10], max_features=1.0, - apply_max_features_per_feature_set=stratify_mtry_per_view, ) clf.fit(X, y) - if stratify_mtry_per_view: - assert_array_equal(clf.max_features_per_set_, [1, 1, 2, 6]) - else: - assert clf.max_features_per_set_ is None - assert clf.max_features_ == 10 + assert_array_equal(clf.max_features_per_set_, [1, 1, 2, 6]) + assert clf.max_features_ == 10 assert clf.max_features_ == 10, np.sum(clf.max_features_per_set_) @@ -262,9 +248,10 @@ def test_multiview_without_feature_view_stratification(): random_state=seed, feature_set_ends=[497, 500], max_features=0.3, - apply_max_features_per_feature_set=False, ) clf.fit(X, y) - assert clf.max_features_per_set_ is None - assert clf.max_features_ == 500 * clf.max_features, clf.max_features_ + assert clf.max_features_ == math.ceil(497.0 * clf.max_features) + math.ceil( + 3 * clf.max_features + ) + assert_array_equal(clf.max_features_per_set_, [150, 1]), clf.max_features_per_set_