diff --git a/src/lib.rs b/src/lib.rs index bc43682..82ccafa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ mod treestate; use rate_matrix::RateMatrix; use state_data::create_dummy_statedata; use topology::Topology; +use treestate::TreeStateIter; use treestate::{TreeState, hillclimb_accept}; use crate::newick_to_vec::*; @@ -53,17 +54,24 @@ pub fn main() { if !args.no_optimise { let start = Instant::now(); - for i in 0..5 { - println!{"Step {}", i}; - // let new_v = random_vector(27); - // let mv = ExactMove{target_vector: new_v}; - let mv = ChildSwap{}; - // let mv = PeturbVec{n: 1}; - ts.apply_move(mv, hillclimb_accept, &mut gen_data); + let mv = ChildSwap{}; + let mut tsi = TreeStateIter{ts, move_fn: mv, accept_fn: hillclimb_accept, gen_data: &mut gen_data}; + + let res = tsi.nth(100).unwrap(); + + + + // for i in 0..100 { + // // println!{"Step {}", i}; + // // let new_v = random_vector(27); + // // let mv = ExactMove{target_vector: new_v}; + // let mv = ChildSwap{}; + // // let mv = PeturbVec{n: 1}; + // ts.apply_move(mv, hillclimb_accept, &mut gen_data); - } + // } let end = Instant::now(); - println!("New likelihood: {:?}", ts.likelihood(&gen_data)); + println!("New likelihood: {:?}", res.likelihood(&gen_data)); eprintln!("Done in {}s", end.duration_since(start).as_secs()); eprintln!("Done in {}ms", end.duration_since(start).as_millis()); } diff --git a/src/tests.rs b/src/tests.rs index ea382c6..3799382 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -5,7 +5,7 @@ mod tests { use crate::topology::Topology; use crate::moves::ExactMove; use crate::create_dummy_gendata; - use crate::treestate::{TreeState, always_accept}; + use crate::treestate::{always_accept, TreeMove, TreeState, TreeStateIter}; #[test] fn check_topology_build_manual() { @@ -99,10 +99,16 @@ mod tests { let vecs: Vec> = vec![vec![0, 0, 0, 0], vec![0, 0, 1, 0], vec![0, 0, 1, 2], vec![0, 0, 1, 1]]; let n = ts.top.nodes.len(); + let nomv = ExactMove{target_vector: vec![0, 0, 1, 0]}; + let mut ti = TreeStateIter{ts, move_fn: nomv, accept_fn: always_accept, gen_data: &mut gen_data}; + for vec in vecs { let t_2 = Topology::from_vec(&vec); let mv = ExactMove{target_vector: vec}; - ts.apply_move(mv, always_accept, &mut gen_data); + ti.move_fn = mv; + let ts = ti.nth(0).unwrap(); + + // ts.apply_move(mv, always_accept, &mut gen_data); // t_1.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix()); for i in 0..n { @@ -123,11 +129,14 @@ mod tests { let old_likelihood = ts.likelihood(&gen_data); let mv = ExactMove{target_vector: vec![0, 0, 0, 1]}; - ts.apply_move(mv, always_accept, &mut gen_data); + let mut tsi = TreeStateIter{ts, move_fn: mv, accept_fn: always_accept, gen_data: &mut gen_data}; + // ts.apply_move(mv, always_accept, &mut gen_data); + ts = tsi.nth(0).unwrap(); // t.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix()); let mv = ExactMove{target_vector: vec![0, 0, 0, 0]}; - ts.apply_move(mv, always_accept, &mut gen_data); + tsi = TreeStateIter{ts, move_fn: mv, accept_fn: always_accept, gen_data: &mut gen_data}; + ts = tsi.nth(0).unwrap(); let new_likelihood = ts.likelihood(&gen_data); diff --git a/src/treestate.rs b/src/treestate.rs index 162798b..0251256 100644 --- a/src/treestate.rs +++ b/src/treestate.rs @@ -1,3 +1,4 @@ +use crate::topology::NodeTuple; use crate::Topology; use crate::RateMatrix; use crate::{base_freq_logse, matrix_exp, slice_data, node_likelihood, BF_DEFAULT}; @@ -28,53 +29,133 @@ impl TreeState { } - pub fn apply_move>(&mut self, - move_fn: T, - accept_fn: fn(&f64, &f64) -> bool, - gen_data: &mut ndarray::ArrayBase, ndarray::Dim<[usize; 3]>>) -> () { +// pub fn apply_move>(&mut self, +// move_fn: T, +// accept_fn: fn(&f64, &f64) -> bool, +// gen_data: &mut ndarray::ArrayBase, ndarray::Dim<[usize; 3]>>) -> TreeState { + +// if self.ll.is_none() { +// self.ll = Some(self.likelihood(gen_data)); +// } +// let old_ll = self.ll.unwrap(); + +// let rate_mat = self.mat.get_matrix(); +// let new_ts = move_fn.generate(self); + +// // If move did nothing, keep old TreeState +// if new_ts.changed_nodes.is_none() { +// return *self +// } + +// // Do minimal likelihood updates (and push new values into HashMap temporarily) +// let nodes = new_ts.top.changes_iter_notips(new_ts.changed_nodes.unwrap()); +// let mut temp_likelihoods: HashMap, ndarray::Dim<[usize; 2]>>> = HashMap::new(); + +// for node in nodes { +// // check if in HM +// // println!("{:?}", node); +// let lchild = node.get_lchild().unwrap(); +// let rchild = node.get_rchild().unwrap(); +// let seql: ndarray::ArrayBase, ndarray::Dim<[usize; 2]>>; +// let seqr: ndarray::ArrayBase, ndarray::Dim<[usize; 2]>>; + +// match (temp_likelihoods.contains_key(&lchild), temp_likelihoods.contains_key(&rchild)) { +// (true, true) => { +// seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]); +// seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]); +// }, +// (true, false) => { +// seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]); +// seqr = slice_data(rchild, gen_data); +// }, +// (false, true) => { +// seql = slice_data(lchild, gen_data); +// seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]); +// }, +// (false, false) => { +// seql = slice_data(lchild, gen_data); +// seqr = slice_data(rchild, gen_data); +// }, +// }; + +// let node_ll = node_likelihood(seql, seqr, +// &matrix_exp(&rate_mat, new_ts.top.nodes[lchild].get_branchlen()), +// &matrix_exp(&rate_mat, new_ts.top.nodes[rchild].get_branchlen())); + +// temp_likelihoods.insert(node.get_id(), node_ll); +// } + +// // Calculate whole new topology likelihood at root +// let new_ll = temp_likelihoods +// .get(&new_ts.top.get_root().get_id()) +// .unwrap() +// .rows() +// .into_iter() +// .fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT)); + +// // Likelihood decision rule +// if accept_fn(&old_ll, &new_ll) { +// // Drain hashmap into gen_data +// for (i, ll_data) in temp_likelihoods.drain() { +// gen_data.slice_mut(s![i, .., ..]).assign(&ll_data); +// } +// // Update Topology +// self.top.nodes = new_ts.top.nodes; +// self.top.tree_vec = new_ts.top.tree_vec; +// self.mat = new_ts.mat; +// self.ll = Some(new_ll); +// }; +// *self +// } +} + +pub fn hillclimb_accept(old_ll: &f64, new_ll: &f64) -> bool { + new_ll.gt(old_ll) +} + +pub fn always_accept(_old_ll: &f64, _new_ll: &f64) -> bool { + true +} + + +pub struct TreeStateIter<'a, R: RateMatrix, M: TreeMove> { + pub ts: TreeState, + pub move_fn: M, + pub accept_fn: fn(&f64, &f64) -> bool, + pub gen_data: &'a mut ndarray::ArrayBase, ndarray::Dim<[usize; 3]>>, +} - if self.ll.is_none() { - self.ll = Some(self.likelihood(gen_data)); + +impl<'a, R: RateMatrix + 'a, M: TreeMove> Iterator for TreeStateIter<'a, R, M> { + type Item = TreeState; + fn next(&mut self) -> Option { + + if self.ts.ll.is_none() { + self.ts.ll = Some(self.ts.likelihood(self.gen_data)); } - let old_ll = self.ll.unwrap(); + let old_ll = self.ts.ll.unwrap(); - let rate_mat = self.mat.get_matrix(); - let new_ts = move_fn.generate(self); + let rate_mat = self.ts.mat.get_matrix(); + let mut new_ts = self.move_fn.generate(&self.ts); - // If move did nothing, keep old TreeState if new_ts.changed_nodes.is_none() { - return () + return Some(new_ts) } - // Do minimal likelihood updates (and push new values into HashMap temporarily) - let nodes = new_ts.top.changes_iter_notips(new_ts.changed_nodes.unwrap()); + let changed_nodes = new_ts.changed_nodes.clone().unwrap(); + let nodes = new_ts.top.changes_iter_notips(changed_nodes); let mut temp_likelihoods: HashMap, ndarray::Dim<[usize; 2]>>> = HashMap::new(); for node in nodes { - // check if in HM - // println!("{:?}", node); - let lchild = node.get_lchild().unwrap(); - let rchild = node.get_rchild().unwrap(); - let seql: ndarray::ArrayBase, ndarray::Dim<[usize; 2]>>; - let seqr: ndarray::ArrayBase, ndarray::Dim<[usize; 2]>>; - - match (temp_likelihoods.contains_key(&lchild), temp_likelihoods.contains_key(&rchild)) { - (true, true) => { - seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]); - seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]); - }, - (true, false) => { - seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]); - seqr = slice_data(rchild, gen_data); - }, - (false, true) => { - seql = slice_data(lchild, gen_data); - seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]); - }, - (false, false) => { - seql = slice_data(lchild, gen_data); - seqr = slice_data(rchild, gen_data); - }, + let (lchild, rchild) = (node.get_lchild().unwrap(), node.get_rchild().unwrap()); + + let seql = match temp_likelihoods.contains_key(&lchild) { + true => {temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..])}, + false => {slice_data(lchild, self.gen_data)}, + }; + let seqr = match temp_likelihoods.contains_key(&rchild) { + true => {temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..])}, + false => {slice_data(rchild, self.gen_data)}, }; let node_ll = node_likelihood(seql, seqr, @@ -85,35 +166,41 @@ impl TreeState { } // Calculate whole new topology likelihood at root - let new_ll = temp_likelihoods - .get(&new_ts.top.get_root().get_id()) - .unwrap() - .rows() - .into_iter() - .fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT)); - - // Likelihood decision rule - if accept_fn(&old_ll, &new_ll) { - // Drain hashmap into gen_data - for (i, ll_data) in temp_likelihoods.drain() { - gen_data.slice_mut(s![i, .., ..]).assign(&ll_data); - } - // Update Topology - self.top.nodes = new_ts.top.nodes; - self.top.tree_vec = new_ts.top.tree_vec; - self.mat = new_ts.mat; - self.ll = Some(new_ll); - }; - -} -} - - -pub fn hillclimb_accept(old_ll: &f64, new_ll: &f64) -> bool { - new_ll.gt(old_ll) -} + let new_ll = temp_likelihoods + .get(&new_ts.top.get_root().get_id()) + .unwrap() + .rows() + .into_iter() + .fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT)); + + if (self.accept_fn)(&old_ll, &new_ll) { + // Drain hashmap into gen_data + for (i, ll_data) in temp_likelihoods.drain() { + self.gen_data.slice_mut(s![i, .., ..]).assign(&ll_data); + } + new_ts.ll = Some(new_ll); + // Return new TreeState + return Some(new_ts) + } else { + // Return old TreeState + let top = Topology{ + nodes: self.ts.top.nodes.clone(), + tree_vec: self.ts.top.tree_vec.clone(), + likelihood: None, + }; -pub fn always_accept(_old_ll: &f64, _new_ll: &f64) -> bool { - true + return Some(TreeState{ + top, + mat: self.ts.mat, + ll: Some(old_ll), + changed_nodes: None, + }) + } + } } +// impl<'a, R: RateMatrix, M: TreeMove> TreeState { +// pub fn moveiter() -> TreeStateIter<'a, R, M> { +// todo!() +// } +// } \ No newline at end of file