Skip to content

Commit

Permalink
Add doc comments to rust code (#28)
Browse files Browse the repository at this point in the history
* Add doc comments to rust code

* Fix doc comment example
  • Loading branch information
glazari authored Dec 6, 2023
1 parent d20f7d2 commit 3b9e071
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 18 deletions.
25 changes: 23 additions & 2 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,22 @@ use std::fs::File;

use pyo3::types::PyAny;

/// Dataset represents the data used to train the model.
///
/// Data is stored grouped by features (columns) and the target is stored separately.
#[pyclass]
#[derive(Clone, Debug, PartialEq)]
pub struct Dataset {
/// Keeping the feature_names allows for mapping feature_indexes to feature names
pub feature_names: Vec<String>,
/// feature_uniform is a vector of booleans that indicates whether a feature is uniform.
/// Uniform features are never useful for splitting so we can skip them. Also, once a feature
/// is uniform it will always be uniform for further splits.
pub feature_uniform: Vec<bool>,

/// feature_matrix is a vector of features, where each feature is a vector of values.
/// The algorithm oftenn iterates over the features, so it is more efficient to keep
/// them in a vector of features, rather than a vector of rows.
pub feature_matrix: Vec<Vec<f32>>,
pub target_name: String,
pub target_vector: Vec<f32>,
Expand Down Expand Up @@ -93,7 +104,7 @@ impl Dataset {
}
}

pub fn clone_without_data(&self) -> Dataset {
pub(crate) fn clone_without_data(&self) -> Dataset {
Dataset {
feature_names: self.feature_names.clone(),
feature_uniform: vec![false; self.feature_names.len()],
Expand All @@ -103,11 +114,12 @@ impl Dataset {
}
}

/// exposes the size of the dataset
pub fn n_samples(&self) -> usize {
self.target_vector.len()
}

pub fn bootstrap(&self, rng: &mut StdRng) -> Dataset {
pub(crate) fn bootstrap(&self, rng: &mut StdRng) -> Dataset {
let mut feature_matrix: Vec<Vec<f32>> = vec![vec![]; self.feature_names.len()];
let mut target_vector: Vec<f32> = Vec::new();

Expand All @@ -130,15 +142,19 @@ impl Dataset {
}
}

/// Methods exposed to the python binding.
#[pymethods]
impl Dataset {

/// Reads a CSV file and returns a Dataset.
#[staticmethod]
pub fn read_csv(path: &str, sep: &str) -> Dataset {
println!("Reading CSV file {}", path);
//let contents = fs::read_to_string(path).expect("Cannot read CSV file");
Self::_read_csv(path, sep)
}

/// Writes dataset to a CSV file.
pub fn write_csv(&self, path: &str, sep: &str) {
let mut contents: String = self.feature_names.join(sep) + sep + &self.target_name + "\n";

Expand All @@ -152,11 +168,16 @@ impl Dataset {
fs::write(path, contents).expect("Unable to write file");
}

/// Converts a pyarrow datafram to a Dataset.
/// This is used to convert a pandas dataframe to a Dataset.
#[staticmethod]
pub fn from_pyarrow(df: &PyAny) -> Dataset {
Self::_from_pyarrow(df)
}

/// Adds the target vector to the dataset.
/// A dataset always has the feature matrix, but will only have the target vector
/// when it is for training.
pub fn add_target(&mut self, target: Vec<f32>) {
self.target_vector = target;
}
Expand Down
29 changes: 28 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
//! Rustrees is a library for building decision trees and random forests.
//!
//! The goal is to provide a fast implementation of decision trees in rust, with a python API.
//!
//! Example usage:
//!
//! ```rust
//! use rustrees::{DecisionTree, Dataset, r2};
//!
//! let dataset = Dataset::read_csv("datasets/titanic_train.csv", ",");
//!
//! let dt = DecisionTree::train_reg(
//! &dataset,
//! Some(5), // max_depth
//! Some(1), // min_samples_leaf
//! None, // max_features (None = all features)
//! Some(42), // random_state
//! );
//!
//! let pred = dt.predict(&dataset);
//!
//! println!("r2 score: {}", r2(&dataset.target_vector, &pred));
//!
//! ```
//!

mod dataset;
mod split_criteria;
mod tests;
Expand All @@ -9,7 +36,7 @@ pub use trees::DecisionTree;
pub use trees::RandomForest;
pub use trees::TrainOptions;
pub use trees::Tree;
pub use utils::*;
pub use utils::{accuracy, r2};

use pyo3::prelude::*;

Expand Down
14 changes: 4 additions & 10 deletions src/split_criteria.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::utils;

/// A common interface for different split criteria.
pub(crate) type SplitFunction = fn(
col_index: usize,
feature_name: &str,
Expand All @@ -8,10 +9,6 @@ pub(crate) type SplitFunction = fn(
target: &[f32],
) -> SplitResult;

//pub(crate) trait SplitCriteria {
// fn split_feature(col_index: usize, feature: &[f32], target: &[f32]) -> SplitResult;
//}

#[derive(Debug, PartialEq)]
pub(crate) struct SplitResult {
pub(crate) col_index: usize,
Expand All @@ -35,9 +32,9 @@ impl SplitResult {
}
}

//pub(crate) struct MeanSquaredError;

//impl SplitCriteria for MeanSquaredError {
/// The split criteria used for regression problems. The mean squared error has a special form
/// that allows to compute the loss of all splits reusing most of the computation across splits.
pub(crate) fn mean_squared_error_split_feature(
col_index: usize,
feature_name: &str,
Expand Down Expand Up @@ -100,11 +97,9 @@ pub(crate) fn mean_squared_error_split_feature(
loss: min_mse,
}
}
//}

//pub(crate) struct GiniCoefficient;

//impl SplitCriteria for GiniCoefficient {
/// The split criteria used for classification problems.
pub(crate) fn gini_coefficient_split_feature(
col_index: usize,
feature_name: &str,
Expand Down Expand Up @@ -160,7 +155,6 @@ pub(crate) fn gini_coefficient_split_feature(
loss: min_gini,
}
}
//}

#[cfg(test)]
mod test {
Expand Down
1 change: 1 addition & 0 deletions src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(test)]
mod tests {
use crate::{trees::RandomForest, *};
use crate::{utils::classification_threshold, utils::r2, utils::accuracy};

fn assert_greater_than(a: f32, b: f32) {
if a <= b {
Expand Down
28 changes: 27 additions & 1 deletion src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,30 @@ use std::fmt::Formatter;

use pyo3::prelude::*;

/// Represents the decision tree model. Each node represents a split on a feature.
#[pyclass]
pub struct DecisionTree {
tree: Tree,
}

/// Represents the Random forest model. It is basically a collection of decision trees.
#[pyclass]
pub struct RandomForest {
trees: Vec<Tree>,
}

/// Possible options for training the model.
#[derive(Clone, Copy)]
pub struct TrainOptions {
/// Minimum number of samples required to be at a leaf node. default: 1
min_samples_leaf: i32,
/// Maximum depth of the tree. default: 10
max_depth: i32,
max_features: i32,
}

impl TrainOptions {
/// Returns the default options for training a model.
pub fn default_options(num_features: i32) -> TrainOptions {
TrainOptions {
max_depth: 10,
Expand All @@ -41,8 +47,12 @@ impl TrainOptions {
}
}

/// Methods for training and predicting with a random forest. These methods are exposed to python.
#[pymethods]
impl RandomForest {

/// Trains a random forest for regression.
/// A regression tree uses the mean squared error as the split criterion.
#[staticmethod]
pub fn train_reg(
train: &Dataset,
Expand Down Expand Up @@ -77,6 +87,8 @@ impl RandomForest {
RandomForest { trees }
}

/// Trains a random forest for classification problem.
/// A classification tree uses the gini coefficient as the split criterion.
#[staticmethod]
pub fn train_clf(
train: &Dataset,
Expand Down Expand Up @@ -116,6 +128,8 @@ impl RandomForest {
RandomForest { trees }
}

/// Predicts the target for a given dataset.
/// The prediction is the average of the predictions of each tree.
pub fn predict(&self, x: &Dataset) -> Vec<f32> {
let mut predictions = Vec::new();
for tree in &self.trees {
Expand All @@ -135,8 +149,13 @@ impl RandomForest {
}
}

/// Methods for training and predicting with a decision tree. These methods are exposed to python.
#[pymethods]
impl DecisionTree {


/// Trains a decision tree for regression.
/// A regression tree uses the mean squared error as the split criterion.
#[staticmethod]
pub fn train_reg(
train: &Dataset,
Expand Down Expand Up @@ -164,6 +183,8 @@ impl DecisionTree {
}
}

/// Trains a decision tree for classification problem.
/// A classification tree uses the gini coefficient as the split criterion.
#[staticmethod]
pub fn train_clf(
train: &Dataset,
Expand All @@ -184,6 +205,7 @@ impl DecisionTree {
}
}

/// Predicts the target for a given dataset.
pub fn predict(&self, test: &Dataset) -> Vec<f32> {
self.tree.predict(test)
}
Expand Down Expand Up @@ -243,6 +265,10 @@ fn should_stop(options: TrainOptions, depth: i32, ds: &Dataset) -> bool {
type NodeId = usize;
type FeatureIndex = usize;

/// An arena-based tree implementation. Each node is stored in a vector and the children are accessed by index.
///
/// Having all the nodes in a vector allows for a more cache-friendly implementation. And accessing
/// them by index allows to avoid borrow checker issues related to having recursive data structures.
pub struct Tree {
root: NodeId,
nodes: Vec<Node>,
Expand Down Expand Up @@ -362,7 +388,7 @@ impl Tree {
self.nodes.len() - 1
}

pub fn predict(&self, test: &Dataset) -> Vec<f32> {
fn predict(&self, test: &Dataset) -> Vec<f32> {
let feature_matrix = self.reindex_features(&test);

let mut predictions = Vec::with_capacity(test.n_samples());
Expand Down
14 changes: 10 additions & 4 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
use std::cmp::Ordering::Equal;

pub fn sort_two_vectors(a: &[f32], b: &[f32]) -> (Vec<f32>, Vec<f32>) {
/// Sorts two vectors by the first one. This is used to sort target by feature and it is at the
/// core of the decision tree algorithm.
pub(crate) fn sort_two_vectors(a: &[f32], b: &[f32]) -> (Vec<f32>, Vec<f32>) {
let a_sorter = permutation::sort_by(a, |a, b| a.partial_cmp(b).unwrap_or(Equal));

let a = a_sorter.apply_slice(a);
let b = a_sorter.apply_slice(b);
(a, b)
}

pub fn float_avg(x: &[f32]) -> f32 {
pub(crate) fn float_avg(x: &[f32]) -> f32 {
x.iter().sum::<f32>() / x.len() as f32
}

pub fn classification_threshold(x: &[f32], clf_threshold: f32) -> Vec<f32> {
/// computes the classification threshold for a given vector. This is used for testing the
#[cfg(test)]
pub(crate) fn classification_threshold(x: &[f32], clf_threshold: f32) -> Vec<f32> {
x.iter()
.map(|&x| if x >= clf_threshold { 1.0 } else { 0.0 })
.collect()
}

/// computes the mean squared error between two vectors used for testing regression case.
pub fn r2(x_true: &[f32], x_pred: &[f32]) -> f32 {
let mse: f32 = x_true
.iter()
Expand All @@ -31,6 +36,7 @@ pub fn r2(x_true: &[f32], x_pred: &[f32]) -> f32 {
1.0 - mse / var
}

/// computes the accuracy of a binary classification. Used for testing.
pub fn accuracy(x_true: &[f32], x_pred: &[f32]) -> f32 {
x_true
.iter()
Expand All @@ -40,7 +46,7 @@ pub fn accuracy(x_true: &[f32], x_pred: &[f32]) -> f32 {
/ x_true.len() as f32
}

pub fn get_rng(maybe_seed: Option<u64>, offset: u64) -> rand::rngs::StdRng {
pub(crate) fn get_rng(maybe_seed: Option<u64>, offset: u64) -> rand::rngs::StdRng {
match maybe_seed {
Some(seed) => rand::SeedableRng::seed_from_u64(seed + offset),
None => rand::SeedableRng::from_entropy(),
Expand Down

0 comments on commit 3b9e071

Please sign in to comment.