Skip to content

Commit

Permalink
Implement crmh_ with fewer txops
Browse files Browse the repository at this point in the history
  • Loading branch information
David Norris authored and David Norris committed Mar 24, 2021
1 parent 968ad6e commit 4858476
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 35 deletions.
Binary file modified data/viola_dtp.rda
Binary file not shown.
78 changes: 76 additions & 2 deletions src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,87 @@ fn crmh_(a: &f64, // NB: *NOT* vectorized on a
w: &[f64],
s: f64,
b: i32) -> f64 {
let mut v = a.powi(b) * (-0.5*(a/s).powi(2)).exp();
// Below was the original, very short computation.
// I count N+1 transcendental ops, compared with 2D+1 in new implementation.
// If an average of 4 patients are enrolled per dose level, this translates
// to a savings of (4D - 2D)/(4D+1) ~ 50% of effort.
// TODO: By precalculating the X[d].ln()'s in the calling function, the cost
// of these txops amortizes to zero! Thus we can bring this routine
// down to just D+1 transcendendal ops, saving 75% of effort!
// I might even dare to call these ln()'s the EXOSKELETON.
/*
let mut v = a.powi(b) * (-0.5*(a/s).powi(2)).exp(); // 1 exp()
if v.is_infinite() { return 0.0; }
for i in 0 .. y.len() {
let p_i = x[i].powf(a.exp()); // 'power model' CRM
let p_i = x[i].powf(a.exp()); // 'power model' CRM // N powf()'s
v = v * if y[i] == 0 { 1.0 - w[i] * p_i } else { p_i };
}
v
*/

// I'm going to undertake some refactoring toward an algorithm
// that reduces transcendental ops to a bare minimum. The key
// opportunity in this regard lies in COMPUTING x[i]^exp_a_
// JUST ONCE PER DOSE LEVEL, rather than once per patient.
// Some bookkeeping preliminaries are needed to make this possible.
// As part of my initial refactoring, I will do this initially here
// in this routine; but this computation is one that ought to be
// bubbled upward as far as possible.

// To enable a dose-wise (rather than patent-wise) computation
// with the x[i]^exp_a_ values, we require a vector of unique dose
// levels thus far tried.
let mut X: Vec<f64> = x.to_vec();
// NB: -^ doses expressed on a prior-prob scale!
X.sort_by(|a, b| a.partial_cmp(b).unwrap());
X.dedup();

// It also proves useful to tally dose-wise sums of toxicities:
let mut Y: [f64; 10] = [0.0; 10]; // Static allocation sufficient for 10 doses
for i in 0 .. y.len() {
match X.iter().position(|&p| p==x[i]) {
Some(d) => {
if y[i]==1 {
Y[d] += 1.0; // tally toxicities
}
}
None => {}
}
}

if a > &709.0 { return 0.0 } // "Not gonna exp() it; wouldn't be prudent."
let exp_a_ = a.exp(); // 1 exp()

// The objective function can factored as: v = vconst * log_vtox.exp() * v_non.
// (Let's call log_vtox the 'toxic term' and v_non the 'non-tox factor'.)
let log_vconst = -0.5 * (a/s).powi(2);
if log_vconst > 709.0 { return 0.0; } // saves time, avoids returning Inf*0=NaN

let mut log_vtox = 0.0;
for d in 0 .. X.len() { // TODO: Use an iterator-based expression?
log_vtox = log_vtox + Y[d]*X[d].ln(); // D ln()'s [TODO: pass]
}
log_vtox = log_vtox * exp_a_; // this completes the toxic term

// The non-tox factor is more difficult, because ln does not
// commute with any operations inside (1 - w * x^exp_a_) 8^(.
// So we must iterate over the patients. BUT... this can be
// done without racking up per-patient transcendental ops,
// provided we collect our factor multiplicatively (not by logs).
// The key point is that the following computation costs only
// X.len() transcendental ops -- just 1 powf() per dose tried.
let mut v_non = 1.0; // to build non-tox factor by straight multiplication
for d in 0 .. X.len() { // For each dose X[d] that was tried
let p_d = X[d].powf(exp_a_); // .. compute p_d = X[d]^exp_a_ JUST ONCE
for i in 0 .. y.len() { // .. then scan over all patients
if y[i] == 0 && x[i] == X[d] { // .. without tox, assigned to X[d]
v_non = v_non * (1.0 - w[i]*p_d); // .. and multiply on (1-w*p_d).
}
} // D powf()'s
}

let vfast = a.powi(b) * (log_vconst + log_vtox).exp() * v_non; // 1 exp()
vfast
}

// Vectorize crmh1 on the 'a' parameter
Expand Down
48 changes: 15 additions & 33 deletions tests/testthat/test-viola-dtp.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,21 @@ test_that("calculate_dtps() yields same VIOLA result as dtpcrm's version", {
}
}

timings <- list(
dtpcrm = system.time(
old <- dtpcrm::calculate_dtps(
next_dose = start.dose.level,
cohort_sizes = rep(3, 7),
prior = prior.DLT,
target = target.DLT,
stop_func = stop_func,
scale = sqrt(prior.var),
no_skip_esc = TRUE,
no_skip_deesc = FALSE,
global_coherent_esc = TRUE)
)
, newdtp = system.time(
new <- calculate_dtps(
next_dose = start.dose.level,
cohort_sizes = rep(3, 7),
dose_func = applied_crm, # i.e., precautionary::applied_crm
prior = prior.DLT,
target = target.DLT,
stop_func = stop_func,
scale = sqrt(prior.var),
no_skip_esc = TRUE,
no_skip_deesc = FALSE,
global_coherent_esc = TRUE,
impl = 'rusti')
)
)
new <- calculate_dtps(
next_dose = start.dose.level,
cohort_sizes = rep(3, 7),
dose_func = applied_crm, # i.e., precautionary::applied_crm
prior = prior.DLT,
target = target.DLT,
stop_func = stop_func,
scale = sqrt(prior.var),
no_skip_esc = TRUE,
no_skip_deesc = FALSE,
global_coherent_esc = TRUE,
impl = 'rusti')

with(timings, {
speedup_message(newdtp, dtpcrm)
})
data(viola_dtp) # saved for comparison

rownames(new) <- rownames(old) # don't compare rownames
expect_equal(old, new)
rownames(new) <- rownames(viola_dtp) # don't compare rownames
expect_equal(viola_dtp, new)
})

0 comments on commit 4858476

Please sign in to comment.