Skip to content

Commit

Permalink
Adding docstring
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <[email protected]>
  • Loading branch information
adam2392 committed Aug 8, 2024
1 parent 770cbe1 commit 9810ede
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions treeple/tree/_honest_prune.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,25 @@ def _build_pruned_tree_honesty(
const float64_t[:] sample_weight,
const uint8_t[::1] missing_values_in_feature_mask=None,
):
"""Prune an existing tree with honest splits.
Parameters
----------
tree : Tree
The tree to be pruned.
orig_tree : Tree
The original tree to be pruned.
pruner : HonestPruner
The pruner to enforce honest splits.
X : array-like of shape (n_samples, n_features)
The input samples.
y : array-like of shape (n_samples,)
The target values.
sample_weight : array-like of shape (n_samples,)
The sample weights.
missing_values_in_feature_mask : array-like of shape (n_features,)
The mask of missing values in the features.
"""
cdef:
intp_t n_nodes = orig_tree.node_count
uint8_t[:] leaves_in_subtree = np.zeros(
Expand Down Expand Up @@ -109,7 +128,7 @@ cdef class HonestPruner(Splitter):
self,
intp_t node_idx,
) noexcept nogil:
"""Partition samples for X at the threshold and feature index.
"""Partition samples for X at the threshold and feature index of `orig_tree`.
If missing values are present, this method partitions `samples`
so that the `best_n_missing` missing values' indices are in the
Expand Down Expand Up @@ -169,19 +188,13 @@ cdef class HonestPruner(Splitter):
current_split.n_missing = self.n_missing
current_split.missing_go_to_left = self.missing_go_to_left

# with gil:
# print('Inside check node partitions conditions')
# print(self.start, self.pos, self.end, self.n_missing, current_split.feature)

# first check the presplit conditions
cdef bint invalid_split = self.check_presplit_conditions(
current_split,
self.n_missing,
self.missing_go_to_left
)

# with gil:
# print('invalid presplit? ', invalid_split)
if invalid_split:
return 0

Expand All @@ -197,8 +210,6 @@ cdef class HonestPruner(Splitter):
upper_bound,
)
):
# with gil:
# print('No monotonic cst met ', invalid_split)
return 0

# Note this is called after pre-split condition checks
Expand All @@ -212,8 +223,6 @@ cdef class HonestPruner(Splitter):

# next check the postsplit conditions that leverages the criterion
invalid_split = self.check_postsplit_conditions()
# with gil:
# print('Invalid postsplit ', invalid_split)
return invalid_split

cdef inline intp_t n_left_samples(
Expand Down Expand Up @@ -343,7 +352,6 @@ cdef _honest_prune(
pruning_stack.pop()
start = stack_record.start
end = stack_record.end

parent_record.impurity = stack_record.impurity
parent_record.lower_bound = stack_record.lower_bound
parent_record.upper_bound = stack_record.upper_bound
Expand Down Expand Up @@ -376,11 +384,6 @@ cdef _honest_prune(
pruner.n_left_samples() == 0 or pruner.n_right_samples() == 0
)
is_leaf_in_origtree = child_l[node_idx] == _TREE_LEAF
# with gil:
# print(f"Node {node_idx} is leaf in orig_tree: {is_leaf_in_origtree}")
# print(f"is degenerate: {split_is_degenerate}")
# print(f"invalid split: {invalid_split}")

if invalid_split or split_is_degenerate or is_leaf_in_origtree:
# ... and child_r[node_idx] == _TREE_LEAF:
#
Expand Down Expand Up @@ -426,7 +429,6 @@ cdef _honest_prune(
"node_idx": child_l[node_idx],
"start": pruner.start,
"end": pruner.pos,
# "parent": node_idx,
"impurity": split_ptr.impurity_left,
"lower_bound": left_child_min,
"upper_bound": left_child_max,
Expand All @@ -435,7 +437,6 @@ cdef _honest_prune(
"node_idx": child_r[node_idx],
"start": pruner.pos,
"end": pruner.end,
# "parent": node_idx,
"impurity": split_ptr.impurity_right,
"lower_bound": right_child_min,
"upper_bound": right_child_max,
Expand Down

0 comments on commit 9810ede

Please sign in to comment.