Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding max_features argument #27

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benches/benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use rustrees::Dataset;
use rustrees::DecisionTree;

fn decision_tree_housing(train: &Dataset, test: &Dataset) {
let dt = DecisionTree::train_reg(train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(train, Some(5), Some(1), None, Some(42));
if train.n_samples() <= 1 {
let pred = dt.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
Expand Down Expand Up @@ -42,7 +42,7 @@ fn criterion_benchmark(c: &mut Criterion) {

// benchmark prediction
let pred_name = "predict_decision_tree_".to_string() + dataset;
let dt = DecisionTree::train_reg(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
c.bench_function(&pred_name, |b| {
b.iter(|| predict_decision_tree_housing(&dt, &test))
});
Expand Down
13 changes: 12 additions & 1 deletion python/rustrees/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,28 @@ class DecisionTree(BaseEstimator):
Options for regression and classification are available.
"""

def __init__(self, min_samples_leaf=1, max_depth: int = 10, random_state=None):
def __init__(
self,
min_samples_leaf=1,
max_depth: int = 10,
max_features: int = None,
random_state=None,
):
"""
Parameters
----------
min_samples_leaf : int, optional
The minimum number of samples required to be at a leaf node. The default is 1.
max_depth : int, optional
The maximum depth of the tree. The default is 10.
max_features: int, optional
The maximum number of features per split. Default is None, which means all features are considered.
random_state : int, optional
The seed used by the random number generator. The default is None.
"""
self.min_samples_leaf = min_samples_leaf
self.max_depth = max_depth
self.max_features = max_features
self.random_state = random_state

def fit(self, X, y):
Expand Down Expand Up @@ -84,6 +93,7 @@ def fit(self, X, y) -> "DecisionTreeRegressor":
dataset,
min_samples_leaf=self.min_samples_leaf,
max_depth=self.max_depth,
max_features=self.max_features,
random_state=self.random_state,
)
return self
Expand All @@ -103,6 +113,7 @@ def fit(self, X, y) -> "DecisionTreeClassifier":
dataset,
min_samples_leaf=self.min_samples_leaf,
max_depth=self.max_depth,
max_features=self.max_features,
random_state=self.random_state,
)
return self
Expand Down
22 changes: 14 additions & 8 deletions python/rustrees/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,30 @@ class RandomForest(BaseEstimator):

def __init__(
self,
n_estimators: int = 100,
min_samples_leaf=1,
max_depth: int = 10,
n_estimators: int = 100,
max_features: int = None,
random_state=None,
):
"""
Parameters
----------
n_estimators : int, optional
The number of trees in the forest. The default is 100.
min_samples_leaf : int, optional
The minimum number of samples required to be at a leaf node. The default is 1.
max_depth : int, optional
max_depth : int, optional
The maximum depth of the tree. The default is 10.
n_estimators : int, optional
The number of trees in the forest. The default is 100.
random_state : int, optional
max_features: int, optional
The maximum number of features per split. Default is None, which means all features are considered.
random_state : int, optional
The seed used by the random number generator. The default is None.
"""
self.n_estimators = n_estimators
self.min_samples_leaf = min_samples_leaf
self.max_depth = max_depth
self.n_estimators = n_estimators
self.max_features = max_features
self.random_state = random_state

def fit(self, X, y):
Expand Down Expand Up @@ -91,9 +95,10 @@ def fit(self, X, y) -> "RandomForestRegressor":
dataset = prepare_dataset(X, y)
self.forest = rt_dt.train_reg(
dataset,
n_estimators=self.n_estimators,
min_samples_leaf=self.min_samples_leaf,
max_depth=self.max_depth,
n_estimators=self.n_estimators,
max_features=self.max_features,
random_state=self.random_state,
)
return self
Expand All @@ -111,9 +116,10 @@ def fit(self, X, y) -> "RandomForestClassifier":
dataset = prepare_dataset(X, y)
self.forest = rt_dt.train_clf(
dataset,
n_estimators=self.n_estimators,
min_samples_leaf=self.min_samples_leaf,
max_depth=self.max_depth,
n_estimators=self.n_estimators,
max_features=self.max_features,
random_state=self.random_state,
)
return self
Expand Down
44 changes: 21 additions & 23 deletions src/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use pyo3::prelude::*;
use rand::{rngs::StdRng, Rng};
use std::fs;
use arrow::array::{Float32Array};
use arrow::record_batch::RecordBatch;
use arrow::array::Float32Array;
use arrow::compute::cast;
use arrow::csv;
use std::fs::File;
use arrow::datatypes::DataType;
use arrow::pyarrow::PyArrowConvert;
use arrow::record_batch::RecordBatch;
use pyo3::prelude::*;
use rand::{rngs::StdRng, Rng};
use std::fs;
use std::fs::File;

use pyo3::types::PyAny;

Expand All @@ -22,10 +22,9 @@ pub struct Dataset {
}

impl Dataset {

fn _from_pyarrow(df: &PyAny) -> Dataset {
let batch = RecordBatch::from_pyarrow(df).unwrap();

let feature_names = batch
.schema()
.fields()
Expand All @@ -41,22 +40,22 @@ impl Dataset {
feature_matrix: feature_matrix[0..feature_matrix.len() - 1].to_vec(),
target_name: feature_names.last().unwrap().to_string(),
target_vector: feature_matrix.last().unwrap().to_vec(),
}
}
}

fn _read_batch(batch: RecordBatch) -> Vec<Vec<f32>>{
fn _read_batch(batch: RecordBatch) -> Vec<Vec<f32>> {
batch
.columns()
.iter()
.map(|c| cast(c, &DataType::Float32).unwrap())
.map(|c| {
c.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.values()
.to_vec()
})
.collect::<Vec<_>>()
.columns()
.iter()
.map(|c| cast(c, &DataType::Float32).unwrap())
.map(|c| {
c.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.values()
.to_vec()
})
.collect::<Vec<_>>()
}

fn _read_csv(path: &str, sep: &str) -> Dataset {
Expand Down Expand Up @@ -100,7 +99,7 @@ impl Dataset {
feature_uniform: vec![false; self.feature_names.len()],
feature_matrix: vec![],
target_name: self.target_name.clone(),
target_vector: vec![]
target_vector: vec![],
}
}

Expand Down Expand Up @@ -179,7 +178,6 @@ mod tests {
target_vector: vec![1.0, 0.0],
};


assert_eq!(expected, got);
}

Expand Down
20 changes: 9 additions & 11 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod tests {
fn test_integration() {
let train = Dataset::read_csv("datasets/diabetes_train.csv", ",");
let test = Dataset::read_csv("datasets/diabetes_test.csv", ",");
let dt = DecisionTree::train_reg(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let mut pred = test.clone();
dt.predict(&mut pred);
assert_eq!(r2(&test.target_vector, &pred.target_vector) > 0.28, true);
Expand All @@ -21,7 +21,7 @@ mod tests {
#[test]
fn decision_tree_titanic() {
let (train, test) = read_train_test_dataset("titanic");
let dt = DecisionTree::train_clf(&train, 5, Some(1), Some(43));
let dt = DecisionTree::train_clf(&train, Some(5), Some(1), None, Some(43));
let pred = dt.predict(&test);
println!("Accuracy: {}", accuracy(&test.target_vector, &pred));
assert_greater_than(accuracy(&test.target_vector, &pred), 0.237);
Expand All @@ -30,7 +30,7 @@ mod tests {
#[test]
fn decision_tree_breast_cancer() {
let (train, test) = read_train_test_dataset("breast_cancer");
let dt = DecisionTree::train_clf(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_clf(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("Accuracy: {}", accuracy(&test.target_vector, &pred));
assert_greater_than(accuracy(&test.target_vector, &pred), 0.83);
Expand All @@ -39,7 +39,7 @@ mod tests {
#[test]
fn decision_tree_housing() {
let (train, test) = read_train_test_dataset("housing");
let dt = DecisionTree::train_reg(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.59);
Expand All @@ -48,7 +48,7 @@ mod tests {
#[test]
fn decision_tree_diabeties() {
let (train, test) = read_train_test_dataset("diabetes");
let dt = DecisionTree::train_reg(&train, 5, Some(1), Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.30);
Expand All @@ -64,11 +64,10 @@ mod tests {
(train, test)
}


#[test]
fn random_forest_diabetes() {
let (train, test) = read_train_test_dataset("diabetes");
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), Some(42));
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.38);
Expand All @@ -77,7 +76,7 @@ mod tests {
#[test]
fn random_forest_housing() {
let (train, test) = read_train_test_dataset("housing");
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), Some(42));
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.641);
Expand All @@ -86,19 +85,18 @@ mod tests {
#[test]
fn random_forest_breast_cancer() {
let (train, test) = read_train_test_dataset("breast_cancer");
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), Some(42));
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
let pred = classification_threshold(&pred, 0.5);

println!("Accuracy: {}", accuracy(&test.target_vector, &pred),);
assert_greater_than(accuracy(&test.target_vector, &pred), 0.96);

}

#[test]
fn random_forest_breast_titanic() {
let (train, test) = read_train_test_dataset("titanic");
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), Some(42));
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
let pred = classification_threshold(&pred, 0.5);

Expand Down
Loading