Skip to content

Commit

Permalink
Fix doc comment example
Browse files Browse the repository at this point in the history
  • Loading branch information
glazari committed Nov 22, 2023
1 parent b0c752f commit 9fb0ff1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
//! ```rust
//! use rustrees::{DecisionTree, Dataset, r2};
//!
//! let dataset = Dataset::read_csv("iris.csv", ",");
//! let dataset = Dataset::read_csv("datasets/titanic_train.csv", ",");
//!
//! let dt = DecisionTree::train_reg(
//! &dataset,
//! 5, // max_depth
//! Some(5), // max_depth
//! Some(1), // min_samples_leaf
//! None, // max_features (None = all features)
//! Some(42), // random_state
//! );
//!
Expand All @@ -35,6 +36,7 @@ pub use trees::DecisionTree;
pub use trees::RandomForest;
pub use trees::TrainOptions;
pub use trees::Tree;
pub use utils::{accuracy, r2};

use pyo3::prelude::*;

Expand Down
2 changes: 1 addition & 1 deletion src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(test)]
mod tests {
use crate::{trees::RandomForest, *};
use crate::{utils::classification_threshold, utils::r2, utils::accuracy};

fn assert_greater_than(a: f32, b: f32) {
if a <= b {
Expand All @@ -10,7 +11,6 @@ mod tests {

#[test]
fn test_integration() {
let a = Box::new(5);
let train = Dataset::read_csv("datasets/diabetes_train.csv", ",");
let test = Dataset::read_csv("datasets/diabetes_test.csv", ",");
let dt = DecisionTree::train_reg(&train, Some(5), Some(1), None, Some(42));
Expand Down
6 changes: 2 additions & 4 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ pub(crate) fn classification_threshold(x: &[f32], clf_threshold: f32) -> Vec<f32
}

/// computes the mean squared error between two vectors used for testing regression case.
#[cfg(test)]
pub(crate) fn r2(x_true: &[f32], x_pred: &[f32]) -> f32 {
pub fn r2(x_true: &[f32], x_pred: &[f32]) -> f32 {
let mse: f32 = x_true
.iter()
.zip(x_pred)
Expand All @@ -38,8 +37,7 @@ pub(crate) fn r2(x_true: &[f32], x_pred: &[f32]) -> f32 {
}

/// computes the accuracy of a binary classification. Used for testing.
#[cfg(test)]
pub(crate) fn accuracy(x_true: &[f32], x_pred: &[f32]) -> f32 {
pub fn accuracy(x_true: &[f32], x_pred: &[f32]) -> f32 {
x_true
.iter()
.zip(x_pred)
Expand Down

0 comments on commit 9fb0ff1

Please sign in to comment.