-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
542 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[package] | ||
name = "stochastic-rs-ai" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
anyhow = "1.0.88" | ||
candle-core = {version = "0.6.0", features = ["accelerate", "metal"]} | ||
candle-datasets = {version = "0.6.0"} | ||
candle-metal-kernels = "0.6.0" | ||
candle-nn = {version = "0.6.0", features = ["accelerate", "metal"]} | ||
indicatif = "0.17.7" | ||
ndarray = "0.16.1" | ||
ndarray-rand = "0.15.0" | ||
polars = "0.43.1" | ||
polars-io = {version = "0.43.1", features = ["csv"]} | ||
rand_distr = "0.4.3" | ||
stochastic-rs = {version = "0.8.0", path = "../stochastic-rs-core"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
pub mod datasets; | ||
pub mod lstm_model_1_d; | ||
pub mod lstm_model_2_d; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
use anyhow::Result; | ||
use candle_core::{Device, Tensor}; | ||
use candle_datasets::{batcher::IterResult2, Batcher}; | ||
use indicatif::{ProgressBar, ProgressStyle}; | ||
use ndarray::{s, Array1}; | ||
use ndarray_rand::RandomExt; | ||
use rand_distr::Uniform; | ||
use std::vec::IntoIter; | ||
use stochastic_rs::{diffusion::fou::Fou, Sampling}; | ||
|
||
pub fn test_vasicek_1_d( | ||
epoch_size: usize, | ||
batch_size: usize, | ||
n: usize, | ||
device: &Device, | ||
) -> Result<( | ||
Batcher<IterResult2<IntoIter<Result<(Tensor, Tensor), candle_core::Error>>>>, | ||
Vec<f64>, | ||
)> { | ||
let mut paths = Vec::with_capacity(epoch_size); | ||
let mu = 2.8; | ||
let sigma = 1.0; | ||
let thetas = Array1::random(epoch_size, Uniform::new(0.0, 10.0)).to_vec(); | ||
let hursts = Array1::random(epoch_size, Uniform::new(0.01, 0.99)).to_vec(); | ||
let progress_bar = ProgressBar::new(epoch_size as u64); | ||
progress_bar.set_style( | ||
ProgressStyle::with_template( | ||
"{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] ({eta})", | ||
)? | ||
.progress_chars("#>-"), | ||
); | ||
for idx in 0..epoch_size { | ||
let hurst = hursts[idx]; | ||
let theta = thetas[idx]; | ||
let fou = Fou::new(&Fou { | ||
hurst, | ||
mu, | ||
sigma, | ||
theta, | ||
n, | ||
x0: Some(0.0), | ||
t: Some(16.0), | ||
..Default::default() | ||
}); | ||
let mut path = fou.sample(); | ||
let mean = path.mean().unwrap(); | ||
let std = path.std(0.0); | ||
path = (path - mean) / std; | ||
|
||
paths.push(Ok(( | ||
Tensor::from_iter(path, device)?, | ||
Tensor::new(&[thetas[idx]], device)?, | ||
))); | ||
progress_bar.inc(1); | ||
} | ||
progress_bar.finish(); | ||
|
||
let batcher = Batcher::new_r2(paths.into_iter()) | ||
.batch_size(batch_size) | ||
.return_last_incomplete_batch(false); | ||
|
||
Ok((batcher, hursts)) | ||
} | ||
|
||
pub fn test_vasicek_2_d( | ||
epoch_size: usize, | ||
batch_size: usize, | ||
n: usize, | ||
device: &Device, | ||
) -> Result<( | ||
Batcher<IterResult2<IntoIter<Result<(Tensor, Tensor), candle_core::Error>>>>, | ||
Vec<f64>, | ||
)> { | ||
let mut paths = Vec::with_capacity(epoch_size); | ||
let mu = 2.8; | ||
let sigma = 1.0; | ||
let thetas = Array1::random(epoch_size, Uniform::new(0.0, 10.0)).to_vec(); | ||
let hursts = Array1::random(epoch_size, Uniform::new(0.01, 0.99)).to_vec(); | ||
let progress_bar = ProgressBar::new(epoch_size as u64); | ||
progress_bar.set_style( | ||
ProgressStyle::with_template( | ||
"{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] ({eta})", | ||
)? | ||
.progress_chars("#>-"), | ||
); | ||
for idx in 0..epoch_size { | ||
let hurst = hursts[idx]; | ||
let theta = thetas[idx]; | ||
let fou = Fou::new(&Fou { | ||
hurst, | ||
mu, | ||
sigma, | ||
theta, | ||
n, | ||
x0: Some(0.0), | ||
t: Some(16.0), | ||
..Default::default() | ||
}); | ||
let mut path = fou.sample(); | ||
let mean = path.mean().unwrap(); | ||
let std = path.std(0.0); | ||
path = (path - mean) / std; | ||
|
||
let diff = &path.slice(s![1..]) - &path.slice(s![..-1]); | ||
let path = path.slice(s![..-1]); | ||
let paired = path.iter().zip(diff.iter()).collect::<Vec<_>>(); | ||
let paired_tensors = paired | ||
.iter() | ||
.map(|pair| { | ||
let (x, y) = *pair; | ||
Tensor::new(&[*x, *y], device).unwrap() | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
paths.push(Ok(( | ||
Tensor::stack(&paired_tensors, 0)?, | ||
Tensor::new(&[thetas[idx]], device)?, | ||
))); | ||
progress_bar.inc(1); | ||
} | ||
progress_bar.finish(); | ||
|
||
let batcher = Batcher::new_r2(paths.into_iter()) | ||
.batch_size(batch_size) | ||
.return_last_incomplete_batch(false); | ||
|
||
Ok((batcher, hursts)) | ||
} |
Oops, something went wrong.