From 0246b6a01280836531934ddad8fef142cae6951d Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 9 Aug 2024 08:13:46 -0400 Subject: [PATCH] Refactoring to separate cython files Signed-off-by: Adam Li --- treeple/tree/_honest_prune.pxd | 3 + treeple/tree/_honest_prune.pyx | 145 +++------------------------------ treeple/tree/_prune.pxd | 19 +++++ treeple/tree/_prune.pyx | 109 +++++++++++++++++++++++++ treeple/tree/meson.build | 3 + 5 files changed, 146 insertions(+), 133 deletions(-) create mode 100644 treeple/tree/_prune.pxd create mode 100644 treeple/tree/_prune.pyx diff --git a/treeple/tree/_honest_prune.pxd b/treeple/tree/_honest_prune.pxd index d0ded7a3a..00e41f909 100644 --- a/treeple/tree/_honest_prune.pxd +++ b/treeple/tree/_honest_prune.pxd @@ -22,6 +22,8 @@ cdef class HonestPruner(Splitter): cdef intp_t pos # The current position to split left/right children cdef intp_t n_missing # The number of missing values in the feature currently considered cdef uint8_t missing_go_to_left + + # TODO: only supports sparse for now. cdef const float32_t[:, :] X cdef int init( @@ -32,6 +34,7 @@ cdef class HonestPruner(Splitter): const uint8_t[::1] missing_values_in_feature_mask, ) except -1 + # This function is not used, and should be disabled for pruners cdef int node_split( self, ParentInfo* parent_record, diff --git a/treeple/tree/_honest_prune.pyx b/treeple/tree/_honest_prune.pyx index 9bcae910f..d20602460 100644 --- a/treeple/tree/_honest_prune.pyx +++ b/treeple/tree/_honest_prune.pyx @@ -13,6 +13,8 @@ from libc.stdlib cimport free, malloc from libcpp.stack cimport stack from sklearn.tree._tree cimport ParentInfo +from ._prune cimport _build_pruned_tree + TREE_LEAF = -1 TREE_UNDEFINED = -2 cdef intp_t _TREE_LEAF = TREE_LEAF @@ -265,7 +267,7 @@ cdef class HonestPruner(Splitter): Returns 0 if a split cannot be done, 1 if a split can be done and -1 in case of failure to allocate memory (and raise MemoryError). """ - pass + raise NotImplementedError("node_split is not used in honest pruning") cdef _honest_prune( @@ -327,10 +329,6 @@ cdef _honest_prune( float64_t lower_bound, upper_bound float64_t left_child_min, left_child_max, right_child_min, right_child_max, middle_value - cdef bint first = 0 - cdef ParentInfo parent_record - _init_parent_record(&parent_record) - # find parent node ids and leaves with nogil: # Push the root node @@ -352,9 +350,7 @@ cdef _honest_prune( pruning_stack.pop() start = stack_record.start end = stack_record.end - parent_record.impurity = stack_record.impurity - parent_record.lower_bound = stack_record.lower_bound - parent_record.upper_bound = stack_record.upper_bound + impurity = stack_record.impurity lower_bound = stack_record.lower_bound upper_bound = stack_record.upper_bound @@ -366,7 +362,7 @@ cdef _honest_prune( # get the impurity to initialize passing into its children if first: - parent_record.impurity = pruner.node_impurity() + impurity = pruner.node_impurity() first = 0 # partition samples into left/right child based on the @@ -377,6 +373,7 @@ cdef _honest_prune( split_ptr.feature = orig_tree.nodes[node_idx].feature invalid_split = pruner.check_node_partition_conditions( split_ptr, + impurity, lower_bound, upper_bound ) @@ -402,12 +399,12 @@ cdef _honest_prune( # Current bounds must always be propagated to both children. # If a monotonic constraint is active, bounds are used in # node value clipping. - left_child_min = right_child_min = parent_record.lower_bound - left_child_max = right_child_max = parent_record.upper_bound + left_child_min = right_child_min = lower_bound + left_child_max = right_child_max = upper_bound elif pruner.monotonic_cst[split_ptr.feature] == 1: # Split on a feature with monotonic increase constraint - left_child_min = parent_record.lower_bound - right_child_max = parent_record.upper_bound + left_child_min = lower_bound + right_child_max = upper_bound # Lower bound for right child and upper bound for left child # are set to the same value. @@ -416,8 +413,8 @@ cdef _honest_prune( left_child_max = middle_value else: # i.e. pruner.monotonic_cst[split.feature] == -1 # Split on a feature with monotonic decrease constraint - right_child_min = parent_record.lower_bound - left_child_max = parent_record.upper_bound + right_child_min = lower_bound + left_child_max = upper_bound # Lower bound for left child and upper bound for right child # are set to the same value. @@ -444,121 +441,3 @@ cdef _honest_prune( # free the memory created for the SplitRecord pointer free(split_ptr) - - -from libc.stdint cimport INTPTR_MAX -from libc.string cimport memcpy - - -cdef struct BuildPrunedRecord: - intp_t start - intp_t depth - intp_t parent - bint is_left - - -cdef _build_pruned_tree( - Tree tree, # OUT - Tree orig_tree, - const uint8_t[:] leaves_in_subtree, - intp_t capacity -): - """Build a pruned tree. - - Build a pruned tree from the original tree by transforming the nodes in - ``leaves_in_subtree`` into leaves. - - Parameters - ---------- - tree : Tree - Location to place the pruned tree - orig_tree : Tree - Original tree - leaves_in_subtree : uint8_t memoryview, shape=(node_count, ) - Boolean mask for leaves to include in subtree - capacity : intp_t - Number of nodes to initially allocate in pruned tree - """ - tree._resize(capacity) - - cdef: - intp_t orig_node_id - intp_t new_node_id - intp_t depth - intp_t parent - bint is_left - bint is_leaf - - # value_stride for original tree and new tree are the same - intp_t value_stride = orig_tree.value_stride - intp_t max_depth_seen = -1 - intp_t rc = 0 - Node* node - float64_t* orig_value_ptr - float64_t* new_value_ptr - - stack[BuildPrunedRecord] prune_stack - BuildPrunedRecord stack_record - - SplitRecord split - - with nogil: - # push root node onto stack - prune_stack.push({"start": 0, "depth": 0, "parent": _TREE_UNDEFINED, "is_left": 0}) - - while not prune_stack.empty(): - stack_record = prune_stack.top() - prune_stack.pop() - - orig_node_id = stack_record.start - depth = stack_record.depth - parent = stack_record.parent - is_left = stack_record.is_left - - is_leaf = leaves_in_subtree[orig_node_id] - node = &orig_tree.nodes[orig_node_id] - - # redefine to a SplitRecord to pass into _add_node - split.feature = node.feature - split.threshold = node.threshold - - # protect against an infinite loop as a runtime error, when leaves_in_subtree - # are improperly set where a node is not marked as a leaf, but is a node - # in the original tree. Thus, it violates the assumption that the node - # is a leaf in the pruned tree, or has a descendant that will be pruned. - if (not is_leaf and node.left_child == _TREE_LEAF - and node.right_child == _TREE_LEAF): - raise ValueError( - "Node has reached a leaf in the original tree, but is not " - "marked as a leaf in the leaves_in_subtree mask." - ) - - new_node_id = tree._add_node( - parent, is_left, is_leaf, &split, - node.impurity, node.n_node_samples, - node.weighted_n_node_samples, node.missing_go_to_left) - - if new_node_id == INTPTR_MAX: - rc = -1 - break - - # copy value from original tree to new tree - orig_value_ptr = orig_tree.value + value_stride * orig_node_id - new_value_ptr = tree.value + value_stride * new_node_id - memcpy(new_value_ptr, orig_value_ptr, sizeof(float64_t) * value_stride) - - if not is_leaf: - # Push right child on stack - prune_stack.push({"start": node.right_child, "depth": depth + 1, - "parent": new_node_id, "is_left": 0}) - # push left child on stack - prune_stack.push({"start": node.left_child, "depth": depth + 1, - "parent": new_node_id, "is_left": 1}) - - if depth > max_depth_seen: - max_depth_seen = depth - - if rc >= 0: - tree.max_depth = max_depth_seen - if rc == -1: - raise MemoryError("pruning tree") diff --git a/treeple/tree/_prune.pxd b/treeple/tree/_prune.pxd new file mode 100644 index 000000000..3e7a638d9 --- /dev/null +++ b/treeple/tree/_prune.pxd @@ -0,0 +1,19 @@ +# Copied from scikit-learn/tree/_tree.pyx + +from libc.stdint cimport INTPTR_MAX +from libc.string cimport memcpy + + +cdef struct BuildPrunedRecord: + intp_t start + intp_t depth + intp_t parent + bint is_left + + +cdef void _build_pruned_tree( + Tree tree, # OUT + Tree orig_tree, + const uint8_t[:] leaves_in_subtree, + intp_t capacity +) noexcept diff --git a/treeple/tree/_prune.pyx b/treeple/tree/_prune.pyx new file mode 100644 index 000000000..040491f09 --- /dev/null +++ b/treeple/tree/_prune.pyx @@ -0,0 +1,109 @@ +# cython: boundscheck=False +# cython: wraparound=False +# cython: initializedcheck=False + +cdef void _build_pruned_tree( + Tree tree, # OUT + Tree orig_tree, + const uint8_t[:] leaves_in_subtree, + intp_t capacity +) noexcept: + """Build a pruned tree. + + Build a pruned tree from the original tree by transforming the nodes in + ``leaves_in_subtree`` into leaves. + + Parameters + ---------- + tree : Tree + Location to place the pruned tree + orig_tree : Tree + Original tree + leaves_in_subtree : uint8_t memoryview, shape=(node_count, ) + Boolean mask for leaves to include in subtree + capacity : intp_t + Number of nodes to initially allocate in pruned tree + """ + tree._resize(capacity) + + cdef: + intp_t orig_node_id + intp_t new_node_id + intp_t depth + intp_t parent + bint is_left + bint is_leaf + + # value_stride for original tree and new tree are the same + intp_t value_stride = orig_tree.value_stride + intp_t max_depth_seen = -1 + intp_t rc = 0 + Node* node + float64_t* orig_value_ptr + float64_t* new_value_ptr + + stack[BuildPrunedRecord] prune_stack + BuildPrunedRecord stack_record + + SplitRecord split + + with nogil: + # push root node onto stack + prune_stack.push({"start": 0, "depth": 0, "parent": _TREE_UNDEFINED, "is_left": 0}) + + while not prune_stack.empty(): + stack_record = prune_stack.top() + prune_stack.pop() + + orig_node_id = stack_record.start + depth = stack_record.depth + parent = stack_record.parent + is_left = stack_record.is_left + + is_leaf = leaves_in_subtree[orig_node_id] + node = &orig_tree.nodes[orig_node_id] + + # redefine to a SplitRecord to pass into _add_node + split.feature = node.feature + split.threshold = node.threshold + + # protect against an infinite loop as a runtime error, when leaves_in_subtree + # are improperly set where a node is not marked as a leaf, but is a node + # in the original tree. Thus, it violates the assumption that the node + # is a leaf in the pruned tree, or has a descendant that will be pruned. + if (not is_leaf and node.left_child == _TREE_LEAF + and node.right_child == _TREE_LEAF): + raise ValueError( + "Node has reached a leaf in the original tree, but is not " + "marked as a leaf in the leaves_in_subtree mask." + ) + + new_node_id = tree._add_node( + parent, is_left, is_leaf, &split, + node.impurity, node.n_node_samples, + node.weighted_n_node_samples, node.missing_go_to_left) + + if new_node_id == INTPTR_MAX: + rc = -1 + break + + # copy value from original tree to new tree + orig_value_ptr = orig_tree.value + value_stride * orig_node_id + new_value_ptr = tree.value + value_stride * new_node_id + memcpy(new_value_ptr, orig_value_ptr, sizeof(float64_t) * value_stride) + + if not is_leaf: + # Push right child on stack + prune_stack.push({"start": node.right_child, "depth": depth + 1, + "parent": new_node_id, "is_left": 0}) + # push left child on stack + prune_stack.push({"start": node.left_child, "depth": depth + 1, + "parent": new_node_id, "is_left": 1}) + + if depth > max_depth_seen: + max_depth_seen = depth + + if rc >= 0: + tree.max_depth = max_depth_seen + if rc == -1: + raise MemoryError("pruning tree") diff --git a/treeple/tree/meson.build b/treeple/tree/meson.build index 00bc8f728..a13f192af 100644 --- a/treeple/tree/meson.build +++ b/treeple/tree/meson.build @@ -14,6 +14,9 @@ tree_extension_metadata = { '_honest_prune': {'sources': ['_honest_prune.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']}, + '_prune': + {'sources': ['_prune.pyx'], + 'override_options': ['cython_language=cpp', 'optimization=3']}, '_marginal': {'sources': ['_marginal.pyx'], 'override_options': ['cython_language=cpp', 'optimization=3']},