diff --git a/src/lib.rs b/src/lib.rs index 47f4a30..8fdbd2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,8 @@ mod utils; pub use dataset::Dataset; pub use trees::DecisionTree; pub use trees::RandomForest; +pub use trees::TrainOptions; +pub use trees::Tree; pub use utils::*; use pyo3::prelude::*; diff --git a/src/trees.rs b/src/trees.rs index c79926f..bec7536 100644 --- a/src/trees.rs +++ b/src/trees.rs @@ -9,45 +9,32 @@ use rand::seq::SliceRandom; use rand::SeedableRng; use rayon::prelude::*; use std::cmp::Ordering::Equal; -use std::collections::HashMap; +use std::fmt::Debug; +use std::fmt::Formatter; use pyo3::prelude::*; -#[derive(Debug)] -pub struct TreeNode { - split: Option, - prediction: f32, - samples: usize, - feature_name: Option, - left: Option>, - right: Option>, -} - #[pyclass] pub struct DecisionTree { - root: TreeNode, - params: TrainOptions, + tree: Tree, } #[pyclass] pub struct RandomForest { - roots: Vec, - params: TrainOptions, + trees: Vec, } #[derive(Clone, Copy)] pub struct TrainOptions { min_samples_leaf: i32, max_depth: i32, - n_estimators: Option, } impl TrainOptions { - fn default_options() -> TrainOptions { + pub fn default_options() -> TrainOptions { TrainOptions { max_depth: 10, min_samples_leaf: 1, - n_estimators: None, } } } @@ -66,21 +53,14 @@ impl RandomForest { max_depth: max_depth.unwrap_or(TrainOptions::default_options().max_depth), min_samples_leaf: min_samples_leaf .unwrap_or(TrainOptions::default_options().min_samples_leaf), - n_estimators: Some(n_estimators), }; - let roots: Vec = (0..n_estimators) + let trees: Vec = (0..n_estimators) .into_par_iter() .map(|i| { - let mut rng; - if let Some(seed) = random_state { - rng = StdRng::seed_from_u64(seed + i as u64); - } else { - rng = StdRng::from_entropy(); - } - + let mut rng = utils::get_rng(random_state, i as u64); let bootstrap = train.bootstrap(&mut rng); - TreeNode::_train( + Tree::fit( &bootstrap, 0, params, @@ -90,7 +70,7 @@ impl RandomForest { }) .collect(); - RandomForest { roots, params } + RandomForest { trees } } #[staticmethod] @@ -105,10 +85,9 @@ impl RandomForest { max_depth: max_depth.unwrap_or(TrainOptions::default_options().max_depth), min_samples_leaf: min_samples_leaf .unwrap_or(TrainOptions::default_options().min_samples_leaf), - n_estimators: Some(n_estimators), }; - let roots: Vec = (0..n_estimators) + let trees: Vec = (0..n_estimators) .into_par_iter() .map(|i| { let mut rng; @@ -119,7 +98,7 @@ impl RandomForest { } let bootstrap = train.bootstrap(&mut rng); - TreeNode::_train( + Tree::fit( &bootstrap, 0, params, @@ -129,14 +108,13 @@ impl RandomForest { }) .collect(); - RandomForest { roots, params } + RandomForest { trees } } pub fn predict(&self, x: &Dataset) -> Vec { let mut predictions = Vec::new(); - for root in &self.roots { - let new_tree = NewTree::from_old_tree(&root, x.feature_names.clone()); - predictions.push(new_tree.predict(x)); + for tree in &self.trees { + predictions.push(tree.predict(x)); } let mut final_predictions = vec![0.0; x.n_samples()]; @@ -152,6 +130,8 @@ impl RandomForest { } } + + #[pymethods] impl DecisionTree { #[staticmethod] @@ -161,22 +141,20 @@ impl DecisionTree { min_samples_leaf: Option, random_state: Option, ) -> DecisionTree { - let mut rng = utils::get_rng(random_state); + let mut rng = utils::get_rng(random_state, 0); let params = TrainOptions { max_depth, min_samples_leaf: min_samples_leaf .unwrap_or(TrainOptions::default_options().min_samples_leaf), - n_estimators: None, }; DecisionTree { - root: TreeNode::_train( + tree: Tree::fit( &train, 0, params, mean_squared_error_split_feature, &mut rng, ), - params, } } @@ -187,138 +165,19 @@ impl DecisionTree { min_samples_leaf: Option, random_state: Option, ) -> DecisionTree { - let mut rng = utils::get_rng(random_state); + let mut rng = utils::get_rng(random_state, 0); let params = TrainOptions { max_depth, min_samples_leaf: min_samples_leaf .unwrap_or(TrainOptions::default_options().min_samples_leaf), - n_estimators: None, }; DecisionTree { - root: TreeNode::_train(&train, 0, params, gini_coefficient_split_feature, &mut rng), - params, + tree: Tree::fit(&train, 0, params, gini_coefficient_split_feature, &mut rng), } } pub fn predict(&self, test: &Dataset) -> Vec { - let index_tree = NewTree::from_old_tree(&self.root, test.feature_names.clone()); - index_tree.predict(test) - } -} - -impl TreeNode { - fn new_leaf(prediction: f32, n_samples: usize) -> TreeNode { - TreeNode { - prediction, - samples: n_samples, - split: None, - feature_name: None, - left: None, - right: None, - } - } - - fn new_from_split( - left: TreeNode, - right: TreeNode, - split: SplitResult, - feature_name: &str, - ) -> TreeNode { - TreeNode { - prediction: split.prediction, - samples: left.samples + right.samples, - split: Some(split.split), - feature_name: Some(feature_name.to_string()), - left: Some(Box::new(left)), - right: Some(Box::new(right)), - } - } - - fn _train( - train: &Dataset, - depth: i32, - train_options: TrainOptions, - split_feature: SplitFunction, - rng: &mut StdRng, - ) -> TreeNode { - if should_stop(train_options, depth, train) { - return TreeNode::new_leaf(utils::float_avg(&train.target_vector), train.n_samples()); - } - - let mut best_feature = SplitResult::new_max_loss(); - let mut feature_indexes = (0..train.feature_names.len()).collect::>(); - feature_indexes.shuffle(rng); - - for i in feature_indexes { - if train.feature_uniform[i] { - continue; - } - - let split = split_feature( - i, - &train.feature_names[i], - train_options.min_samples_leaf, - &train.feature_matrix[i], - &train.target_vector, - ); - - if split.loss < best_feature.loss { - best_feature = split; - } - } - - let (left_ds, right_ds) = split_dataset(&best_feature, train); - - let left_child = TreeNode::_train(&left_ds, depth + 1, train_options, split_feature, rng); - let right_child = TreeNode::_train(&right_ds, depth + 1, train_options, split_feature, rng); - - let name = &train.feature_names[best_feature.col_index]; - TreeNode::new_from_split(left_child, right_child, best_feature, name) - } - - pub fn predict_row(&self, row: &HashMap<&String, f32>) -> f32 { - if let Some(feature) = &self.feature_name { - if *row.get(&feature).unwrap() >= self.split.unwrap() { - self.right - .as_ref() - .expect("Right node expected") - .predict_row(row) - } else { - self.left - .as_ref() - .expect("Left node expected") - .predict_row(row) - } - } else { - self.prediction - } - } - - fn print(&self, depth: usize) { - match &self.feature_name { - None => { - println!( - "{:indent$}|-> Leaf: pred: {}, N: {}", - "", - self.prediction, - self.samples, - indent = depth * 4 - ) - } - Some(f) => { - println!( - "{:indent$}-> Branch: feat: {}, th: {}, N: {}, pred: {}", - "", - f, - self.split.unwrap(), - self.samples, - self.prediction, - indent = depth * 4 - ); - self.left.as_ref().unwrap().print(depth + 1); - self.right.as_ref().unwrap().print(depth + 1); - } - } + self.tree.predict(test) } } @@ -371,27 +230,30 @@ fn should_stop(options: TrainOptions, depth: i32, ds: &Dataset) -> bool { } // ------------------------------------- -// New tree test +// Base Tree implementation type NodeId = usize; type FeatureIndex = usize; -struct NewTree { +pub struct Tree { root: NodeId, - nodes: Vec, + nodes: Vec, feature_names: Vec, } -enum NewNode { +#[derive(Debug)] +enum Node { Leaf(Leaf), Branch(Branch), } +#[derive(PartialEq, Debug)] struct Leaf { prediction: f32, samples: usize, } +#[derive(PartialEq, Debug)] struct Branch { feature: FeatureIndex, threshold: f32, @@ -401,6 +263,18 @@ struct Branch { prediction: f32, } +impl Node { + fn new_leaf(prediction: f32, samples: usize) -> Self { + Node::Leaf(Leaf::new(prediction, samples)) + } + fn samples(&self) -> usize { + match self { + Node::Leaf(leaf) => leaf.samples, + Node::Branch(branch) => branch.samples, + } + } +} + impl Leaf { fn new(prediction: f32, samples: usize) -> Self { Leaf { @@ -410,59 +284,34 @@ impl Leaf { } } -impl NewTree { - fn new(feature_names: Vec) -> Self { - NewTree { +impl Tree { + fn new>(feature_names: Vec) -> Self { + Tree { root: 0, nodes: Vec::new(), - feature_names, + feature_names: feature_names.into_iter().map(|x| x.into()).collect(), } } - fn from_old_tree(root: &TreeNode, feature_names: Vec) -> Self { - let mut tree = NewTree::new(feature_names); - tree.root = tree.new_node_from_old(root); - tree - } - - fn new_node_from_old(&mut self, old: &TreeNode) -> NodeId { - let node = match old { - TreeNode { - split: None, - prediction, - samples, - .. - } => NewNode::Leaf(Leaf::new(*prediction, *samples)), - TreeNode { - split: Some(threshold), - prediction, - samples, - feature_name: Some(feature_name), - left: Some(left), - right: Some(right), - } => { - let feature = match self.feature_names.iter().position(|x| x == feature_name) { - Some(i) => i, - None => { - self.feature_names.push(feature_name.clone()); - self.feature_names.len() - 1 - } - }; - let left = self.new_node_from_old(&*left); - let right = self.new_node_from_old(&*right); - NewNode::Branch(Branch { - feature, - threshold: *threshold, - left, - right, - samples: *samples, - prediction: *prediction, - }) - } - _ => panic!("Invalid Node, either leaf or branch with children expected"), - }; - - self.add_node(node) + fn new_from_split( + &self, + left: NodeId, + right: NodeId, + split: SplitResult, + feature_name: &str, + ) -> Node { + Node::Branch(Branch { + prediction: split.prediction, + samples: self.nodes[left].samples() + self.nodes[right].samples(), + threshold: split.split, + feature: self + .feature_names + .iter() + .position(|x| x == feature_name) + .unwrap(), + left, + right, + }) } fn print(&self) { @@ -471,7 +320,7 @@ impl NewTree { fn print_node(&self, node: NodeId, depth: usize) { match &self.nodes[node] { - NewNode::Leaf(l) => { + Node::Leaf(l) => { println!( "{:indent$}|-> Leaf: pred: {}, N: {}", "", @@ -480,7 +329,7 @@ impl NewTree { indent = depth * 4 ); } - NewNode::Branch(b) => { + Node::Branch(b) => { println!( "{:indent$}-> Branch: feat: {}, th: {}, N: {}, pred: {}", "", @@ -496,28 +345,29 @@ impl NewTree { } } - fn add_root(&mut self, node: NewNode) { - self.nodes.push(node); - self.root = self.nodes.len() - 1; + fn set_root(&mut self, node_id: NodeId) { + self.root = node_id; } - fn add_node(&mut self, node: NewNode) -> NodeId { + fn add_node(&mut self, node: Node) -> NodeId { self.nodes.push(node); self.nodes.len() - 1 } - fn predict(&self, test: &Dataset) -> Vec { + pub fn predict(&self, test: &Dataset) -> Vec { + let feature_matrix = self.reindex_features(&test); + let mut predictions = Vec::with_capacity(test.n_samples()); let mut nodes: Vec = Vec::new(); for i in 0..test.n_samples() { nodes.push(self.root); while let Some(node) = nodes.pop() { match &self.nodes[node] { - NewNode::Leaf(l) => { + Node::Leaf(l) => { predictions.push(l.prediction); } - NewNode::Branch(b) => { - if test.feature_matrix[b.feature][i] < b.threshold { + Node::Branch(b) => { + if feature_matrix[b.feature][i] < b.threshold { nodes.push(b.left); } else { nodes.push(b.right); @@ -529,6 +379,125 @@ impl NewTree { } predictions } + + fn reindex_features<'a>(&self, ds: &'a Dataset) -> Vec<&'a Vec> { + let mut feature_indexes = Vec::with_capacity(self.feature_names.len()); + for feature in &self.feature_names { + let index = ds.feature_names.iter().position(|x| x == feature); + match index { + Some(index) => feature_indexes.push(index), + None => panic!("Feature {} not found in tree", feature), + }; + } + + let mut feature_matrix = Vec::with_capacity(self.feature_names.len()); + for i in 0..self.feature_names.len() { + feature_matrix.push(&ds.feature_matrix[feature_indexes[i]]); + } + feature_matrix + } + + fn fit( + train: &Dataset, + depth: i32, + train_options: TrainOptions, + split_feature: SplitFunction, + rng: &mut StdRng, + ) -> Self { + let mut tree = Tree::new(train.feature_names.clone()); + let root = tree.fit_node(train, depth, train_options, split_feature, rng); + tree.set_root(root); + tree + } + + fn fit_node( + &mut self, + train: &Dataset, + depth: i32, + train_options: TrainOptions, + split_feature: SplitFunction, + rng: &mut StdRng, + ) -> NodeId { + if should_stop(train_options, depth, train) { + let leaf = self.add_node(Node::new_leaf( + utils::float_avg(&train.target_vector), + train.n_samples(), + )); + return leaf; + } + + let mut best_feature = SplitResult::new_max_loss(); + let mut feature_indexes = (0..train.feature_names.len()).collect::>(); + feature_indexes.shuffle(rng); + + for i in feature_indexes { + if train.feature_uniform[i] { + continue; + } + + let split = split_feature( + i, + &train.feature_names[i], + train_options.min_samples_leaf, + &train.feature_matrix[i], + &train.target_vector, + ); + + if split.loss < best_feature.loss { + best_feature = split; + } + } + + let (left_ds, right_ds) = split_dataset(&best_feature, train); + + let left_child = self.fit_node(&left_ds, depth + 1, train_options, split_feature, rng); + let right_child = self.fit_node(&right_ds, depth + 1, train_options, split_feature, rng); + + let name = &train.feature_names[best_feature.col_index]; + let node = self.new_from_split(left_child, right_child, best_feature, name); + let node_id = self.add_node(node); + node_id + } +} + +impl Debug for Tree { + fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { + self.print(); + Ok(()) + } +} + +impl PartialEq for Tree { + fn eq(&self, other: &Self) -> bool { + let mut nodes_self = vec![self.root]; + let mut nodes_other = vec![other.root]; + + while let Some(node) = nodes_self.pop() { + let other_n = nodes_other.pop(); + if other_n.is_none() { + return false; + } + let o = other_n.unwrap(); + match &self.nodes[node] { + Node::Leaf(l) => match &other.nodes[o] { + Node::Leaf(l2) if l2 == l => { + continue; + } + _ => return false, + }, + Node::Branch(b) => match &other.nodes[o] { + Node::Branch(b2) if b2 == b => { + nodes_self.push(b.left); + nodes_self.push(b.right); + nodes_other.push(b2.left); + nodes_other.push(b2.right); + } + _ => return false, + }, + } + } + return true; + } } // ------------------------------------- @@ -541,69 +510,23 @@ mod test { fn test_predict() { let dataset = Dataset::read_csv("datasets/toy_test.csv", ";"); - let root = TreeNode { - split: Some(2.), - prediction: 0.5, - samples: 2, - feature_name: Some("feature_a".to_string()), - left: Some(Box::new(TreeNode::new_leaf(1., 1))), - right: Some(Box::new(TreeNode::new_leaf(0., 1))), - }; - - let dt = DecisionTree { - root, - params: TrainOptions { - max_depth: 1, - min_samples_leaf: 1, - n_estimators: None, - }, - }; - - let expected = Dataset::read_csv("datasets/toy_test_predict.csv", ";"); - let pred = dt.predict(&dataset); - assert_eq!(expected.target_vector, pred); - } - - #[test] - fn test_new_predict() { - let dataset = Dataset::read_csv("datasets/toy_test.csv", ";"); - - let root = TreeNode { - split: Some(2.), + let mut tree = Tree::new(vec!["feature_a"]); + let left = tree.add_node(Node::new_leaf(1., 1)); + let right = tree.add_node(Node::new_leaf(0., 1)); + let root = tree.add_node(Node::Branch(Branch { + feature: 0, + threshold: 2., prediction: 0.5, samples: 2, - feature_name: Some("feature_a".to_string()), - left: Some(Box::new(TreeNode::new_leaf(1., 1))), - right: Some(Box::new(TreeNode::new_leaf(0., 1))), - }; + left, + right, + })); + tree.set_root(root); - let dt = DecisionTree { - root, - params: TrainOptions { - max_depth: 1, - min_samples_leaf: 1, - n_estimators: None, - }, - }; + let dt = DecisionTree { tree }; let expected = Dataset::read_csv("datasets/toy_test_predict.csv", ";"); let pred = dt.predict(&dataset); assert_eq!(expected.target_vector, pred); - - let new_tree = NewTree::from_old_tree(&dt.root, dataset.feature_names.clone()); - let new_predictions = new_tree.predict(&dataset); - assert_eq!(pred, new_predictions); - } - - #[test] - fn print_trees() { - let dataset = Dataset::read_csv("datasets/titanic_train.csv", ","); - let dt = DecisionTree::train_reg(&dataset, 2, None, None); - println!("Old Tree"); - dt.root.print(0); - - let new_tree = NewTree::from_old_tree(&dt.root, dataset.feature_names.clone()); - println!("\nNew Tree"); - new_tree.print(); } } diff --git a/src/utils.rs b/src/utils.rs index 5ef0487..7ce17e9 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -40,9 +40,9 @@ pub fn accuracy(x_true: &[f32], x_pred: &[f32]) -> f32 { / x_true.len() as f32 } -pub fn get_rng(maybe_seed: Option) -> rand::rngs::StdRng { +pub fn get_rng(maybe_seed: Option, offset: u64) -> rand::rngs::StdRng { match maybe_seed { - Some(seed) => rand::SeedableRng::seed_from_u64(seed), + Some(seed) => rand::SeedableRng::seed_from_u64(seed + offset), None => rand::SeedableRng::from_entropy(), } }