Skip to content

Commit

Permalink
feat: add multivariate interface
Browse files Browse the repository at this point in the history
  • Loading branch information
dancixx committed Jan 12, 2025
1 parent 79fb41f commit e0a9d6e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
8 changes: 2 additions & 6 deletions src/stats/copulas/bivariate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ pub mod gumbel;
pub mod independence;

#[derive(Debug, Clone, Copy)]
pub enum CopulaType {
pub(super) enum CopulaType {
Clayton,
Frank,
Gumbel,
Independence,
}

const EPSILON: f64 = 1e-12;

pub trait Bivariate {
fn r#type(&self) -> CopulaType;

Expand Down Expand Up @@ -128,9 +126,7 @@ pub trait Bivariate {
fn pdf(&self, X: &Array2<f64>) -> Result<Array1<f64>, Box<dyn Error>>;

fn log_pdf(&self, X: &Array2<f64>) -> Result<Array1<f64>, Box<dyn Error>> {
let pdf = self.pdf(X)?;
let log_pdf = pdf.mapv(|val| (val + 1e-32).ln());
Ok(log_pdf)
Ok(self.pdf(X)?.ln())
}

fn cdf(&self, X: &Array2<f64>) -> Result<Array1<f64>, Box<dyn Error>>;
Expand Down
28 changes: 28 additions & 0 deletions src/stats/copulas/multivariate.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
use std::error::Error;

use ndarray::{Array1, Array2};

pub mod gaussian;
pub mod tree;
pub mod vine;

pub(super) enum CopulaType {
Gaussian,
Tree,
Vine,
}

pub trait Multivariate {
fn r#type(&self) -> CopulaType;

fn sample(&self, n: usize) -> Result<Array2<f64>, Box<dyn Error>>;

fn fit(&mut self, X: Array2<f64>) -> Result<(), Box<dyn Error>>;

fn check_fit(&self, X: &Array2<f64>) -> Result<(), Box<dyn Error>>;

fn pdf(&self, X: Array2<f64>) -> Result<Array1<f64>, Box<dyn Error>>;

fn log_pdf(&self, X: Array2<f64>) -> Result<Array1<f64>, Box<dyn Error>> {
Ok(self.pdf(X)?.ln())
}

fn cdf(&self, X: Array2<f64>) -> Result<Array1<f64>, Box<dyn Error>>;
}

0 comments on commit e0a9d6e

Please sign in to comment.