Skip to content

Commit

Permalink
Merge pull request #27 from tabacof/max_features
Browse files Browse the repository at this point in the history
Adding max_features argument
  • Loading branch information
tabacof authored Nov 22, 2023
2 parents 3266824 + 7f9ca05 commit d20f7d2
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 64 deletions.
4 changes: 2 additions & 2 deletions benches/benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rustrees::Dataset;
use rustrees::DecisionTree;

fn decision_tree_housing(train: &Dataset, test: &Dataset) {
let dt = DecisionTree::train_reg(train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(train, Some(5), Some(1), None, Some(42));
if train.n_samples() <= 1 {
let pred = dt.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
Expand Down Expand Up @@ -42,7 +42,7 @@ fn criterion_benchmark(c: &mut Criterion) {

// benchmark prediction
let pred_name = "predict_decision_tree_".to_string() + dataset;
let dt = DecisionTree::train_reg(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
c.bench_function(&pred_name, |b| {
b.iter(|| predict_decision_tree_housing(&dt, &test))
});
Expand Down
13 changes: 12 additions & 1 deletion python/rustrees/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,28 @@ class DecisionTree(BaseEstimator):
Options for regression and classification are available.
"""

def __init__(self, min_samples_leaf=1, max_depth: int = 10, random_state=None):
def __init__(
self,
min_samples_leaf=1,
max_depth: int = 10,
max_features: int = None,
random_state=None,
):
"""
Parameters
----------
min_samples_leaf : int, optional
The minimum number of samples required to be at a leaf node. The default is 1.
max_depth : int, optional
The maximum depth of the tree. The default is 10.
max_features: int, optional
The maximum number of features per split. Default is None, which means all features are considered.
random_state : int, optional
The seed used by the random number generator. The default is None.
"""
self.min_samples_leaf = min_samples_leaf
self.max_depth = max_depth
self.max_features = max_features
self.random_state = random_state

def fit(self, X, y):
Expand Down Expand Up @@ -84,6 +93,7 @@ def fit(self, X, y) -> "DecisionTreeRegressor":
dataset,
min_samples_leaf=self.min_samples_leaf,
max_depth=self.max_depth,
max_features=self.max_features,
random_state=self.random_state,
)
return self
Expand All @@ -103,6 +113,7 @@ def fit(self, X, y) -> "DecisionTreeClassifier":
dataset,
min_samples_leaf=self.min_samples_leaf,
max_depth=self.max_depth,
max_features=self.max_features,
random_state=self.random_state,
)
return self
Expand Down
22 changes: 14 additions & 8 deletions python/rustrees/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,30 @@ class RandomForest(BaseEstimator):

def __init__(
self,
n_estimators: int = 100,
min_samples_leaf=1,
max_depth: int = 10,
n_estimators: int = 100,
max_features: int = None,
random_state=None,
):
"""
Parameters
----------
n_estimators : int, optional
The number of trees in the forest. The default is 100.
min_samples_leaf : int, optional
The minimum number of samples required to be at a leaf node. The default is 1.
max_depth : int, optional
max_depth : int, optional
The maximum depth of the tree. The default is 10.
n_estimators : int, optional
The number of trees in the forest. The default is 100.
random_state : int, optional
max_features: int, optional
The maximum number of features per split. Default is None, which means all features are considered.
random_state : int, optional
The seed used by the random number generator. The default is None.
"""
self.n_estimators = n_estimators
self.min_samples_leaf = min_samples_leaf
self.max_depth = max_depth
self.n_estimators = n_estimators
self.max_features = max_features
self.random_state = random_state

def fit(self, X, y):
Expand Down Expand Up @@ -91,9 +95,10 @@ def fit(self, X, y) -> "RandomForestRegressor":
dataset = prepare_dataset(X, y)
self.forest = rt_dt.train_reg(
dataset,
n_estimators=self.n_estimators,
min_samples_leaf=self.min_samples_leaf,
max_depth=self.max_depth,
n_estimators=self.n_estimators,
max_features=self.max_features,
random_state=self.random_state,
)
return self
Expand All @@ -111,9 +116,10 @@ def fit(self, X, y) -> "RandomForestClassifier":
dataset = prepare_dataset(X, y)
self.forest = rt_dt.train_clf(
dataset,
n_estimators=self.n_estimators,
min_samples_leaf=self.min_samples_leaf,
max_depth=self.max_depth,
n_estimators=self.n_estimators,
max_features=self.max_features,
random_state=self.random_state,
)
return self
Expand Down
44 changes: 21 additions & 23 deletions src/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use pyo3::prelude::*;
use rand::{rngs::StdRng, Rng};
use std::fs;
use arrow::array::{Float32Array};
use arrow::record_batch::RecordBatch;
use arrow::array::Float32Array;
use arrow::compute::cast;
use arrow::csv;
use std::fs::File;
use arrow::datatypes::DataType;
use arrow::pyarrow::PyArrowConvert;
use arrow::record_batch::RecordBatch;
use pyo3::prelude::*;
use rand::{rngs::StdRng, Rng};
use std::fs;
use std::fs::File;

use pyo3::types::PyAny;

Expand All @@ -22,10 +22,9 @@ pub struct Dataset {
}

impl Dataset {

fn _from_pyarrow(df: &PyAny) -> Dataset {
let batch = RecordBatch::from_pyarrow(df).unwrap();

let feature_names = batch
.schema()
.fields()
Expand All @@ -41,22 +40,22 @@ impl Dataset {
feature_matrix: feature_matrix[0..feature_matrix.len() - 1].to_vec(),
target_name: feature_names.last().unwrap().to_string(),
target_vector: feature_matrix.last().unwrap().to_vec(),
}
}
}

fn _read_batch(batch: RecordBatch) -> Vec<Vec<f32>>{
fn _read_batch(batch: RecordBatch) -> Vec<Vec<f32>> {
batch
.columns()
.iter()
.map(|c| cast(c, &DataType::Float32).unwrap())
.map(|c| {
c.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.values()
.to_vec()
})
.collect::<Vec<_>>()
.columns()
.iter()
.map(|c| cast(c, &DataType::Float32).unwrap())
.map(|c| {
c.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.values()
.to_vec()
})
.collect::<Vec<_>>()
}

fn _read_csv(path: &str, sep: &str) -> Dataset {
Expand Down Expand Up @@ -100,7 +99,7 @@ impl Dataset {
feature_uniform: vec![false; self.feature_names.len()],
feature_matrix: vec![],
target_name: self.target_name.clone(),
target_vector: vec![]
target_vector: vec![],
}
}

Expand Down Expand Up @@ -179,7 +178,6 @@ mod tests {
target_vector: vec![1.0, 0.0],
};


assert_eq!(expected, got);
}

Expand Down
20 changes: 9 additions & 11 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod tests {
fn test_integration() {
let train = Dataset::read_csv("datasets/diabetes_train.csv", ",");
let test = Dataset::read_csv("datasets/diabetes_test.csv", ",");
let dt = DecisionTree::train_reg(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let mut pred = test.clone();
dt.predict(&mut pred);
assert_eq!(r2(&test.target_vector, &pred.target_vector) > 0.28, true);
Expand All @@ -21,7 +21,7 @@ mod tests {
#[test]
fn decision_tree_titanic() {
let (train, test) = read_train_test_dataset("titanic");
let dt = DecisionTree::train_clf(&train, 5, Some(1), Some(43));
let dt = DecisionTree::train_clf(&train, Some(5), Some(1), None, Some(43));
let pred = dt.predict(&test);
println!("Accuracy: {}", accuracy(&test.target_vector, &pred));
assert_greater_than(accuracy(&test.target_vector, &pred), 0.237);
Expand All @@ -30,7 +30,7 @@ mod tests {
#[test]
fn decision_tree_breast_cancer() {
let (train, test) = read_train_test_dataset("breast_cancer");
let dt = DecisionTree::train_clf(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_clf(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("Accuracy: {}", accuracy(&test.target_vector, &pred));
assert_greater_than(accuracy(&test.target_vector, &pred), 0.83);
Expand All @@ -39,7 +39,7 @@ mod tests {
#[test]
fn decision_tree_housing() {
let (train, test) = read_train_test_dataset("housing");
let dt = DecisionTree::train_reg(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.59);
Expand All @@ -48,7 +48,7 @@ mod tests {
#[test]
fn decision_tree_diabeties() {
let (train, test) = read_train_test_dataset("diabetes");
let dt = DecisionTree::train_reg(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.30);
Expand All @@ -64,11 +64,10 @@ mod tests {
(train, test)
}


#[test]
fn random_forest_diabetes() {
let (train, test) = read_train_test_dataset("diabetes");
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), Some(42));
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.38);
Expand All @@ -77,7 +76,7 @@ mod tests {
#[test]
fn random_forest_housing() {
let (train, test) = read_train_test_dataset("housing");
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), Some(42));
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.641);
Expand All @@ -86,19 +85,18 @@ mod tests {
#[test]
fn random_forest_breast_cancer() {
let (train, test) = read_train_test_dataset("breast_cancer");
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), Some(42));
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
let pred = classification_threshold(&pred, 0.5);

println!("Accuracy: {}", accuracy(&test.target_vector, &pred),);
assert_greater_than(accuracy(&test.target_vector, &pred), 0.96);

}

#[test]
fn random_forest_breast_titanic() {
let (train, test) = read_train_test_dataset("titanic");
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), Some(42));
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
let pred = classification_threshold(&pred, 0.5);

Expand Down
Loading

0 comments on commit d20f7d2

Please sign in to comment.