From f2938df15344d8a814761ac0ef9fb1efa146de16 Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Sat, 28 Dec 2024 22:01:35 -0600 Subject: [PATCH 01/13] feat(stats_tests): implement f_oneway --- src/stats_tests/f_oneway.rs | 273 ++++++++++++++++++++++++++++++++++++ src/stats_tests/mod.rs | 13 ++ 2 files changed, 286 insertions(+) create mode 100644 src/stats_tests/f_oneway.rs diff --git a/src/stats_tests/f_oneway.rs b/src/stats_tests/f_oneway.rs new file mode 100644 index 00000000..a8e46680 --- /dev/null +++ b/src/stats_tests/f_oneway.rs @@ -0,0 +1,273 @@ +//! Provides the [one-way ANOVA F-test](https://en.wikipedia.org/wiki/One-way_analysis_of_variance) +//! and related functions + +use crate::distribution::{ContinuousCDF, FisherSnedecor}; +use crate::stats_tests::NaNPolicy; + +/// Represents the errors that occur when computing the f_oneway function +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum FOneWayTestError { + /// must be at least two samples + NotEnoughSamples, + /// one sample must be length greater than 1 + SampleTooSmall, + /// samples must not contain all of the same values + SampleContainsSameConstants, + /// samples can not contain NaN when `nan_policy` is set to `NaNPolicy::Error` + SampleContainsNaN, +} + +impl std::fmt::Display for FOneWayTestError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + FOneWayTestError::NotEnoughSamples => write!(f, "must be at least two samples"), + FOneWayTestError::SampleTooSmall => { + write!(f, "one sample must be length greater than 1") + } + FOneWayTestError::SampleContainsSameConstants => { + write!(f, "samples must not contain all of the same values") + } + FOneWayTestError::SampleContainsNaN => { + write!( + f, + "samples can not contain NaN when `nan_policy` is set to `NaNPolicy::Error`" + ) + } + } + } +} + +impl std::error::Error for FOneWayTestError {} + +/// Perform a one-way Analysis of Variance (ANOVA) F-test +/// +/// Takes in a set (outer vector) of samples (inner vector) and returns the F-statistic and p-value +/// +/// # Remarks +/// Implementation based on [statsdirect](https://www.statsdirect.com/help/analysis_of_variance/one_way.htm) +/// and [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.f_oneway.html#scipy.stats.f_oneway) +/// +/// # Examples +/// +/// ``` +/// use statrs::stats_tests::f_oneway::f_oneway; +/// use statrs::stats_tests::NaNPolicy; +/// +/// // based on wikipedia example +/// let a1 = Vec::from([6f64, 8f64, 4f64, 5f64, 3f64, 4f64]); +/// let a2 = Vec::from([8f64, 12f64, 9f64, 11f64, 6f64, 8f64]); +/// let a3 = Vec::from([13f64, 9f64, 11f64, 8f64, 7f64, 12f64]); +/// let sample_input = Vec::from([a1, a2, a3]); +/// let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Error).unwrap(); // (9.3, 0.002) +/// ``` +pub fn f_oneway( + samples: Vec>, + nan_policy: NaNPolicy, +) -> Result<(f64, f64), FOneWayTestError> { + // samples as mutable in case it needs to be modified via NaNPolicy::Emit + let mut samples = samples; + let k = samples.len(); + + // initial input validation + if k < 2 { + return Err(FOneWayTestError::NotEnoughSamples); + } + + let has_nans = samples.iter().flatten().any(|x| x.is_nan()); + if has_nans { + match nan_policy { + NaNPolicy::Propogate => { + return Ok((f64::NAN, f64::NAN)); + } + NaNPolicy::Error => { + return Err(FOneWayTestError::SampleContainsNaN); + } + NaNPolicy::Emit => { + samples = samples + .into_iter() + .map(|v| v.into_iter().filter(|x| !x.is_nan()).collect::>()) + .collect::>(); + } + } + } + + // do remaining input validation after potential subset from Emit + let n_i: Vec = samples.iter().map(|v| v.len()).collect(); + if !n_i.iter().all(|x| *x >= 1) || !n_i.iter().any(|x| *x >= 2) { + return Err(FOneWayTestError::SampleTooSmall); + } + + if samples.iter().any(|v| { + if v.len() > 1 { + let mut it = v.iter(); + let first = it.next().unwrap(); + it.all(|x| x == first) + } else { + false + } + }) { + return Err(FOneWayTestError::SampleContainsSameConstants); + } + + let n = n_i.iter().sum::(); + let g = samples.iter().flatten().sum::(); + + let tsq = samples + .iter() + .map(|v| v.iter().sum::().powi(2) / v.len() as f64) + .sum::(); + let ysq = samples.iter().flatten().map(|x| x.powi(2)).sum::(); + + // Sum of Squares (SS) and Mean Square (MS) between and within groups + let sst = tsq - (g.powi(2) / n as f64); + let mst = sst / (k - 1) as f64; + + let sse = ysq - tsq; + let mse = sse / (n - k) as f64; + + let fstat = mst / mse; + + // degrees of freedom for between groups (t) and within groups (e) + let dft = (k - 1) as f64; + let dfe = (n - k) as f64; + // k >= 2 meaning dft = (k-1) > 0 or Err(NotEnoughSamples) + // one group must be at least 2 and all other groups must be at least 1 or Err(SampleTooSmall) + // meaning that the minimum value of n will always be at least one greater than k so dfe must + // be > 0 + let f_dist = FisherSnedecor::new(dft, dfe).expect("degrees of freedom should always be >0 "); + let pvalue = 1.0 - f_dist.cdf(fstat); + + Ok((fstat, pvalue)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prec; + + #[test] + fn test_scipy_example() { + // Test against the scipy example + // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.f_oneway.html#scipy.stats.f_oneway + let tillamook = Vec::from([ + 0.0571, 0.0813, 0.0831, 0.0976, 0.0817, 0.0859, 0.0735, 0.0659, 0.0923, 0.0836, + ]); + let newport = Vec::from([ + 0.0873, 0.0662, 0.0672, 0.0819, 0.0749, 0.0649, 0.0835, 0.0725, + ]); + let petersburg = Vec::from([0.0974, 0.1352, 0.0817, 0.1016, 0.0968, 0.1064, 0.105]); + let magadan = Vec::from([ + 0.1033, 0.0915, 0.0781, 0.0685, 0.0677, 0.0697, 0.0764, 0.0689, + ]); + let tvarminne = Vec::from([0.0703, 0.1026, 0.0956, 0.0973, 0.1039, 0.1045]); + let sample_input = Vec::from([tillamook, newport, petersburg, magadan, tvarminne]); + let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Error).unwrap(); + + assert!(prec::almost_eq(statistic, 7.121019471642447, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.0002812242314534544, 1e-12)); + } + #[test] + fn test_nan_in_data_w_emit() { + // same as scipy example above with NaNs added should give same result + let tillamook = Vec::from([ + 0.0571, + 0.0813, + 0.0831, + 0.0976, + 0.0817, + 0.0859, + 0.0735, + 0.0659, + 0.0923, + 0.0836, + f64::NAN, + ]); + let newport = Vec::from([ + 0.0873, 0.0662, 0.0672, 0.0819, 0.0749, 0.0649, 0.0835, 0.0725, + ]); + let petersburg = Vec::from([0.0974, 0.1352, 0.0817, 0.1016, 0.0968, 0.1064, 0.105]); + let magadan = Vec::from([ + 0.1033, + 0.0915, + 0.0781, + 0.0685, + 0.0677, + 0.0697, + 0.0764, + 0.0689, + f64::NAN, + ]); + let tvarminne = Vec::from([0.0703, 0.1026, 0.0956, 0.0973, 0.1039, 0.1045]); + let sample_input = Vec::from([tillamook, newport, petersburg, magadan, tvarminne]); + let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Emit).unwrap(); + + assert!(prec::almost_eq(statistic, 7.121019471642447, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.0002812242314534544, 1e-12)); + } + #[test] + fn test_nan_in_data_w_propogate() { + let group1 = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); + let group2 = Vec::from([0.0873, 0.0662, 0.0672, 0.0819, 0.0749]); + let sample_input = Vec::from([group1, group2]); + let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Propogate).unwrap(); + assert!(statistic.is_nan()); + assert!(pvalue.is_nan()); + } + #[test] + fn test_nan_in_data_w_error() { + let group1 = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); + let group2 = Vec::from([0.0873, 0.0662, 0.0672, 0.0819, 0.0749]); + let sample_input = Vec::from([group1, group2]); + let result = f_oneway(sample_input, NaNPolicy::Error); + assert_eq!(result, Err(FOneWayTestError::SampleContainsNaN)); + } + #[test] + fn test_bad_data_not_enough_samples() { + let group1 = Vec::from([0.0, 0.0]); + let sample_input = Vec::from([group1]); + let result = f_oneway(sample_input, NaNPolicy::Propogate); + assert_eq!(result, Err(FOneWayTestError::NotEnoughSamples)) + } + #[test] + fn test_bad_data_sample_too_small() { + let group1 = Vec::new(); + let group2 = Vec::from([0.0873, 0.0662]); + let sample_input = Vec::from([group1, group2]); + let result = f_oneway(sample_input, NaNPolicy::Propogate); + assert_eq!(result, Err(FOneWayTestError::SampleTooSmall)); + + let group1 = Vec::from([f64::NAN]); + let group2 = Vec::from([0.0873, 0.0662]); + let sample_input = Vec::from([group1, group2]); + let result = f_oneway(sample_input, NaNPolicy::Emit); + assert_eq!(result, Err(FOneWayTestError::SampleTooSmall)); + + let group1 = Vec::from([1.0]); + let group2 = Vec::from([0.0873]); + let sample_input = Vec::from([group1, group2]); + let result = f_oneway(sample_input, NaNPolicy::Propogate); + assert_eq!(result, Err(FOneWayTestError::SampleTooSmall)); + + let group1 = Vec::from([1.0, f64::NAN]); + let group2 = Vec::from([0.0873, f64::NAN]); + let sample_input = Vec::from([group1, group2]); + let result = f_oneway(sample_input, NaNPolicy::Emit); + assert_eq!(result, Err(FOneWayTestError::SampleTooSmall)); + } + #[test] + fn test_bad_data_sample_contains_same_constants() { + let group1 = Vec::from([1.0, 1.0]); + let group2 = Vec::from([2.0, 2.0]); + let sample_input = Vec::from([group1, group2]); + let result = f_oneway(sample_input, NaNPolicy::Error); + assert_eq!(result, Err(FOneWayTestError::SampleContainsSameConstants)); + + let group1 = Vec::from([1.0, 1.0, 1.0]); + let group2 = Vec::from([0.0873, 0.0662, 0.0342]); + let sample_input = Vec::from([group1, group2]); + let result = f_oneway(sample_input, NaNPolicy::Error); + assert_eq!(result, Err(FOneWayTestError::SampleContainsSameConstants)); + } +} diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs index 84a01fc7..b118449e 100644 --- a/src/stats_tests/mod.rs +++ b/src/stats_tests/mod.rs @@ -1,3 +1,4 @@ +pub mod f_oneway; pub mod fisher; /// Specifies an [alternative hypothesis](https://en.wikipedia.org/wiki/Alternative_hypothesis) @@ -14,4 +15,16 @@ pub enum Alternative { Greater, } +/// Specifies how to deal with NaNs provided in input data +/// based on scipy treatment +#[derive(Debug, Copy, Clone)] +pub enum NaNPolicy { + /// allow for NaNs; if exist fcuntion will return NaN + Propogate, + /// filter out the NaNs before calculations + Emit, + /// if NaNs are in the input data, return an Error + Error, +} + pub use fisher::{fishers_exact, fishers_exact_with_odds_ratio}; From 30ba470f75b3df5b8cb01aa7b25d7ae11abc9eed Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Sat, 28 Dec 2024 22:05:11 -0600 Subject: [PATCH 02/13] feat(stats_tests): implement ttest_onesample --- src/stats_tests/mod.rs | 1 + src/stats_tests/ttest_onesample.rs | 203 +++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+) create mode 100644 src/stats_tests/ttest_onesample.rs diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs index b118449e..57085185 100644 --- a/src/stats_tests/mod.rs +++ b/src/stats_tests/mod.rs @@ -1,5 +1,6 @@ pub mod f_oneway; pub mod fisher; +pub mod ttest_onesample; /// Specifies an [alternative hypothesis](https://en.wikipedia.org/wiki/Alternative_hypothesis) #[derive(Debug, Copy, Clone)] diff --git a/src/stats_tests/ttest_onesample.rs b/src/stats_tests/ttest_onesample.rs new file mode 100644 index 00000000..4c42a1e1 --- /dev/null +++ b/src/stats_tests/ttest_onesample.rs @@ -0,0 +1,203 @@ +//! Provides the [one-sample t-test](https://en.wikipedia.org/wiki/Student%27s_t-test#One-sample_t-test) +//! and related functions + +use crate::distribution::{ContinuousCDF, StudentsT}; +use crate::stats_tests::{Alternative, NaNPolicy}; + +/// Represents the errors that can occur when computing the ttest_onesample function +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum TTestOneSampleError { + /// sample must be greater than length 1 + SampleTooSmall, + /// samples can not contain NaN when `nan_policy` is set to `NaNPolicy::Error` + SampleContainsNaN, +} + +impl std::fmt::Display for TTestOneSampleError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TTestOneSampleError::SampleTooSmall => write!(f, "sample must be len > 1"), + TTestOneSampleError::SampleContainsNaN => { + write!( + f, + "samples can not contain NaN when nan_policy is set to NaNPolicy::Error" + ) + } + } + } +} + +impl std::error::Error for TTestOneSampleError {} + +/// Perform a one sample t-test +/// +/// Returns the t-statistic and p-value +/// +/// # Remarks +/// +/// Implementation based on [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_1samp.html). +/// +/// # Examples +/// +/// ``` +/// use statrs::stats_tests::ttest_onesample::ttest_onesample; +/// use statrs::stats_tests::{Alternative, NaNPolicy}; +/// let data = Vec::from([13f64, 9f64, 11f64, 8f64, 7f64, 12f64]); +/// let (statistic, pvalue) = ttest_onesample(data, 13f64, Alternative::TwoSided, NaNPolicy::Error).unwrap(); +/// ``` +pub fn ttest_onesample( + a: Vec, + popmean: f64, + alternative: Alternative, + nan_policy: NaNPolicy, +) -> Result<(f64, f64), TTestOneSampleError> { + // make a mutable in case it needs to be modified due to NaNPolicy::Emit + let mut a = a; + + let has_nans = a.iter().any(|x| x.is_nan()); + if has_nans { + match nan_policy { + NaNPolicy::Propogate => { + return Ok((f64::NAN, f64::NAN)); + } + NaNPolicy::Error => { + return Err(TTestOneSampleError::SampleContainsNaN); + } + NaNPolicy::Emit => { + a = a.into_iter().filter(|x| !x.is_nan()).collect::>(); + } + } + } + + let n = a.len(); + if n < 2 { + return Err(TTestOneSampleError::SampleTooSmall); + } + let samplemean = a.iter().sum::() / (n as f64); + let df = (n - 1) as f64; + let s = a.iter().map(|x| (x - samplemean).powi(2)).sum::() / df; + let se = (s / n as f64).sqrt(); + + let tstat = (samplemean - popmean) / se; + + let t_dist = + StudentsT::new(0.0, 1.0, df).expect("df should always be non NaN and greater than 0"); + + let pvalue = match alternative { + Alternative::TwoSided => 2.0 * (1.0 - t_dist.cdf(tstat.abs())), + Alternative::Less => t_dist.cdf(tstat), + Alternative::Greater => 1.0 - t_dist.cdf(tstat), + }; + + Ok((tstat, pvalue)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prec; + + /// Test one sample t-test comparing to + #[test] + fn test_jmp_example() { + // Test against an example from jmp.com + // https://www.jmp.com/en_us/statistics-knowledge-portal/t-test/one-sample-t-test.html + let data = Vec::from([ + 20.70f64, 27.46f64, 22.15f64, 19.85f64, 21.29f64, 24.75f64, 20.75f64, 22.91f64, + 25.34f64, 20.33f64, 21.54f64, 21.08f64, 22.14f64, 19.56f64, 21.10f64, 18.04f64, + 24.12f64, 19.95f64, 19.72f64, 18.28f64, 16.26f64, 17.46f64, 20.53f64, 22.12f64, + 25.06f64, 22.44f64, 19.08f64, 19.88f64, 21.39f64, 22.33f64, 25.79f64, + ]); + let (statistic, pvalue) = + ttest_onesample(data.clone(), 20.0, Alternative::TwoSided, NaNPolicy::Error).unwrap(); + assert!(prec::almost_eq(statistic, 3.066831635284081, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.004552621060635401, 1e-12)); + + let (statistic, pvalue) = + ttest_onesample(data.clone(), 20.0, Alternative::Greater, NaNPolicy::Error).unwrap(); + assert!(prec::almost_eq(statistic, 3.066831635284081, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.0022763105303177005, 1e-12)); + + let (statistic, pvalue) = + ttest_onesample(data.clone(), 20.0, Alternative::Less, NaNPolicy::Error).unwrap(); + assert!(prec::almost_eq(statistic, 3.066831635284081, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.9977236894696823, 1e-12)); + } + #[test] + fn test_nan_in_data_w_emit() { + // results should be the same as the example above since the NaNs should be filtered out + let data = Vec::from([ + 20.70f64, + 27.46f64, + 22.15f64, + 19.85f64, + 21.29f64, + 24.75f64, + 20.75f64, + 22.91f64, + 25.34f64, + 20.33f64, + 21.54f64, + 21.08f64, + 22.14f64, + 19.56f64, + 21.10f64, + 18.04f64, + 24.12f64, + 19.95f64, + 19.72f64, + 18.28f64, + 16.26f64, + 17.46f64, + 20.53f64, + 22.12f64, + 25.06f64, + 22.44f64, + 19.08f64, + 19.88f64, + 21.39f64, + 22.33f64, + 25.79f64, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + ]); + let (statistic, pvalue) = + ttest_onesample(data.clone(), 20.0, Alternative::TwoSided, NaNPolicy::Emit).unwrap(); + assert!(prec::almost_eq(statistic, 3.066831635284081, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.004552621060635401, 1e-12)); + } + #[test] + fn test_nan_in_data_w_propogate() { + let sample_input = Vec::from([1.3, f64::NAN]); + let (statistic, pvalue) = ttest_onesample( + sample_input, + 20.0, + Alternative::TwoSided, + NaNPolicy::Propogate, + ) + .unwrap(); + assert!(statistic.is_nan()); + assert!(pvalue.is_nan()); + } + #[test] + fn test_nan_in_data_w_error() { + let sample_input = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); + let result = ttest_onesample(sample_input, 20.0, Alternative::TwoSided, NaNPolicy::Error); + assert_eq!(result, Err(TTestOneSampleError::SampleContainsNaN)); + } + #[test] + fn test_bad_data_sample_too_small() { + let sample_input = Vec::new(); + let result = ttest_onesample(sample_input, 20.0, Alternative::TwoSided, NaNPolicy::Error); + assert_eq!(result, Err(TTestOneSampleError::SampleTooSmall)); + + let sample_input = Vec::from([1.0]); + let result = ttest_onesample(sample_input, 20.0, Alternative::TwoSided, NaNPolicy::Error); + assert_eq!(result, Err(TTestOneSampleError::SampleTooSmall)); + } +} From f686f6cac8c1cac6a0828cef8dfd86dd42bf4b8f Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Wed, 1 Jan 2025 21:54:31 -0600 Subject: [PATCH 03/13] feat(stats_tests): implement mannwhitneyu --- src/stats_tests/mannwhitneyu.rs | 503 ++++++++++++++++++++++++++++++++ src/stats_tests/mod.rs | 1 + 2 files changed, 504 insertions(+) create mode 100644 src/stats_tests/mannwhitneyu.rs diff --git a/src/stats_tests/mannwhitneyu.rs b/src/stats_tests/mannwhitneyu.rs new file mode 100644 index 00000000..81eee64b --- /dev/null +++ b/src/stats_tests/mannwhitneyu.rs @@ -0,0 +1,503 @@ +//! Provides the [Mann-Whitney U test](https://en.wikipedia.org/wiki/Mann–Whitney_U_test#) and related +//! functions + +use num_traits::clamp; + +use crate::distribution::{ContinuousCDF, Normal}; +use crate::stats_tests::Alternative; + +/// Represents the errors that can occur when computing the mannwhitneyu function +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MannWhitneyUError { + /// at least one element of the input data can not be compared to another element (possibly due + /// to float NaNs) + UncomparableData, + /// the samples for both `x` and `y` must be at least length 1 + SampleTooSmall, + /// `MannWhitneyUMethod::Exact` is not implemented for data where ties exist + ExactMethodWithTiesInData, +} + +impl std::fmt::Display for MannWhitneyUError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MannWhitneyUError::UncomparableData => { + write!(f, "elements in the data are not comparable") + } + MannWhitneyUError::SampleTooSmall => write!( + f, + "the samples for both `x` and `y` must be at least length 1" + ), + MannWhitneyUError::ExactMethodWithTiesInData => write!( + f, + "using the Exact method with ties in input data is not supported" + ), + } + } +} + +impl std::error::Error for MannWhitneyUError {} + +/// Represents the different methods that can be used when calculating the p-value for the +/// mannwhitneyu function +pub enum MannWhitneyUMethod { + /// determine method based on input data provided in `x` and `y`. Will use `Exact` for smaller + /// sample sizes and `AsymptoticInclContinuityCorrection` for larger samples and when there + /// are ties in the data + Automatic, + /// calculate the exact p-value + Exact, + /// calculate an approximated (via normal distribution) p-value including a continuity + /// correction + AsymptoticInclContinuityCorrection, + /// calculate an approximated (via normal distribution) p-value excluding a continuity + /// correction + AsymptoticExclContinuityCorrection, +} + +/// ranks data and accounts for ties to calculate the U statistic +fn rankdata_mwu(xy: Vec) -> Result<(Vec, Vec), MannWhitneyUError> { + let mut j = (0..xy.len()).collect::>(); + let mut y = xy; + + // check to make sure data can be compared to generate the ranks + for i in 0..y.len() { + for k in i + 1..y.len() { + if y[i].partial_cmp(&y[k]).is_none() { + return Err(MannWhitneyUError::UncomparableData); + } + } + } + + // calculate the ordinal rank minus 1 (ordinal index) in j which is roughly equivalent to + // np.argsort. Additionally sort xy at the same time + let mut zipped: Vec<_> = j.into_iter().zip(y).collect(); + zipped.sort_by(|(_, a), (_, b)| { + a.partial_cmp(b) + .expect("NaN should not exist or be filtered out by this point") + }); + (j, y) = zipped.into_iter().unzip(); + + let mut ranks_sorted: Vec = vec![999.0; y.len()]; + let mut t: Vec = vec![999; y.len()]; + + let mut k = 0; + let mut count = 1; + let n = y.len(); + + for i in 1..n { + if y[i] != y[i - 1] { + let ordinal_rank = k + 1; + let rank = ordinal_rank as f64 + (count as f64 - 1.0) / 2.0; + // repeat the rank in the event of ties + ranks_sorted[k..i].fill(rank); + // for ties, match scipy logic and have first occurrence be the count + // and all additional occurrences be 0 + t[k] = count; + t[(k + 1)..i].fill(0); + + // reset to handle next occurrence of a unique value + k = i; + count = 0; + } + count += 1; + } + + // handle from the last set of unique values to the end + // same logic as above except goes until n (instead of i) including the last count increment + let ordinal_rank = k + 1; + let rank = ordinal_rank as f64 + (count as f64 - 1.0) / 2.0; + ranks_sorted[k..n].fill(rank); + t[k] = count; + t[(k + 1)..n].fill(0); + + // leverage the ordinal indices from j to reverse into to the original ordering + let mut ranks = ranks_sorted; + let mut zipped: Vec<_> = j.into_iter().zip(ranks).collect(); + zipped.sort_by(|(i, _), (j, _)| i.partial_cmp(j).unwrap()); + (_, ranks) = zipped.into_iter().unzip::, Vec<_>>(); + + Ok((ranks, t)) +} + +/// based on https://github.com/scipy/scipy/blob/92d2a8592782ee19a1161d0bf3fc2241ba78bb63/scipy/stats/_mannwhitneyu.py#L149 +fn calc_mwu_asymptotic_pvalue( + u: f64, + n1: usize, + n2: usize, + t: Vec, + continuity: bool, +) -> f64 { + let mu = ((n1 * n2) as f64) / 2.0; + + let tie_term = t.iter().map(|x| x.pow(3) - x).sum::(); + + let n1 = n1 as f64; + let n2 = n2 as f64; + let n = n1 + n2; + + let s: f64 = (n1 * n2 / 12.0 * ((n + 1.0) - tie_term as f64 / (n * (n - 1.0)))).sqrt(); + + let mut numerator = u - mu; + if continuity { + numerator -= 0.5; + } + + let z = numerator / s; + + // NOTE: z could be infinity (if all input values are the same for example) + // but the Normal CDF should handle this in a consistent way with scipy + let norm_dist = Normal::default(); + 1.0 - norm_dist.cdf(z) +} + +fn calc_mwu_exact_pvalue(u: f64, n1: usize, n2: usize) -> f64 { + let n = n1 + n2; + let k = n1.min(n2); // use the smaller of the two for less combinations to go through + let mut a: Vec = (0..n).collect(); + + // placeholder for number of times U (observed) is smaller than the universe of U values + let mut numerator = 0; + let mut total = 0; // total combinations (universe of U values) + + loop { + // calculate the number of times the hypothesis is rejected + // + // add k since index 0 all the indices need to be shifted by 1 to represent ranks + let r1 = a[0..k].iter().sum::() + k; + let u_generic = r1 - (k * (k + 1)) / 2; + if u <= (u_generic as f64) { + numerator += 1; + } + total += 1; + + // handle generating the next combination of n choose k (non-recursively) + // + // figure out the right most index g + let mut i = k; + while i > 0 { + i -= 1; + if a[i] != i + n - k { + break; + } + } + + // all combinations have been generated since the first index is at its max value + if i == 0 && a[i] == n - k { + break; + } + + a[i] += 1; + + for j in i + 1..k { + a[j] = a[j - 1] + 1; + } + } + + if k == n1 { + 1.0 - numerator as f64 / total as f64 + } else { + // if k was set to n2, return back the compliment p-value + numerator as f64 / total as f64 + } +} + +/// Perform a Mann-Whitney U (Wilcoxon rank-sum) test +/// +/// Returns the U statistic (based on `x`) and p-value +/// +/// # Remarks +/// +/// For larger sample sizes, the Exact method can become computationally expensive. Per Wikipedia, +/// samples sizes (length of `x` + length of `y`) above 20 are approximated fairly well using the +/// asymptotic (normal) methods. +/// +/// Implementation was largely based on the [scipy version](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.mannwhitneyu.html#scipy.stats.mannwhitneyu). +/// There are a few deviations including, not supporting calculation of the value via permutation +/// tests, not supporting calculation of the exact p-value where input data includes ties, and not +/// supporting the NaN policy due to being generic on T which might not have NaN values. +/// +/// # Examples +/// +/// ``` +/// use statrs::stats_tests::mannwhitneyu::{mannwhitneyu, MannWhitneyUMethod}; +/// use statrs::stats_tests::Alternative; +/// +/// // based on scipy example +/// let male = Vec::from([19, 22, 16, 29, 24]); +/// let female = Vec::from([20, 11, 17, 12]); +/// +/// let (statistic, pvalue) = mannwhitneyu( +/// &male, +/// &female, +/// MannWhitneyUMethod::Automatic, +/// Alternative::TwoSided, +/// ) +/// .unwrap(); +/// ``` +pub fn mannwhitneyu( + x: &[T], + y: &[T], + method: MannWhitneyUMethod, + alternative: Alternative, +) -> Result<(f64, f64), MannWhitneyUError> { + let n1 = x.len(); + let n2 = y.len(); + + if n1 == 0 || n2 == 0 { + return Err(MannWhitneyUError::SampleTooSmall); + } + + let mut x = x.to_vec(); + let mut y = y.to_vec(); + x.append(&mut y); + + let (ranks, t) = rankdata_mwu(x)?; + // NOTE: in the case of ties (eg: x = &[1, 2, 3] and y = &[3, 4, 5]), the U statistic can be a float + // (being #.5). When there are no ties, U will always be a whole number + let r1 = ranks[..n1].iter().sum::(); + let u1 = r1 - (n1 * (n1 + 1) / 2) as f64; + let u2 = (n1 * n2) as f64 - u1; + + // f is a factor to apply to the p-value in a two-sided test + let (u, f) = match alternative { + Alternative::Greater => (u1, 1), + Alternative::Less => (u2, 1), + Alternative::TwoSided => (u1.max(u2), 2), + }; + + let mut pvalue = match method { + MannWhitneyUMethod::Automatic => { + if (n1 > 8 && n2 > 8) || t.iter().any(|x| x > &1usize) { + calc_mwu_asymptotic_pvalue(u, n1, n2, t, true) + } else { + calc_mwu_exact_pvalue(u, n1, n2) + } + } + MannWhitneyUMethod::Exact => { + if t.iter().any(|x| x > &1usize) { + return Err(MannWhitneyUError::ExactMethodWithTiesInData); + } + calc_mwu_exact_pvalue(u, n1, n2) + } + MannWhitneyUMethod::AsymptoticInclContinuityCorrection => { + calc_mwu_asymptotic_pvalue(u, n1, n2, t, true) + } + MannWhitneyUMethod::AsymptoticExclContinuityCorrection => { + calc_mwu_asymptotic_pvalue(u, n1, n2, t, false) + } + }; + + pvalue *= f as f64; + pvalue = clamp(pvalue, 0.0, 1.0); + + Ok((u1, pvalue)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prec; + + #[test] + fn test_wikipedia_example() { + // Replicate example from https://en.wikipedia.org/wiki/Mann–Whitney_U_test#Illustration_of_calculation_methods + let data = "THHHHHTTTTTH"; + let mut x = Vec::new(); + let mut y = Vec::new(); + + for (i, c) in data.chars().enumerate() { + if c == 'T' { + x.push(i + 1) + } else { + y.push(i + 1) + } + } + let (statistic, _) = mannwhitneyu( + &x, + &y, + MannWhitneyUMethod::AsymptoticInclContinuityCorrection, + Alternative::Less, + ) + .unwrap(); + assert_eq!(statistic, 25.0); + + let (statistic, _) = mannwhitneyu( + &y, + &x, + MannWhitneyUMethod::AsymptoticInclContinuityCorrection, + Alternative::Greater, + ) + .unwrap(); + assert_eq!(statistic, 11.0); + } + + #[test] + fn test_scipy_example() { + // Test against scipy function including the documentation example + // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.mannwhitneyu.html + // as well as additional validations comparing to examples run in python + let male = Vec::from([19, 22, 16, 29, 24]); + let female = Vec::from([20, 11, 17, 12]); + + let (statistic, pvalue) = mannwhitneyu( + &male, + &female, + MannWhitneyUMethod::Automatic, + Alternative::TwoSided, + ) + .unwrap(); + assert_eq!(statistic, 17.0); + assert!(prec::almost_eq(pvalue, 0.1111111111111111, 1e-9)); + + let (statistic, _) = mannwhitneyu( + &female, + &male, + MannWhitneyUMethod::Automatic, + Alternative::TwoSided, + ) + .unwrap(); + assert_eq!(statistic, 3.0); + + let (statistic, pvalue) = mannwhitneyu( + &male, + &female, + MannWhitneyUMethod::AsymptoticInclContinuityCorrection, + Alternative::TwoSided, + ) + .unwrap(); + assert_eq!(statistic, 17.0); + assert!(prec::almost_eq(pvalue, 0.11134688653314041, 1e-9)); + + // not in scipy's official example but testing other variations against python output + let (_, pvalue) = mannwhitneyu( + &male, + &female, + MannWhitneyUMethod::AsymptoticExclContinuityCorrection, + Alternative::Less, + ) + .unwrap(); + assert!(prec::almost_eq(pvalue, 0.95679463351315, 1e-9)); + + let (_, pvalue) = + mannwhitneyu(&male, &female, MannWhitneyUMethod::Exact, Alternative::Less).unwrap(); + assert!(prec::almost_eq(pvalue, 0.9682539682539683, 1e-9)); + + let (_, pvalue) = mannwhitneyu( + &male, + &female, + MannWhitneyUMethod::AsymptoticInclContinuityCorrection, + Alternative::Greater, + ) + .unwrap(); + assert!(prec::almost_eq(pvalue, 0.055673443266570206, 1e-9)); + + let (statistic, pvalue) = mannwhitneyu( + &[1], + &[2], + MannWhitneyUMethod::AsymptoticInclContinuityCorrection, + Alternative::Less, + ) + .unwrap(); + assert_eq!(statistic, 0.0); + assert!(prec::almost_eq(pvalue, 0.5, 1e-9)); + + // larger deviation from scipy logic for exact so double check here + // also check usage with floats + let x = &[5.0, 2.0, 7.0, 8.0, 9.0, 3.0, 11.0, 12.0]; + let y = &[1.0, 6.0, 10.0, 4.0]; + + let (statistic, pvalue) = + mannwhitneyu(x, y, MannWhitneyUMethod::Exact, Alternative::Greater).unwrap(); + assert_eq!(statistic, 21.0); + assert!(prec::almost_eq(pvalue, 0.23030303030303031, 1e-9)); + + let (statistic, pvalue) = + mannwhitneyu(x, y, MannWhitneyUMethod::Exact, Alternative::Less).unwrap(); + assert_eq!(statistic, 21.0); + assert!(prec::almost_eq(pvalue, 0.8161616161616161, 1e-9)); + + let (statistic, pvalue) = + mannwhitneyu(x, y, MannWhitneyUMethod::Exact, Alternative::TwoSided).unwrap(); + assert_eq!(statistic, 21.0); + assert!(prec::almost_eq(pvalue, 0.46060606060606063, 1e-9)); + + let (statistic, pvalue) = mannwhitneyu( + &[1, 1], + &[1, 1, 1], + MannWhitneyUMethod::AsymptoticInclContinuityCorrection, + Alternative::TwoSided, + ) + .unwrap(); + assert_eq!(statistic, 3.0); + assert!(prec::almost_eq(pvalue, 1.0, 1e-9)); + } + + #[test] + fn test_bad_data_nan() { + let male = Vec::from([19.0, 22.0, 16.0, 29.0, 24.0, f64::NAN]); + let female = Vec::from([20.0, 11.0, 17.0, 12.0]); + + let result = mannwhitneyu( + &male, + &female, + MannWhitneyUMethod::Automatic, + Alternative::TwoSided, + ); + assert_eq!(result, Err(MannWhitneyUError::UncomparableData)); + } + #[test] + fn test_bad_data_sample_too_small() { + let result = mannwhitneyu( + &[], + &[1, 2, 3], + MannWhitneyUMethod::Automatic, + Alternative::TwoSided, + ); + assert_eq!(result, Err(MannWhitneyUError::SampleTooSmall)); + + let result = mannwhitneyu::( + &[], + &[], + MannWhitneyUMethod::Automatic, + Alternative::TwoSided, + ); + assert_eq!(result, Err(MannWhitneyUError::SampleTooSmall)); + } + #[test] + fn test_bad_data_exact_with_ties() { + let result = mannwhitneyu( + &[1, 2], + &[1, 2, 3], + MannWhitneyUMethod::Exact, + Alternative::TwoSided, + ); + assert_eq!(result, Err(MannWhitneyUError::ExactMethodWithTiesInData)); + } + #[test] + fn test_rankdata_mwu() { + let data = Vec::from([1, 4, 3]); + let (rank, t) = rankdata_mwu(data).expect("data is good"); + assert_eq!(rank, Vec::from([1.0, 3.0, 2.0])); + assert_eq!(t, Vec::from([1, 1, 1])); + + let data = Vec::from([4.0, 2.0, 2.0, 1.0]); + let (rank, t) = rankdata_mwu(data).expect("data is good"); + assert_eq!(rank, Vec::from([4.0, 2.5, 2.5, 1.0])); + assert_eq!(t, Vec::from([1, 2, 0, 1,])); + + let data = Vec::from([1, 2, 2, 2, 3]); + let (rank, t) = rankdata_mwu(data).expect("data is good"); + assert_eq!(rank, Vec::from([1.0, 3.0, 3.0, 3.0, 5.0])); + assert_eq!(t, Vec::from([1, 3, 0, 0, 1])); + } + #[test] + fn test_calc_mwu_exact_pvalue() { + let pvalue = calc_mwu_exact_pvalue(4.0, 3, 2); + assert!(prec::almost_eq(pvalue, 0.4, 1e-9)); + let pvalue = calc_mwu_exact_pvalue(4.0, 2, 3); + assert!(prec::almost_eq(pvalue, 0.6, 1e-9)); + } +} diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs index 57085185..1f197db7 100644 --- a/src/stats_tests/mod.rs +++ b/src/stats_tests/mod.rs @@ -1,5 +1,6 @@ pub mod f_oneway; pub mod fisher; +pub mod mannwhitneyu; pub mod ttest_onesample; /// Specifies an [alternative hypothesis](https://en.wikipedia.org/wiki/Alternative_hypothesis) From a6f8137d11c8b103eca32bc30a6e9a27d12ead59 Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Sat, 11 Jan 2025 14:45:19 -0600 Subject: [PATCH 04/13] feat(stats_tests): implement skewtest --- src/stats_tests/mod.rs | 1 + src/stats_tests/skewtest.rs | 253 ++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+) create mode 100644 src/stats_tests/skewtest.rs diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs index 1f197db7..7a7c0499 100644 --- a/src/stats_tests/mod.rs +++ b/src/stats_tests/mod.rs @@ -1,6 +1,7 @@ pub mod f_oneway; pub mod fisher; pub mod mannwhitneyu; +pub mod skewtest; pub mod ttest_onesample; /// Specifies an [alternative hypothesis](https://en.wikipedia.org/wiki/Alternative_hypothesis) diff --git a/src/stats_tests/skewtest.rs b/src/stats_tests/skewtest.rs new file mode 100644 index 00000000..400200f0 --- /dev/null +++ b/src/stats_tests/skewtest.rs @@ -0,0 +1,253 @@ +//! Provides the [skewtest](https://docs.scipy.org/doc/scipy-1.15.0/reference/generated/scipy.stats.skewtest.html) +//! to test whether or not provided data is different than a normal distribution + +use crate::distribution::{ContinuousCDF, Normal}; +use crate::stats_tests::{Alternative, NaNPolicy}; + +/// Represents the errors that can occur when computing the skewtest function +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum SkewTestError { + /// sample must contain at least 8 observations + SampleTooSmall, + /// samples can not contain NaN when `nan_policy` is set to `NaNPolicy::Error` + SampleContainsNaN, +} + +impl std::fmt::Display for SkewTestError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + SkewTestError::SampleTooSmall => { + write!(f, "sample must contain at least 8 observations") + } + SkewTestError::SampleContainsNaN => { + write!( + f, + "samples can not contain NaN when nan_policy is set to NaNPolicy::Error" + ) + } + } + } +} + +impl std::error::Error for SkewTestError {} + +fn calc_root_b1(data: &[f64]) -> f64 { + // Fisher's moment coefficient of skewness + // https://en.wikipedia.org/wiki/Skewness#Definition + let n = data.len() as f64; + let mu = data.iter().sum::() / n; + + // NOTE: population not sample skewness + (data.iter().map(|x_i| (x_i - mu).powi(3)).sum::() / n) + / (data.iter().map(|x_i| (x_i - mu).powi(2)).sum::() / n).powf(1.5) +} + +/// Perform a skewness test for whether the skew of the sample provided is different than a normal +/// distribution +/// +/// Returns the z-score and p-value +/// +/// # Remarks +/// +/// Implementation based on [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.skewtest.html#scipy.stats.skewtest). +/// and [fintools.com](https://www.fintools.com/docs/normality_correlation.pdf) which both +/// reference D'Agostino, 1970 (but direct access to the paper has been challenging to find) +/// +/// # Examples +/// +/// ``` +/// use statrs::stats_tests::skewtest::skewtest; +/// use statrs::stats_tests::{Alternative, NaNPolicy}; +/// let data = Vec::from([ 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8.0f64, ]); +/// let (statistic, pvalue) = skewtest(data, Alternative::TwoSided, NaNPolicy::Error).unwrap(); +/// ``` +pub fn skewtest( + a: Vec, + alternative: Alternative, + nan_policy: NaNPolicy, +) -> Result<(f64, f64), SkewTestError> { + // make a mutable in case it needs to be modified due to NaNPolicy::Emit + let mut a = a; + + let has_nans = a.iter().any(|x| x.is_nan()); + if has_nans { + match nan_policy { + NaNPolicy::Propogate => { + return Ok((f64::NAN, f64::NAN)); + } + NaNPolicy::Error => { + return Err(SkewTestError::SampleContainsNaN); + } + NaNPolicy::Emit => { + a = a.into_iter().filter(|x| !x.is_nan()).collect::>(); + } + } + } + + let n = a.len(); + if n < 8 { + return Err(SkewTestError::SampleTooSmall); + } + let n = n as f64; + + let root_b1 = calc_root_b1(&a); + let mut y = root_b1 * ((n + 1.0) * (n + 3.0) / (6.0 * (n - 2.0))).sqrt(); + let beta2_root_b1 = 3.0 * (n.powi(2) + 27.0 * n - 70.0) * (n + 1.0) * (n + 3.0) + / ((n - 2.0) * (n + 5.0) * (n + 7.0) * (n + 9.0)); + let w_sq = -1.0 + (2.0 * (beta2_root_b1 - 1.0)).sqrt(); + let delta = 1.0 / (0.5 * w_sq.ln()).sqrt(); + let alpha = (2.0 / (w_sq - 1.0)).sqrt(); + // correction from scipy version to`match scipy example results + if y == 0.0 { + y = 1.0; + } + let zscore = delta * (y / alpha + ((y / alpha).powi(2) + 1.0).sqrt()).ln(); + + let norm_dist = Normal::default(); + + let pvalue = match alternative { + Alternative::TwoSided => 2.0 * (1.0 - norm_dist.cdf(zscore.abs())), + Alternative::Less => norm_dist.cdf(zscore), + Alternative::Greater => 1.0 - norm_dist.cdf(zscore), + }; + + Ok((zscore, pvalue)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prec; + + #[test] + fn test_scipy_example() { + let data = Vec::from([ + 148.0f64, 154.0f64, 158.0f64, 160.0f64, 161.0f64, 162.0f64, 166.0f64, 170.0f64, + 182.0f64, 195.0f64, 236.0f64, + ]); + let (statistic, pvalue) = + skewtest(data.clone(), Alternative::TwoSided, NaNPolicy::Error).unwrap(); + assert!(prec::almost_eq(statistic, 2.7788579769903414, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.005455036974740185, 1e-9)); + + let (statistic, pvalue) = skewtest( + Vec::from([ + 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8.0f64, + ]), + Alternative::TwoSided, + NaNPolicy::Error, + ) + .unwrap(); + assert!(prec::almost_eq(statistic, 1.0108048609177787, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.3121098361421897, 1e-9)); + let (statistic, pvalue) = skewtest( + Vec::from([ + 2.0f64, 8.0f64, 0.0f64, 4.0f64, 1.0f64, 9.0f64, 9.0f64, 0.0f64, + ]), + Alternative::TwoSided, + NaNPolicy::Error, + ) + .unwrap(); + assert!(prec::almost_eq(statistic, 0.44626385374196975, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.6554066631275459, 1e-9)); + let (statistic, pvalue) = skewtest( + Vec::from([ + 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8000.0f64, + ]), + Alternative::TwoSided, + NaNPolicy::Error, + ) + .unwrap(); + assert!(prec::almost_eq(statistic, 3.571773510360407, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.0003545719905823133, 1e-9)); + let (statistic, pvalue) = skewtest( + Vec::from([ + 100.0f64, 100.0f64, 100.0f64, 100.0f64, 100.0f64, 100.0f64, 100.0f64, 101.0f64, + ]), + Alternative::TwoSided, + NaNPolicy::Error, + ) + .unwrap(); + assert!(prec::almost_eq(statistic, 3.5717766638478072, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.000354567720281634, 1e012)); + let (statistic, pvalue) = skewtest( + Vec::from([ + 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8.0f64, + ]), + Alternative::Less, + NaNPolicy::Error, + ) + .unwrap(); + assert!(prec::almost_eq(statistic, 1.0108048609177787, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.8439450819289052, 1e-9)); + let (statistic, pvalue) = skewtest( + Vec::from([ + 1.0f64, 2.0f64, 3.0f64, 4.0f64, 5.0f64, 6.0f64, 7.0f64, 8.0f64, + ]), + Alternative::Greater, + NaNPolicy::Error, + ) + .unwrap(); + assert!(prec::almost_eq(statistic, 1.0108048609177787, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.15605491807109484, 1e-9)); + } + #[test] + fn test_nan_in_data_w_emit() { + // results should be the same as the example above since the NaNs should be filtered out + let data = Vec::from([ + 148.0f64, + 154.0f64, + 158.0f64, + 160.0f64, + 161.0f64, + 162.0f64, + 166.0f64, + 170.0f64, + 182.0f64, + 195.0f64, + 236.0f64, + f64::NAN, + ]); + let (statistic, pvalue) = + skewtest(data.clone(), Alternative::TwoSided, NaNPolicy::Emit).unwrap(); + assert!(prec::almost_eq(statistic, 2.7788579769903414, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.005455036974740185, 1e-9)); + } + #[test] + fn test_nan_in_data_w_propogate() { + let sample_input = Vec::from([1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, f64::NAN]); + let (statistic, pvalue) = + skewtest(sample_input, Alternative::TwoSided, NaNPolicy::Propogate).unwrap(); + assert!(statistic.is_nan()); + assert!(pvalue.is_nan()); + } + #[test] + fn test_nan_in_data_w_error() { + let sample_input = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); + let result = skewtest(sample_input, Alternative::TwoSided, NaNPolicy::Error); + assert_eq!(result, Err(SkewTestError::SampleContainsNaN)); + } + #[test] + fn test_bad_data_sample_too_small() { + let sample_input = Vec::new(); + let result = skewtest(sample_input, Alternative::TwoSided, NaNPolicy::Error); + assert_eq!(result, Err(SkewTestError::SampleTooSmall)); + + let sample_input = Vec::from([1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, f64::NAN]); + let result = skewtest(sample_input, Alternative::TwoSided, NaNPolicy::Emit); + assert_eq!(result, Err(SkewTestError::SampleTooSmall)); + } + #[test] + fn test_calc_root_b1() { + // compare to https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.skew.html + // since no wikipedia examples + let sample_input = Vec::from([1.0, 2.0, 3.0, 4.0, 5.0]); + assert_eq!(calc_root_b1(&sample_input), 0.0); + + let sample_input = Vec::from([2.0, 8.0, 0.0, 4.0, 1.0, 9.0, 9.0, 0.0]); + let result = calc_root_b1(&sample_input); + assert!(prec::almost_eq(result, 0.2650554122698573, 1e-1)); + } +} From f9a938b415a1576f62b3026715377071fe22c834 Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Sun, 12 Jan 2025 21:07:18 -0600 Subject: [PATCH 05/13] feat(stats_tests): implement chisquare --- src/stats_tests/chisquare.rs | 176 +++++++++++++++++++++++++++++++++++ src/stats_tests/mod.rs | 1 + 2 files changed, 177 insertions(+) create mode 100644 src/stats_tests/chisquare.rs diff --git a/src/stats_tests/chisquare.rs b/src/stats_tests/chisquare.rs new file mode 100644 index 00000000..02061c46 --- /dev/null +++ b/src/stats_tests/chisquare.rs @@ -0,0 +1,176 @@ +//! Provides the functions related to [Chi-Squared tests](https://en.wikipedia.org/wiki/Chi-squared_test) + +use crate::distribution::{ChiSquared, ContinuousCDF}; + +/// Represents the errors that can occur when computing the chisquare function +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ChiSquareTestError { + /// `f_obs` must have a length (or number of categories) greater than 1 + FObsInvalid, + /// `f_exp` must have same length and sum as `f_obs` + FExpInvalid, + /// for the p-value to be meaningful, `ddof` must be at least two less + /// than the number of categories, k, which is the length of `f_obs` + DdofInvalid, +} + +impl std::fmt::Display for ChiSquareTestError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ChiSquareTestError::FObsInvalid => { + write!(f, "`f_obs` must have a length greater than 1") + } + ChiSquareTestError::FExpInvalid => { + write!(f, "`f_exp` must have same length and sum as `f_obs`") + } + ChiSquareTestError::DdofInvalid => { + write!(f, "for the p-value to be meaningful, `ddof` must be at least two less than the number of categories, k, which is the length of `f_obs`") + } + } + } +} + +impl std::error::Error for ChiSquareTestError {} + +/// Perform a Pearson's chi-square test +/// +/// Returns the chi-square test statistic and p-value +/// +/// # Remarks +/// +/// Implementation based on the one-way chi-square test of [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html#scipy.stats.chisquare). +/// and Pearson's chi-squared test [wikipedia] article. +/// +/// `ddof` represents an adjustment that can be made to the degrees of freedom where the unadjusted +/// degrees of freedom is `f_obs.len() - 1`. +/// +/// # Examples +/// +/// ``` +/// use statrs::stats_tests::chisquare::chisquare; +/// let (statistic, pvalue) = chisquare(&[16, 18, 16, 14, 12, 12], None, None).unwrap(); +/// let (statistic, pvalue) = chisquare(&[16, 18, 16, 14, 12, 12], None, Some(1)).unwrap(); +/// let (statistic, pvalue) = chisquare( +/// &[16, 18, 16, 14, 12, 12], +/// Some(&[16.0, 16.0, 16.0, 16.0, 16.0, 8.0]), +/// None, +/// ) +/// .unwrap(); +/// ``` +pub fn chisquare( + f_obs: &[usize], + f_exp: Option<&[f64]>, + ddof: Option, +) -> Result<(f64, f64), ChiSquareTestError> { + let n: usize = f_obs.len(); + if n <= 1 { + return Err(ChiSquareTestError::FObsInvalid); + } + let total_samples = f_obs.iter().sum(); + let f_obs: Vec = f_obs.iter().map(|x| *x as f64).collect(); + + let f_exp = match f_exp { + Some(f_to_validate) => { + // same length check + if f_to_validate.len() != n { + return Err(ChiSquareTestError::FExpInvalid); + } + // same sum check + if f_to_validate.iter().sum::() as usize != total_samples { + return Err(ChiSquareTestError::FExpInvalid); + } + f_to_validate.to_vec() + } + None => { + // make the expected assuming equal frequency + vec![total_samples as f64 / n as f64; n] + } + }; + + let ddof = match ddof { + Some(ddof_to_validate) => { + if ddof_to_validate >= (n - 1) { + return Err(ChiSquareTestError::DdofInvalid); + } + ddof_to_validate + } + None => 0, + }; + let dof = n - 1 - ddof; + + let stat = f_obs + .into_iter() + .zip(f_exp) + .map(|(o, e)| (o - e).powi(2) / e) + .sum::(); + + let chi_dist = ChiSquared::new(dof as f64).expect("ddof validity should already be checked"); + let pvalue = 1.0 - chi_dist.cdf(stat); + + Ok((stat, pvalue)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prec; + + #[test] + fn test_scipy_example() { + let (statistic, pvalue) = chisquare(&[16, 18, 16, 14, 12, 12], None, None).unwrap(); + assert!(prec::almost_eq(statistic, 2.0, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.84914503608460956, 1e-9)); + + let (statistic, pvalue) = chisquare( + &[16, 18, 16, 14, 12, 12], + Some(&[16.0, 16.0, 16.0, 16.0, 16.0, 8.0]), + None, + ) + .unwrap(); + assert!(prec::almost_eq(statistic, 3.5, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.62338762774958223, 1e-9)); + + let (statistic, pvalue) = chisquare(&[16, 18, 16, 14, 12, 12], None, Some(1)).unwrap(); + assert!(prec::almost_eq(statistic, 2.0, 1e-1)); + assert!(prec::almost_eq(pvalue, 0.7357588823428847, 1e-9)); + } + #[test] + fn test_wiki_example() { + // fairness of dice - p-value not provided + let (statistic, _) = chisquare(&[5, 8, 9, 8, 10, 20], None, None).unwrap(); + assert!(prec::almost_eq(statistic, 13.4, 1e-1)); + + let (statistic, _) = chisquare(&[5, 8, 9, 8, 10, 20], Some(&[10.0; 6]), None).unwrap(); + assert!(prec::almost_eq(statistic, 13.4, 1e-1)); + + // chi-squared goodness of fit test + let (statistic, pvalue) = chisquare(&[44, 56], Some(&[50.0, 50.0]), None).unwrap(); + assert!(prec::almost_eq(statistic, 1.44, 1e-2)); + assert!(prec::almost_eq(pvalue, 0.24, 1e-2)); + } + + #[test] + fn test_bad_data_f_obs_invalid() { + let result = chisquare(&[16], None, None); + assert_eq!(result, Err(ChiSquareTestError::FObsInvalid)); + let f_exp: &[usize] = &[]; + let result = chisquare(f_exp, None, None); + assert_eq!(result, Err(ChiSquareTestError::FObsInvalid)); + } + #[test] + fn test_bad_data_f_exp_invalid() { + let result = chisquare(&[16, 18, 16, 14, 12, 12], Some(&[1.0, 2.0, 3.0]), None); + assert_eq!(result, Err(ChiSquareTestError::FExpInvalid)); + let result = chisquare(&[16, 18, 16, 14, 12, 12], Some(&[16.0; 6]), None); + assert_eq!(result, Err(ChiSquareTestError::FExpInvalid)); + } + #[test] + fn test_bad_data_ddof_invalid() { + let result = chisquare(&[16, 18, 16, 14, 12, 12], None, Some(5)); + assert_eq!(result, Err(ChiSquareTestError::DdofInvalid)); + let result = chisquare(&[16, 18, 16, 14, 12, 12], None, Some(100)); + assert_eq!(result, Err(ChiSquareTestError::DdofInvalid)); + } +} diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs index 7a7c0499..b971bed3 100644 --- a/src/stats_tests/mod.rs +++ b/src/stats_tests/mod.rs @@ -1,3 +1,4 @@ +pub mod chisquare; pub mod f_oneway; pub mod fisher; pub mod mannwhitneyu; From 29cc70b080efd1d0eeda1bfcdb6516cd5e532b23 Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Wed, 15 Jan 2025 13:05:20 -0600 Subject: [PATCH 06/13] refactor: `mut` in function header instead of in function --- src/stats_tests/f_oneway.rs | 6 +++--- src/stats_tests/mannwhitneyu.rs | 5 ++--- src/stats_tests/skewtest.rs | 7 +++---- src/stats_tests/ttest_onesample.rs | 7 +++---- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/stats_tests/f_oneway.rs b/src/stats_tests/f_oneway.rs index a8e46680..638e5612 100644 --- a/src/stats_tests/f_oneway.rs +++ b/src/stats_tests/f_oneway.rs @@ -49,6 +49,8 @@ impl std::error::Error for FOneWayTestError {} /// Implementation based on [statsdirect](https://www.statsdirect.com/help/analysis_of_variance/one_way.htm) /// and [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.f_oneway.html#scipy.stats.f_oneway) /// +/// `samples` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit +/// /// # Examples /// /// ``` @@ -63,11 +65,9 @@ impl std::error::Error for FOneWayTestError {} /// let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Error).unwrap(); // (9.3, 0.002) /// ``` pub fn f_oneway( - samples: Vec>, + mut samples: Vec>, nan_policy: NaNPolicy, ) -> Result<(f64, f64), FOneWayTestError> { - // samples as mutable in case it needs to be modified via NaNPolicy::Emit - let mut samples = samples; let k = samples.len(); // initial input validation diff --git a/src/stats_tests/mannwhitneyu.rs b/src/stats_tests/mannwhitneyu.rs index 81eee64b..18504c0e 100644 --- a/src/stats_tests/mannwhitneyu.rs +++ b/src/stats_tests/mannwhitneyu.rs @@ -58,9 +58,8 @@ pub enum MannWhitneyUMethod { } /// ranks data and accounts for ties to calculate the U statistic -fn rankdata_mwu(xy: Vec) -> Result<(Vec, Vec), MannWhitneyUError> { - let mut j = (0..xy.len()).collect::>(); - let mut y = xy; +fn rankdata_mwu(mut y: Vec) -> Result<(Vec, Vec), MannWhitneyUError> { + let mut j = (0..y.len()).collect::>(); // check to make sure data can be compared to generate the ranks for i in 0..y.len() { diff --git a/src/stats_tests/skewtest.rs b/src/stats_tests/skewtest.rs index 400200f0..ef0c61aa 100644 --- a/src/stats_tests/skewtest.rs +++ b/src/stats_tests/skewtest.rs @@ -55,6 +55,8 @@ fn calc_root_b1(data: &[f64]) -> f64 { /// and [fintools.com](https://www.fintools.com/docs/normality_correlation.pdf) which both /// reference D'Agostino, 1970 (but direct access to the paper has been challenging to find) /// +/// `a` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit +/// /// # Examples /// /// ``` @@ -64,13 +66,10 @@ fn calc_root_b1(data: &[f64]) -> f64 { /// let (statistic, pvalue) = skewtest(data, Alternative::TwoSided, NaNPolicy::Error).unwrap(); /// ``` pub fn skewtest( - a: Vec, + mut a: Vec, alternative: Alternative, nan_policy: NaNPolicy, ) -> Result<(f64, f64), SkewTestError> { - // make a mutable in case it needs to be modified due to NaNPolicy::Emit - let mut a = a; - let has_nans = a.iter().any(|x| x.is_nan()); if has_nans { match nan_policy { diff --git a/src/stats_tests/ttest_onesample.rs b/src/stats_tests/ttest_onesample.rs index 4c42a1e1..f9028dfd 100644 --- a/src/stats_tests/ttest_onesample.rs +++ b/src/stats_tests/ttest_onesample.rs @@ -39,6 +39,8 @@ impl std::error::Error for TTestOneSampleError {} /// /// Implementation based on [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_1samp.html). /// +/// `a` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit +/// /// # Examples /// /// ``` @@ -48,14 +50,11 @@ impl std::error::Error for TTestOneSampleError {} /// let (statistic, pvalue) = ttest_onesample(data, 13f64, Alternative::TwoSided, NaNPolicy::Error).unwrap(); /// ``` pub fn ttest_onesample( - a: Vec, + mut a: Vec, popmean: f64, alternative: Alternative, nan_policy: NaNPolicy, ) -> Result<(f64, f64), TTestOneSampleError> { - // make a mutable in case it needs to be modified due to NaNPolicy::Emit - let mut a = a; - let has_nans = a.iter().any(|x| x.is_nan()); if has_nans { match nan_policy { From e21e5b6d2b98b6df22f720f4fbe47f9052cb18fa Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Wed, 15 Jan 2025 13:10:23 -0600 Subject: [PATCH 07/13] test: more coverage for `mannwhitneyu` --- src/stats_tests/mannwhitneyu.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/stats_tests/mannwhitneyu.rs b/src/stats_tests/mannwhitneyu.rs index 18504c0e..f4aa189d 100644 --- a/src/stats_tests/mannwhitneyu.rs +++ b/src/stats_tests/mannwhitneyu.rs @@ -476,6 +476,31 @@ mod tests { assert_eq!(result, Err(MannWhitneyUError::ExactMethodWithTiesInData)); } #[test] + fn test_automatic_asymptotic() { + // compare to running same inputs in scipy + // both samples len > 8 + let (statistic, pvalue) = mannwhitneyu( + &[19, 22, 16, 29, 24, 28, 7, 10, 30], + &[20, 11, 17, 12, 5, 31, 18, 2, 34], + MannWhitneyUMethod::Automatic, + Alternative::TwoSided, + ) + .unwrap(); + assert_eq!(statistic, 49.0); + assert!(prec::almost_eq(pvalue, 0.47992869214595724, 1e-9)); + + // ties in data + let (statistic, pvalue) = mannwhitneyu( + &[1, 2, 3, 4, 5, 6], + &[6, 7, 8, 9, 10], + MannWhitneyUMethod::Automatic, + Alternative::TwoSided, + ) + .unwrap(); + assert_eq!(statistic, 0.5); + assert!(prec::almost_eq(pvalue, 0.010411098147110422, 1e-9)); + } + #[test] fn test_rankdata_mwu() { let data = Vec::from([1, 4, 3]); let (rank, t) = rankdata_mwu(data).expect("data is good"); From f734aa66eb7feb20cb9cdd2a673b1588018cbd56 Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Wed, 15 Jan 2025 13:10:50 -0600 Subject: [PATCH 08/13] test: more coverage for `f_oneway` --- src/stats_tests/f_oneway.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/stats_tests/f_oneway.rs b/src/stats_tests/f_oneway.rs index 638e5612..d9156289 100644 --- a/src/stats_tests/f_oneway.rs +++ b/src/stats_tests/f_oneway.rs @@ -207,6 +207,16 @@ mod tests { assert!(prec::almost_eq(pvalue, 0.0002812242314534544, 1e-12)); } #[test] + fn test_group_length_one_ok() { + // group length 1 doesn't result in error + let group1 = Vec::from([0.5]); + let group2 = Vec::from([0.25, 0.75]); + let sample_input = Vec::from([group1, group2]); + let (statistic, pvalue) = f_oneway(sample_input, NaNPolicy::Propogate).unwrap(); + assert!(prec::almost_eq(statistic, 0.0, 1e-1)); + assert!(prec::almost_eq(pvalue, 1.0, 1e-12)); + } + #[test] fn test_nan_in_data_w_propogate() { let group1 = Vec::from([0.0571, 0.0813, f64::NAN, 0.0836]); let group2 = Vec::from([0.0873, 0.0662, 0.0672, 0.0819, 0.0749]); From f4136d57c8c48646ea3b605287ab31c82f55da6b Mon Sep 17 00:00:00 2001 From: Michael Dahlin Date: Sat, 18 Jan 2025 08:57:59 -0600 Subject: [PATCH 09/13] docs(stats_test): better attribution of sources --- src/stats_tests/chisquare.rs | 7 ++++--- src/stats_tests/f_oneway.rs | 11 ++++++----- src/stats_tests/mannwhitneyu.rs | 15 +++++++++++---- src/stats_tests/skewtest.rs | 11 +++++++---- src/stats_tests/ttest_onesample.rs | 6 ++++-- 5 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/stats_tests/chisquare.rs b/src/stats_tests/chisquare.rs index 02061c46..f65dc943 100644 --- a/src/stats_tests/chisquare.rs +++ b/src/stats_tests/chisquare.rs @@ -40,12 +40,13 @@ impl std::error::Error for ChiSquareTestError {} /// /// # Remarks /// -/// Implementation based on the one-way chi-square test of [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html#scipy.stats.chisquare). -/// and Pearson's chi-squared test [wikipedia] article. -/// /// `ddof` represents an adjustment that can be made to the degrees of freedom where the unadjusted /// degrees of freedom is `f_obs.len() - 1`. /// +/// Implementation based on [wikipedia](https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test) +/// while aligning to [scipy's](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html) +/// function header where possible. The scipy implementation was also used for testing and validation. +/// /// # Examples /// /// ``` diff --git a/src/stats_tests/f_oneway.rs b/src/stats_tests/f_oneway.rs index d9156289..3216a04b 100644 --- a/src/stats_tests/f_oneway.rs +++ b/src/stats_tests/f_oneway.rs @@ -46,18 +46,21 @@ impl std::error::Error for FOneWayTestError {} /// Takes in a set (outer vector) of samples (inner vector) and returns the F-statistic and p-value /// /// # Remarks -/// Implementation based on [statsdirect](https://www.statsdirect.com/help/analysis_of_variance/one_way.htm) -/// and [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.f_oneway.html#scipy.stats.f_oneway) /// /// `samples` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit /// +/// Implementation based on [statsdirect](https://www.statsdirect.com/help/analysis_of_variance/one_way.htm) +/// while aligning to [scipy's](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.f_oneway.html#scipy.stats.f_oneway) +/// function header where possible. The scipy implementation was also used for testing and +/// validation. Includes the use of [McDonald et al. (1991)](doi.org/10.1007/BF01319403) for +/// testing and validation. +/// /// # Examples /// /// ``` /// use statrs::stats_tests::f_oneway::f_oneway; /// use statrs::stats_tests::NaNPolicy; /// -/// // based on wikipedia example /// let a1 = Vec::from([6f64, 8f64, 4f64, 5f64, 3f64, 4f64]); /// let a2 = Vec::from([8f64, 12f64, 9f64, 11f64, 6f64, 8f64]); /// let a3 = Vec::from([13f64, 9f64, 11f64, 8f64, 7f64, 12f64]); @@ -149,8 +152,6 @@ mod tests { #[test] fn test_scipy_example() { - // Test against the scipy example - // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.f_oneway.html#scipy.stats.f_oneway let tillamook = Vec::from([ 0.0571, 0.0813, 0.0831, 0.0976, 0.0817, 0.0859, 0.0735, 0.0659, 0.0923, 0.0836, ]); diff --git a/src/stats_tests/mannwhitneyu.rs b/src/stats_tests/mannwhitneyu.rs index f4aa189d..f0d007c2 100644 --- a/src/stats_tests/mannwhitneyu.rs +++ b/src/stats_tests/mannwhitneyu.rs @@ -213,10 +213,17 @@ fn calc_mwu_exact_pvalue(u: f64, n1: usize, n2: usize) -> f64 { /// samples sizes (length of `x` + length of `y`) above 20 are approximated fairly well using the /// asymptotic (normal) methods. /// -/// Implementation was largely based on the [scipy version](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.mannwhitneyu.html#scipy.stats.mannwhitneyu). -/// There are a few deviations including, not supporting calculation of the value via permutation -/// tests, not supporting calculation of the exact p-value where input data includes ties, and not -/// supporting the NaN policy due to being generic on T which might not have NaN values. +/// +/// Implementation based on [wikipedia](https://en.wikipedia.org/wiki/Mann–Whitney_U_test) +/// while aligning to [scipy's](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.mannwhitneyu.html#scipy.stats.mannwhitneyu) +/// function header where possible. The scipy implementation was also used for testing and +/// validation. Includes the use of [Shier (2004)](https://www.statstutor.ac.uk/resources/uploaded/mannwhitney.pdf) for +/// testing and validation. +/// +/// There are a few deviations from the scipy version including, not supporting calculation +/// of the value via permutation tests, not supporting calculation of the exact p-value +/// where input data includes ties, and not supporting the NaN policy due to being generic +/// on T which might not have NaN values. /// /// # Examples /// diff --git a/src/stats_tests/skewtest.rs b/src/stats_tests/skewtest.rs index ef0c61aa..9455e0cf 100644 --- a/src/stats_tests/skewtest.rs +++ b/src/stats_tests/skewtest.rs @@ -51,12 +51,15 @@ fn calc_root_b1(data: &[f64]) -> f64 { /// /// # Remarks /// -/// Implementation based on [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.skewtest.html#scipy.stats.skewtest). -/// and [fintools.com](https://www.fintools.com/docs/normality_correlation.pdf) which both -/// reference D'Agostino, 1970 (but direct access to the paper has been challenging to find) -/// /// `a` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit /// +/// Implementation based on [fintools.com](https://www.fintools.com/docs/normality_correlation.pdf) +/// which indirectly uses [D'Agostino, (1970)](https://doi.org/10.2307/2684359) +/// while aligning to [scipy's](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.skewtest.html#scipy.stats.skewtest) +/// function header where possible. The scipy implementation was also used for testing and validation. +/// Includes the use of [Shapiro & Wilk (1965)](https://doi.org/10.2307/2333709) for +/// testing and validation. +/// /// # Examples /// /// ``` diff --git a/src/stats_tests/ttest_onesample.rs b/src/stats_tests/ttest_onesample.rs index f9028dfd..e622eb0e 100644 --- a/src/stats_tests/ttest_onesample.rs +++ b/src/stats_tests/ttest_onesample.rs @@ -37,10 +37,12 @@ impl std::error::Error for TTestOneSampleError {} /// /// # Remarks /// -/// Implementation based on [scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_1samp.html). -/// /// `a` needs to be mutable in case needing to filter out NaNs for NaNPolicy::Emit /// +/// Implementation based on [jmp](https://www.jmp.com/en_us/statistics-knowledge-portal/t-test/one-sample-t-test.html) +/// while aligning to [scipy's](https://docs.scipy.org/doc/scipy-1.14.1/reference/generated/scipy.stats.ttest_1samp.html) +/// function header where possible. +/// /// # Examples /// /// ``` From f4658be0d4c8e617275e1c9df80e681c022e38a5 Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Sun, 26 Jan 2025 23:59:42 +0500 Subject: [PATCH 10/13] test: add check for derivative of CDF to ensure it matches PDF --- src/distribution/internal.rs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 13435d44..4b63565a 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -418,6 +418,34 @@ pub mod test { assert!(sum <= 1.0 + 1e-10); } + /// cdf should be the integral of the pdf + fn check_derivative_of_cdf_is_pdf + Continuous>( + dist: &D, + x_min: f64, + x_max: f64, + step: f64, + ) { + const DELTA: f64 = 1e-6; + let mut prev_x = x_min; + + loop { + let x = prev_x + step; + let x_ahead = x + DELTA; + let x_behind = x - DELTA; + let density = dist.pdf(x); + + let d_cdf = (dist.cdf(x_ahead) - dist.cdf(x_behind)) / (2.0 * DELTA); + + assert_almost_eq!(d_cdf, density, 1e-6); + + if x >= x_max { + break; + } else { + prev_x = x; + } + } + } + /// Does a series of checks that all continuous distributions must obey. /// 99% of the probability mass should be between x_min and x_max. pub fn check_continuous_distribution + Continuous>( @@ -433,6 +461,7 @@ pub mod test { assert_eq!(dist.cdf(f64::INFINITY), 1.0); check_integrate_pdf_is_cdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); + check_derivative_of_cdf_is_pdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); } /// Does a series of checks that all positive discrete distributions must From 36920f5ad0f395233e2fdf6a4adae162c25e023f Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Mon, 27 Jan 2025 12:09:54 +0500 Subject: [PATCH 11/13] test: improve numerical stability in CDF derivative check --- src/distribution/internal.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 4b63565a..fa70416a 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -425,7 +425,7 @@ pub mod test { x_max: f64, step: f64, ) { - const DELTA: f64 = 1e-6; + const DELTA: f64 = 1e-12; let mut prev_x = x_min; loop { @@ -433,10 +433,11 @@ pub mod test { let x_ahead = x + DELTA; let x_behind = x - DELTA; let density = dist.pdf(x); + let dx = 2.0 * DELTA; - let d_cdf = (dist.cdf(x_ahead) - dist.cdf(x_behind)) / (2.0 * DELTA); + let d_cdf = dist.cdf(x_ahead) - dist.cdf(x_behind); - assert_almost_eq!(d_cdf, density, 1e-6); + assert_almost_eq!(d_cdf, dx * density, 1e-11); if x >= x_max { break; From 89eaf105f1d9eca89d7f047570893c105e7400fa Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Mon, 27 Jan 2025 13:09:44 +0500 Subject: [PATCH 12/13] test: enhance checks for continuous distributions with panic safety --- src/distribution/internal.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index fa70416a..47e424ff 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -449,7 +449,9 @@ pub mod test { /// Does a series of checks that all continuous distributions must obey. /// 99% of the probability mass should be between x_min and x_max. - pub fn check_continuous_distribution + Continuous>( + pub fn check_continuous_distribution< + D: ContinuousCDF + Continuous + std::panic::RefUnwindSafe, + >( dist: &D, x_min: f64, x_max: f64, @@ -461,8 +463,16 @@ pub mod test { assert_eq!(dist.cdf(f64::NEG_INFINITY), 0.0); assert_eq!(dist.cdf(f64::INFINITY), 1.0); - check_integrate_pdf_is_cdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); - check_derivative_of_cdf_is_pdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); + let result_integration = std::panic::catch_unwind(|| { + check_integrate_pdf_is_cdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); + }); + let result_differentiation = std::panic::catch_unwind(|| { + check_derivative_of_cdf_is_pdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); + }); + + if result_integration.is_err() && result_differentiation.is_err() { + panic!("Integration of pdf doesn't equal cdf and derivative of cdf doesn't equal pdf!"); + } } /// Does a series of checks that all positive discrete distributions must From 71e010b45ea784280b0cc42b5cd285b8841121be Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Wed, 29 Jan 2025 11:29:00 +0500 Subject: [PATCH 13/13] chore: update docs and result chaining idioms --- src/distribution/internal.rs | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 47e424ff..68dca14c 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -418,7 +418,7 @@ pub mod test { assert!(sum <= 1.0 + 1e-10); } - /// cdf should be the integral of the pdf + /// pdf should be derivative of cdf fn check_derivative_of_cdf_is_pdf + Continuous>( dist: &D, x_min: f64, @@ -426,6 +426,7 @@ pub mod test { step: f64, ) { const DELTA: f64 = 1e-12; + const DX: f64 = 2.0 * DELTA; let mut prev_x = x_min; loop { @@ -433,11 +434,10 @@ pub mod test { let x_ahead = x + DELTA; let x_behind = x - DELTA; let density = dist.pdf(x); - let dx = 2.0 * DELTA; let d_cdf = dist.cdf(x_ahead) - dist.cdf(x_behind); - assert_almost_eq!(d_cdf, dx * density, 1e-11); + assert_almost_eq!(d_cdf, DX * density, 1e-11); if x >= x_max { break; @@ -448,7 +448,8 @@ pub mod test { } /// Does a series of checks that all continuous distributions must obey. - /// 99% of the probability mass should be between x_min and x_max. + /// 99% of the probability mass should be between x_min and x_max or the finite + /// difference of cdf should be near to the pdf for much of the support. pub fn check_continuous_distribution< D: ContinuousCDF + Continuous + std::panic::RefUnwindSafe, >( @@ -463,14 +464,14 @@ pub mod test { assert_eq!(dist.cdf(f64::NEG_INFINITY), 0.0); assert_eq!(dist.cdf(f64::INFINITY), 1.0); - let result_integration = std::panic::catch_unwind(|| { + if std::panic::catch_unwind(|| { check_integrate_pdf_is_cdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); - }); - let result_differentiation = std::panic::catch_unwind(|| { + }) + .or(std::panic::catch_unwind(|| { check_derivative_of_cdf_is_pdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); - }); - - if result_integration.is_err() && result_differentiation.is_err() { + })) + .is_err() + { panic!("Integration of pdf doesn't equal cdf and derivative of cdf doesn't equal pdf!"); } }