From e0a9d6e95fa028cb15a43a43fb024f6ef9217599 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Sun, 12 Jan 2025 12:01:20 +0100 Subject: [PATCH] feat: add multivariate interface --- src/stats/copulas/bivariate.rs | 8 ++------ src/stats/copulas/multivariate.rs | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/src/stats/copulas/bivariate.rs b/src/stats/copulas/bivariate.rs index fc728fd..3aab8fb 100644 --- a/src/stats/copulas/bivariate.rs +++ b/src/stats/copulas/bivariate.rs @@ -13,15 +13,13 @@ pub mod gumbel; pub mod independence; #[derive(Debug, Clone, Copy)] -pub enum CopulaType { +pub(super) enum CopulaType { Clayton, Frank, Gumbel, Independence, } -const EPSILON: f64 = 1e-12; - pub trait Bivariate { fn r#type(&self) -> CopulaType; @@ -128,9 +126,7 @@ pub trait Bivariate { fn pdf(&self, X: &Array2) -> Result, Box>; fn log_pdf(&self, X: &Array2) -> Result, Box> { - let pdf = self.pdf(X)?; - let log_pdf = pdf.mapv(|val| (val + 1e-32).ln()); - Ok(log_pdf) + Ok(self.pdf(X)?.ln()) } fn cdf(&self, X: &Array2) -> Result, Box>; diff --git a/src/stats/copulas/multivariate.rs b/src/stats/copulas/multivariate.rs index 2016d20..6e70405 100644 --- a/src/stats/copulas/multivariate.rs +++ b/src/stats/copulas/multivariate.rs @@ -1,3 +1,31 @@ +use std::error::Error; + +use ndarray::{Array1, Array2}; + pub mod gaussian; pub mod tree; pub mod vine; + +pub(super) enum CopulaType { + Gaussian, + Tree, + Vine, +} + +pub trait Multivariate { + fn r#type(&self) -> CopulaType; + + fn sample(&self, n: usize) -> Result, Box>; + + fn fit(&mut self, X: Array2) -> Result<(), Box>; + + fn check_fit(&self, X: &Array2) -> Result<(), Box>; + + fn pdf(&self, X: Array2) -> Result, Box>; + + fn log_pdf(&self, X: Array2) -> Result, Box> { + Ok(self.pdf(X)?.ln()) + } + + fn cdf(&self, X: Array2) -> Result, Box>; +}