diff --git a/src/dataset.rs b/src/dataset.rs index fab58d1..3add5de 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -1,13 +1,13 @@ -use pyo3::prelude::*; -use rand::{rngs::StdRng, Rng}; -use std::fs; -use arrow::array::{Float32Array}; -use arrow::record_batch::RecordBatch; +use arrow::array::Float32Array; use arrow::compute::cast; use arrow::csv; -use std::fs::File; use arrow::datatypes::DataType; use arrow::pyarrow::PyArrowConvert; +use arrow::record_batch::RecordBatch; +use pyo3::prelude::*; +use rand::{rngs::StdRng, Rng}; +use std::fs; +use std::fs::File; use pyo3::types::PyAny; @@ -22,10 +22,9 @@ pub struct Dataset { } impl Dataset { - fn _from_pyarrow(df: &PyAny) -> Dataset { let batch = RecordBatch::from_pyarrow(df).unwrap(); - + let feature_names = batch .schema() .fields() @@ -41,22 +40,22 @@ impl Dataset { feature_matrix: feature_matrix[0..feature_matrix.len() - 1].to_vec(), target_name: feature_names.last().unwrap().to_string(), target_vector: feature_matrix.last().unwrap().to_vec(), - } + } } - fn _read_batch(batch: RecordBatch) -> Vec>{ + fn _read_batch(batch: RecordBatch) -> Vec> { batch - .columns() - .iter() - .map(|c| cast(c, &DataType::Float32).unwrap()) - .map(|c| { - c.as_any() - .downcast_ref::() - .unwrap() - .values() - .to_vec() - }) - .collect::>() + .columns() + .iter() + .map(|c| cast(c, &DataType::Float32).unwrap()) + .map(|c| { + c.as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec() + }) + .collect::>() } fn _read_csv(path: &str, sep: &str) -> Dataset { @@ -100,7 +99,7 @@ impl Dataset { feature_uniform: vec![false; self.feature_names.len()], feature_matrix: vec![], target_name: self.target_name.clone(), - target_vector: vec![] + target_vector: vec![], } } @@ -179,7 +178,6 @@ mod tests { target_vector: vec![1.0, 0.0], }; - assert_eq!(expected, got); } diff --git a/src/tests.rs b/src/tests.rs index aa859de..9a7bc93 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -12,7 +12,7 @@ mod tests { fn test_integration() { let train = Dataset::read_csv("datasets/diabetes_train.csv", ","); let test = Dataset::read_csv("datasets/diabetes_test.csv", ","); - let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42)); + let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42)); let mut pred = test.clone(); dt.predict(&mut pred); assert_eq!(r2(&test.target_vector, &pred.target_vector) > 0.28, true); @@ -30,7 +30,7 @@ mod tests { #[test] fn decision_tree_breast_cancer() { let (train, test) = read_train_test_dataset("breast_cancer"); - let dt = DecisionTree::train_clf(&train, Some(5), Some(1), None,Some(42)); + let dt = DecisionTree::train_clf(&train, Some(5), Some(1), None, Some(42)); let pred = dt.predict(&test); println!("Accuracy: {}", accuracy(&test.target_vector, &pred)); assert_greater_than(accuracy(&test.target_vector, &pred), 0.83); @@ -39,7 +39,7 @@ mod tests { #[test] fn decision_tree_housing() { let (train, test) = read_train_test_dataset("housing"); - let dt = DecisionTree::train_reg(&train, Some(5), Some(1),None, Some(42)); + let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42)); let pred = dt.predict(&test); println!("R2: {}", r2(&test.target_vector, &pred)); assert_greater_than(r2(&test.target_vector, &pred), 0.59); @@ -48,7 +48,7 @@ mod tests { #[test] fn decision_tree_diabeties() { let (train, test) = read_train_test_dataset("diabetes"); - let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None,Some(42)); + let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42)); let pred = dt.predict(&test); println!("R2: {}", r2(&test.target_vector, &pred)); assert_greater_than(r2(&test.target_vector, &pred), 0.30); @@ -64,11 +64,10 @@ mod tests { (train, test) } - #[test] fn random_forest_diabetes() { let (train, test) = read_train_test_dataset("diabetes"); - let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None,Some(42)); + let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None, Some(42)); let pred = rf.predict(&test); println!("R2: {}", r2(&test.target_vector, &pred)); assert_greater_than(r2(&test.target_vector, &pred), 0.38); @@ -77,7 +76,7 @@ mod tests { #[test] fn random_forest_housing() { let (train, test) = read_train_test_dataset("housing"); - let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None,Some(42)); + let rf = RandomForest::train_reg(&train, 10, Some(5), Some(1), None, Some(42)); let pred = rf.predict(&test); println!("R2: {}", r2(&test.target_vector, &pred)); assert_greater_than(r2(&test.target_vector, &pred), 0.641); @@ -86,19 +85,18 @@ mod tests { #[test] fn random_forest_breast_cancer() { let (train, test) = read_train_test_dataset("breast_cancer"); - let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None,Some(42)); + let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None, Some(42)); let pred = rf.predict(&test); let pred = classification_threshold(&pred, 0.5); println!("Accuracy: {}", accuracy(&test.target_vector, &pred),); assert_greater_than(accuracy(&test.target_vector, &pred), 0.96); - } #[test] fn random_forest_breast_titanic() { let (train, test) = read_train_test_dataset("titanic"); - let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None,Some(42)); + let rf = RandomForest::train_clf(&train, 10, Some(5), Some(1), None, Some(42)); let pred = rf.predict(&test); let pred = classification_threshold(&pred, 0.5); diff --git a/src/trees.rs b/src/trees.rs index 5bcb5ce..a7a4480 100644 --- a/src/trees.rs +++ b/src/trees.rs @@ -55,9 +55,8 @@ impl RandomForest { let default_train_options = TrainOptions::default_options(train.feature_names.len() as i32); let params = TrainOptions { max_depth: max_depth.unwrap_or(default_train_options.max_depth), - min_samples_leaf: min_samples_leaf - .unwrap_or(default_train_options.min_samples_leaf), - max_features: max_features.unwrap_or(default_train_options.max_features) + min_samples_leaf: min_samples_leaf.unwrap_or(default_train_options.min_samples_leaf), + max_features: max_features.unwrap_or(default_train_options.max_features), }; let trees: Vec = (0..n_estimators) @@ -90,9 +89,8 @@ impl RandomForest { let default_train_options = TrainOptions::default_options(train.feature_names.len() as i32); let params = TrainOptions { max_depth: max_depth.unwrap_or(default_train_options.max_depth), - min_samples_leaf: min_samples_leaf - .unwrap_or(default_train_options.min_samples_leaf), - max_features: max_features.unwrap_or(default_train_options.max_features) + min_samples_leaf: min_samples_leaf.unwrap_or(default_train_options.min_samples_leaf), + max_features: max_features.unwrap_or(default_train_options.max_features), }; let trees: Vec = (0..n_estimators) .into_par_iter() @@ -137,8 +135,6 @@ impl RandomForest { } } - - #[pymethods] impl DecisionTree { #[staticmethod] @@ -153,9 +149,8 @@ impl DecisionTree { let default_train_options = TrainOptions::default_options(train.feature_names.len() as i32); let params = TrainOptions { max_depth: max_depth.unwrap_or(default_train_options.max_depth), - min_samples_leaf: min_samples_leaf - .unwrap_or(default_train_options.min_samples_leaf), - max_features: max_features.unwrap_or(default_train_options.max_features) + min_samples_leaf: min_samples_leaf.unwrap_or(default_train_options.min_samples_leaf), + max_features: max_features.unwrap_or(default_train_options.max_features), }; DecisionTree { @@ -181,9 +176,8 @@ impl DecisionTree { let default_train_options = TrainOptions::default_options(train.feature_names.len() as i32); let params = TrainOptions { max_depth: max_depth.unwrap_or(default_train_options.max_depth), - min_samples_leaf: min_samples_leaf - .unwrap_or(default_train_options.min_samples_leaf), - max_features: max_features.unwrap_or(default_train_options.max_features) + min_samples_leaf: min_samples_leaf.unwrap_or(default_train_options.min_samples_leaf), + max_features: max_features.unwrap_or(default_train_options.max_features), }; DecisionTree { tree: Tree::fit(&train, 0, params, gini_coefficient_split_feature, &mut rng),