Skip to content

Commit

Permalink
optimise dcor fn (#96)
Browse files Browse the repository at this point in the history
- compute half the matrix
- row/col means are equal since it's mirrored
  • Loading branch information
chungg authored Sep 3, 2024
1 parent 08744f7 commit acd7569
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 55 deletions.
22 changes: 10 additions & 12 deletions benches/traquer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ fn criterion_benchmark(c: &mut Criterion) {
});
c.bench_function("sig-trend-zigzag", |b| {
b.iter(|| {
black_box(trend::zigzag(&stats.close, &stats.close, Some(10.0)).collect::<Vec<f64>>())
black_box(trend::zigzag(&stats.high, &stats.low, Some(10.0)).collect::<Vec<f64>>())
})
});
c.bench_function("sig-trend-chandelier", |b| {
Expand Down Expand Up @@ -586,33 +586,31 @@ fn criterion_benchmark(c: &mut Criterion) {
});

c.bench_function("correlation-pcc", |b| {
b.iter(|| black_box(correlation::pcc(&stats.close, &stats.close, 16).collect::<Vec<_>>()))
b.iter(|| black_box(correlation::pcc(&stats.open, &stats.close, 16).collect::<Vec<_>>()))
});
c.bench_function("correlation-rsq", |b| {
b.iter(|| black_box(correlation::rsq(&stats.close, &stats.close, 16).collect::<Vec<_>>()))
b.iter(|| black_box(correlation::rsq(&stats.open, &stats.close, 16).collect::<Vec<_>>()))
});
c.bench_function("correlation-beta", |b| {
b.iter(|| black_box(correlation::beta(&stats.close, &stats.close, 16).collect::<Vec<_>>()))
b.iter(|| black_box(correlation::beta(&stats.open, &stats.close, 16).collect::<Vec<_>>()))
});
c.bench_function("correlation-rsc", |b| {
b.iter(|| black_box(correlation::rsc(&stats.close, &stats.close).collect::<Vec<_>>()))
b.iter(|| black_box(correlation::rsc(&stats.open, &stats.close).collect::<Vec<_>>()))
});
c.bench_function("correlation-perf", |b| {
b.iter(|| black_box(correlation::perf(&stats.close, &stats.close, 16).collect::<Vec<_>>()))
b.iter(|| black_box(correlation::perf(&stats.open, &stats.close, 16).collect::<Vec<_>>()))
});
c.bench_function("correlation-srcc", |b| {
b.iter(|| black_box(correlation::srcc(&stats.close, &stats.close, 16).collect::<Vec<_>>()))
b.iter(|| black_box(correlation::srcc(&stats.open, &stats.close, 16).collect::<Vec<_>>()))
});
c.bench_function("correlation-krcc", |b| {
b.iter(|| black_box(correlation::krcc(&stats.close, &stats.close, 16).collect::<Vec<_>>()))
b.iter(|| black_box(correlation::krcc(&stats.open, &stats.close, 16).collect::<Vec<_>>()))
});
c.bench_function("correlation-hoeffd", |b| {
b.iter(|| {
black_box(correlation::hoeffd(&stats.close, &stats.close, 16).collect::<Vec<_>>())
})
b.iter(|| black_box(correlation::hoeffd(&stats.open, &stats.close, 16).collect::<Vec<_>>()))
});
c.bench_function("correlation-dcor", |b| {
b.iter(|| black_box(correlation::dcor(&stats.close, &stats.close, 16).collect::<Vec<_>>()))
b.iter(|| black_box(correlation::dcor(&stats.open, &stats.close, 16).collect::<Vec<_>>()))
});

c.bench_function("stats-dist-variance", |b| {
Expand Down
51 changes: 23 additions & 28 deletions src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,14 @@ pub fn dcor<'a, T: ToPrimitive>(
let n = x.len();
// flattened NxN distance matrix, where [x_00..x0j, ... ,x_i0..x_ij]
let mut matrix = vec![0.0; n * n];
let mut matrix_sum = 0_f64;
for i in 0..n {
for j in 0..n {
matrix[(i * n) + j] = (x[i].to_f64().unwrap() - x[j].to_f64().unwrap()).abs();
for j in i..n {
let idx = (i * n) + j;
let mirror_idx = idx / n + (idx % n) * n;
matrix[idx] = (x[i].to_f64().unwrap() - x[j].to_f64().unwrap()).abs();
matrix[mirror_idx] = matrix[idx];
matrix_sum += matrix[idx] * 2.;
}
}

Expand All @@ -461,18 +466,18 @@ pub fn dcor<'a, T: ToPrimitive>(
.step_by(n)
.map(|i| matrix[i..i + n].iter().sum::<f64>() / n as f64)
.collect();
let col_means: Vec<f64> = (0..n)
.map(|i| {
(i..matrix.len())
.step_by(n)
.fold(0.0, |acc, j| acc + matrix[j])
/ n as f64
})
.collect();
let matrix_mean: f64 = matrix.iter().sum::<f64>() / (n * n) as f64;
for i in 0..n {
for j in 0..n {
matrix[(i * n) + j] += -row_means[i] - col_means[j] + matrix_mean;
let col_means = &row_means;
// undo the double count of mirror line rather than add if clause above
matrix_sum -= (0..matrix.len())
.step_by(n + 1)
.fold(0.0, |acc, x| acc + matrix[x]);
let matrix_mean: f64 = matrix_sum / (n * n) as f64;
for (i, row_mean) in row_means.iter().enumerate() {
for (j, col_mean) in col_means.iter().enumerate().skip(i) {
let idx = (i * n) + j;
let mirror_idx = idx / n + (idx % n) * n;
matrix[idx] += -row_mean - col_mean + matrix_mean;
matrix[mirror_idx] = matrix[idx];
}
}
matrix
Expand All @@ -492,20 +497,10 @@ pub fn dcor<'a, T: ToPrimitive>(
.sum::<f64>()
/ window.pow(2) as f64)
.sqrt();
let dvar_x = (centred_x
.iter()
.zip(&centred_x)
.map(|(a, b)| a * b)
.sum::<f64>()
/ window.pow(2) as f64)
.sqrt();
let dvar_y = (centred_y
.iter()
.zip(&centred_y)
.map(|(a, b)| a * b)
.sum::<f64>()
/ window.pow(2) as f64)
.sqrt();
let dvar_x =
(centred_x.iter().map(|a| a * a).sum::<f64>() / window.pow(2) as f64).sqrt();
let dvar_y =
(centred_y.iter().map(|a| a * a).sum::<f64>() / window.pow(2) as f64).sqrt();

dcov / (dvar_x * dvar_y).sqrt()
}),
Expand Down
76 changes: 61 additions & 15 deletions tests/correlation_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,25 +304,71 @@ fn test_dcor() {
assert_eq!(ln_ret1.len(), result.len());
assert_eq!(
vec![
0.39390614319365574,
0.39845847602318907,
0.4012684752778961,
0.4532725521408008,
0.5335623994772698,
0.5899262972738498,
0.6886450053961184,
0.7578847633388898,
0.3939061431936555,
0.39845847602318923,
0.4012684752778963,
0.45327255214080076,
0.53356239947727,
0.5899262972738499,
0.6886450053961186,
0.7578847633388897,
0.7748853182014356,
0.7670646492585647,
0.8058436110412499,
0.7670646492585645,
0.8058436110412497,
0.818014456822133,
0.8064793069072755,
0.687572965245447,
0.6362198562198043,
0.5628963827860937,
0.8064793069072757,
0.6875729652454472,
0.6362198562198045,
0.5628963827860936,
0.5699710776508861,
0.44892542156220505
0.4489254215622051
],
result[16 - 1..]
);
let x: Vec<f64> = (-20..=20).map(|x| x as f64).collect();
let y = [
0.370149038,
0.288083480,
-0.200846331,
-0.511259482,
-0.539743970,
0.775007267,
-0.714104606,
-0.533221882,
-1.354827474,
-0.748386194,
0.815451687,
-0.646052383,
-0.652485422,
-0.574213246,
-0.152848918,
0.477079479,
0.787394877,
0.808239865,
0.681665474,
0.965605080,
1.187225618,
0.947623657,
1.042522240,
0.998102383,
0.673526316,
0.241418652,
-0.004380181,
-0.134488734,
-0.058152176,
-0.378003031,
1.197777442,
-0.594539181,
-1.050862113,
-1.101031114,
-0.690866613,
0.987395489,
-0.434915046,
-0.529045544,
-0.154515749,
0.169430594,
0.275099501,
];
let result = correlation::dcor(&x, &y, y.len()).collect::<Vec<_>>();
assert_eq!(0.30282497848099954, result[result.len() - 1]);
}

0 comments on commit acd7569

Please sign in to comment.