From 4ce0db8870ea4ba0ee87a6dd66c21035e65e4361 Mon Sep 17 00:00:00 2001 From: relf Date: Fri, 20 Oct 2023 12:48:15 +0200 Subject: [PATCH] Fix covariances update --- .../src/gaussian_mixture/algorithm.rs | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs index 5a34cd1e9..6f23485a0 100644 --- a/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs +++ b/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs @@ -311,8 +311,9 @@ impl GaussianMixtureModel { )?; self.means = means; self.weights = weights / F::cast(n_samples); + self.covariances = covariances; // GmmCovarType = Full() - self.precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?; + self.precisions_chol = Self::compute_precisions_cholesky_full(&self.covariances)?; Ok(()) } @@ -488,7 +489,9 @@ mod tests { use ndarray::{array, concatenate, ArrayView1, ArrayView2, Axis}; use ndarray_rand::rand::prelude::ThreadRng; use ndarray_rand::rand::SeedableRng; + use ndarray_rand::rand_distr::Normal; use ndarray_rand::rand_distr::{Distribution, StandardNormal}; + use ndarray_rand::RandomExt; #[test] fn autotraits() { @@ -570,6 +573,34 @@ mod tests { ); } + #[test] + fn test_gmm_covariances() { + let rng = rand_xoshiro::Xoshiro256Plus::seed_from_u64(123); + + let data_0 = ndarray::Array::random((500,), Normal::new(0., 0.5).unwrap()); + let data_1 = ndarray::Array::random((500,), Normal::new(1., 0.5).unwrap()); + let data_2 = ndarray::Array::random((500,), Normal::new(2., 0.5).unwrap()); + let data = ndarray::concatenate![ndarray::Axis(0), data_0, data_1, data_2]; + + let data_2d = data.insert_axis(ndarray::Axis(1)).to_owned(); + let dataset = linfa::DatasetBase::from(data_2d); + + let gmm = GaussianMixtureModel::params(3) + .n_runs(1) + .tolerance(1e-4) + .with_rng(rng) + .max_n_iterations(500) + .fit(&dataset) + .expect("GMM fit"); + + // expected results from scikit-learn 1.3.1 + let expected = array![[[0.22564062]], [[0.26204446]], [[0.23393885]]]; + let expected = Array::from_iter(expected.iter().cloned()); + let actual = gmm.covariances(); + let actual = Array::from_iter(actual.iter().cloned()); + assert_abs_diff_eq!(expected, actual, epsilon = 1e-1); + } + fn function_test_1d(x: &Array2) -> Array2 { let mut y = Array2::zeros(x.dim()); Zip::from(&mut y).and(x).for_each(|yi, &xi| {