Skip to content

Commit

Permalink
cargo fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
tabacof committed Nov 22, 2023
1 parent 7174273 commit 7f9ca05
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 47 deletions.
44 changes: 21 additions & 23 deletions src/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use pyo3::prelude::*;
use rand::{rngs::StdRng, Rng};
use std::fs;
use arrow::array::{Float32Array};
use arrow::record_batch::RecordBatch;
use arrow::array::Float32Array;
use arrow::compute::cast;
use arrow::csv;
use std::fs::File;
use arrow::datatypes::DataType;
use arrow::pyarrow::PyArrowConvert;
use arrow::record_batch::RecordBatch;
use pyo3::prelude::*;
use rand::{rngs::StdRng, Rng};
use std::fs;
use std::fs::File;

use pyo3::types::PyAny;

Expand All @@ -22,10 +22,9 @@ pub struct Dataset {
}

impl Dataset {

fn _from_pyarrow(df: &PyAny) -> Dataset {
let batch = RecordBatch::from_pyarrow(df).unwrap();

let feature_names = batch
.schema()
.fields()
Expand All @@ -41,22 +40,22 @@ impl Dataset {
feature_matrix: feature_matrix[0..feature_matrix.len() - 1].to_vec(),
target_name: feature_names.last().unwrap().to_string(),
target_vector: feature_matrix.last().unwrap().to_vec(),
}
}
}

fn _read_batch(batch: RecordBatch) -> Vec<Vec<f32>>{
fn _read_batch(batch: RecordBatch) -> Vec<Vec<f32>> {
batch
.columns()
.iter()
.map(|c| cast(c, &DataType::Float32).unwrap())
.map(|c| {
c.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.values()
.to_vec()
})
.collect::<Vec<_>>()
.columns()
.iter()
.map(|c| cast(c, &DataType::Float32).unwrap())
.map(|c| {
c.as_any()
.downcast_ref::<Float32Array>()
.unwrap()
.values()
.to_vec()
})
.collect::<Vec<_>>()
}

fn _read_csv(path: &str, sep: &str) -> Dataset {
Expand Down Expand Up @@ -100,7 +99,7 @@ impl Dataset {
feature_uniform: vec![false; self.feature_names.len()],
feature_matrix: vec![],
target_name: self.target_name.clone(),
target_vector: vec![]
target_vector: vec![],
}
}

Expand Down Expand Up @@ -179,7 +178,6 @@ mod tests {
target_vector: vec![1.0, 0.0],
};


assert_eq!(expected, got);
}

Expand Down
18 changes: 8 additions & 10 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod tests {
fn test_integration() {
let train = Dataset::read_csv("datasets/diabetes_train.csv", ",");
let test = Dataset::read_csv("datasets/diabetes_test.csv", ",");
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let mut pred = test.clone();
dt.predict(&mut pred);
assert_eq!(r2(&test.target_vector, &pred.target_vector) > 0.28, true);
Expand All @@ -30,7 +30,7 @@ mod tests {
#[test]
fn decision_tree_breast_cancer() {
let (train, test) = read_train_test_dataset("breast_cancer");
let dt = DecisionTree::train_clf(&train, Some(5), Some(1), None,Some(42));
let dt = DecisionTree::train_clf(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("Accuracy: {}", accuracy(&test.target_vector, &pred));
assert_greater_than(accuracy(&test.target_vector, &pred), 0.83);
Expand All @@ -39,7 +39,7 @@ mod tests {
#[test]
fn decision_tree_housing() {
let (train, test) = read_train_test_dataset("housing");
let dt = DecisionTree::train_reg(&train, Some(5), Some(1),None, Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.59);
Expand All @@ -48,7 +48,7 @@ mod tests {
#[test]
fn decision_tree_diabeties() {
let (train, test) = read_train_test_dataset("diabetes");
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None,Some(42));
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
let pred = dt.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.30);
Expand All @@ -64,11 +64,10 @@ mod tests {
(train, test)
}


#[test]
fn random_forest_diabetes() {
let (train, test) = read_train_test_dataset("diabetes");
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None,Some(42));
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.38);
Expand All @@ -77,7 +76,7 @@ mod tests {
#[test]
fn random_forest_housing() {
let (train, test) = read_train_test_dataset("housing");
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None,Some(42));
let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
println!("R2: {}", r2(&test.target_vector, &pred));
assert_greater_than(r2(&test.target_vector, &pred), 0.641);
Expand All @@ -86,19 +85,18 @@ mod tests {
#[test]
fn random_forest_breast_cancer() {
let (train, test) = read_train_test_dataset("breast_cancer");
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None,Some(42));
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
let pred = classification_threshold(&pred, 0.5);

println!("Accuracy: {}", accuracy(&test.target_vector, &pred),);
assert_greater_than(accuracy(&test.target_vector, &pred), 0.96);

}

#[test]
fn random_forest_breast_titanic() {
let (train, test) = read_train_test_dataset("titanic");
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None,Some(42));
let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None, Some(42));
let pred = rf.predict(&test);
let pred = classification_threshold(&pred, 0.5);

Expand Down
22 changes: 8 additions & 14 deletions src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ impl RandomForest {
let default_train_options = TrainOptions::default_options(train.feature_names.len() as i32);
let params = TrainOptions {
max_depth: max_depth.unwrap_or(default_train_options.max_depth),
min_samples_leaf: min_samples_leaf
.unwrap_or(default_train_options.min_samples_leaf),
max_features: max_features.unwrap_or(default_train_options.max_features)
min_samples_leaf: min_samples_leaf.unwrap_or(default_train_options.min_samples_leaf),
max_features: max_features.unwrap_or(default_train_options.max_features),
};

let trees: Vec<Tree> = (0..n_estimators)
Expand Down Expand Up @@ -90,9 +89,8 @@ impl RandomForest {
let default_train_options = TrainOptions::default_options(train.feature_names.len() as i32);
let params = TrainOptions {
max_depth: max_depth.unwrap_or(default_train_options.max_depth),
min_samples_leaf: min_samples_leaf
.unwrap_or(default_train_options.min_samples_leaf),
max_features: max_features.unwrap_or(default_train_options.max_features)
min_samples_leaf: min_samples_leaf.unwrap_or(default_train_options.min_samples_leaf),
max_features: max_features.unwrap_or(default_train_options.max_features),
};
let trees: Vec<Tree> = (0..n_estimators)
.into_par_iter()
Expand Down Expand Up @@ -137,8 +135,6 @@ impl RandomForest {
}
}



#[pymethods]
impl DecisionTree {
#[staticmethod]
Expand All @@ -153,9 +149,8 @@ impl DecisionTree {
let default_train_options = TrainOptions::default_options(train.feature_names.len() as i32);
let params = TrainOptions {
max_depth: max_depth.unwrap_or(default_train_options.max_depth),
min_samples_leaf: min_samples_leaf
.unwrap_or(default_train_options.min_samples_leaf),
max_features: max_features.unwrap_or(default_train_options.max_features)
min_samples_leaf: min_samples_leaf.unwrap_or(default_train_options.min_samples_leaf),
max_features: max_features.unwrap_or(default_train_options.max_features),
};

DecisionTree {
Expand All @@ -181,9 +176,8 @@ impl DecisionTree {
let default_train_options = TrainOptions::default_options(train.feature_names.len() as i32);
let params = TrainOptions {
max_depth: max_depth.unwrap_or(default_train_options.max_depth),
min_samples_leaf: min_samples_leaf
.unwrap_or(default_train_options.min_samples_leaf),
max_features: max_features.unwrap_or(default_train_options.max_features)
min_samples_leaf: min_samples_leaf.unwrap_or(default_train_options.min_samples_leaf),
max_features: max_features.unwrap_or(default_train_options.max_features),
};
DecisionTree {
tree: Tree::fit(&train, 0, params, gini_coefficient_split_feature, &mut rng),
Expand Down

0 comments on commit 7f9ca05

Please sign in to comment.