Skip to content

Commit

Permalink
Rewrote move apply method for TreeState as an implementation of an It…
Browse files Browse the repository at this point in the history
…erator
  • Loading branch information
jhellewell14 committed Dec 4, 2024
1 parent cca6fe3 commit 476a28c
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 79 deletions.
26 changes: 17 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod treestate;
use rate_matrix::RateMatrix;
use state_data::create_dummy_statedata;
use topology::Topology;
use treestate::TreeStateIter;
use treestate::{TreeState, hillclimb_accept};

use crate::newick_to_vec::*;
Expand Down Expand Up @@ -53,17 +54,24 @@ pub fn main() {

if !args.no_optimise {
let start = Instant::now();
for i in 0..5 {
println!{"Step {}", i};
// let new_v = random_vector(27);
// let mv = ExactMove{target_vector: new_v};
let mv = ChildSwap{};
// let mv = PeturbVec{n: 1};
ts.apply_move(mv, hillclimb_accept, &mut gen_data);
let mv = ChildSwap{};
let mut tsi = TreeStateIter{ts, move_fn: mv, accept_fn: hillclimb_accept, gen_data: &mut gen_data};

let res = tsi.nth(100).unwrap();



// for i in 0..100 {
// // println!{"Step {}", i};
// // let new_v = random_vector(27);
// // let mv = ExactMove{target_vector: new_v};
// let mv = ChildSwap{};
// // let mv = PeturbVec{n: 1};
// ts.apply_move(mv, hillclimb_accept, &mut gen_data);

}
// }
let end = Instant::now();
println!("New likelihood: {:?}", ts.likelihood(&gen_data));
println!("New likelihood: {:?}", res.likelihood(&gen_data));
eprintln!("Done in {}s", end.duration_since(start).as_secs());
eprintln!("Done in {}ms", end.duration_since(start).as_millis());
}
Expand Down
17 changes: 13 additions & 4 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod tests {
use crate::topology::Topology;
use crate::moves::ExactMove;
use crate::create_dummy_gendata;
use crate::treestate::{TreeState, always_accept};
use crate::treestate::{always_accept, TreeMove, TreeState, TreeStateIter};

#[test]
fn check_topology_build_manual() {
Expand Down Expand Up @@ -99,10 +99,16 @@ mod tests {
let vecs: Vec<Vec<usize>> = vec![vec![0, 0, 0, 0], vec![0, 0, 1, 0], vec![0, 0, 1, 2], vec![0, 0, 1, 1]];
let n = ts.top.nodes.len();

let nomv = ExactMove{target_vector: vec![0, 0, 1, 0]};
let mut ti = TreeStateIter{ts, move_fn: nomv, accept_fn: always_accept, gen_data: &mut gen_data};

for vec in vecs {
let t_2 = Topology::from_vec(&vec);
let mv = ExactMove{target_vector: vec};
ts.apply_move(mv, always_accept, &mut gen_data);
ti.move_fn = mv;
let ts = ti.nth(0).unwrap();

// ts.apply_move(mv, always_accept, &mut gen_data);
// t_1.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix());

for i in 0..n {
Expand All @@ -123,11 +129,14 @@ mod tests {
let old_likelihood = ts.likelihood(&gen_data);

let mv = ExactMove{target_vector: vec![0, 0, 0, 1]};
ts.apply_move(mv, always_accept, &mut gen_data);
let mut tsi = TreeStateIter{ts, move_fn: mv, accept_fn: always_accept, gen_data: &mut gen_data};
// ts.apply_move(mv, always_accept, &mut gen_data);
ts = tsi.nth(0).unwrap();
// t.apply_move(mv, always_accept, &mut gen_data, &p.get_matrix());

let mv = ExactMove{target_vector: vec![0, 0, 0, 0]};
ts.apply_move(mv, always_accept, &mut gen_data);
tsi = TreeStateIter{ts, move_fn: mv, accept_fn: always_accept, gen_data: &mut gen_data};
ts = tsi.nth(0).unwrap();

let new_likelihood = ts.likelihood(&gen_data);

Expand Down
219 changes: 153 additions & 66 deletions src/treestate.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::topology::NodeTuple;
use crate::Topology;
use crate::RateMatrix;
use crate::{base_freq_logse, matrix_exp, slice_data, node_likelihood, BF_DEFAULT};
Expand Down Expand Up @@ -28,53 +29,133 @@ impl<R: RateMatrix> TreeState<R> {
}


pub fn apply_move<T: TreeMove<R>>(&mut self,
move_fn: T,
accept_fn: fn(&f64, &f64) -> bool,
gen_data: &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>) -> () {
// pub fn apply_move<T: TreeMove<R>>(&mut self,
// move_fn: T,
// accept_fn: fn(&f64, &f64) -> bool,
// gen_data: &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>) -> TreeState<R> {

// if self.ll.is_none() {
// self.ll = Some(self.likelihood(gen_data));
// }
// let old_ll = self.ll.unwrap();

// let rate_mat = self.mat.get_matrix();
// let new_ts = move_fn.generate(self);

// // If move did nothing, keep old TreeState
// if new_ts.changed_nodes.is_none() {
// return *self
// }

// // Do minimal likelihood updates (and push new values into HashMap temporarily)
// let nodes = new_ts.top.changes_iter_notips(new_ts.changed_nodes.unwrap());
// let mut temp_likelihoods: HashMap<usize, ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>> = HashMap::new();

// for node in nodes {
// // check if in HM
// // println!("{:?}", node);
// let lchild = node.get_lchild().unwrap();
// let rchild = node.get_rchild().unwrap();
// let seql: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;
// let seqr: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;

// match (temp_likelihoods.contains_key(&lchild), temp_likelihoods.contains_key(&rchild)) {
// (true, true) => {
// seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
// seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
// },
// (true, false) => {
// seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
// seqr = slice_data(rchild, gen_data);
// },
// (false, true) => {
// seql = slice_data(lchild, gen_data);
// seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
// },
// (false, false) => {
// seql = slice_data(lchild, gen_data);
// seqr = slice_data(rchild, gen_data);
// },
// };

// let node_ll = node_likelihood(seql, seqr,
// &matrix_exp(&rate_mat, new_ts.top.nodes[lchild].get_branchlen()),
// &matrix_exp(&rate_mat, new_ts.top.nodes[rchild].get_branchlen()));

// temp_likelihoods.insert(node.get_id(), node_ll);
// }

// // Calculate whole new topology likelihood at root
// let new_ll = temp_likelihoods
// .get(&new_ts.top.get_root().get_id())
// .unwrap()
// .rows()
// .into_iter()
// .fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT));

// // Likelihood decision rule
// if accept_fn(&old_ll, &new_ll) {
// // Drain hashmap into gen_data
// for (i, ll_data) in temp_likelihoods.drain() {
// gen_data.slice_mut(s![i, .., ..]).assign(&ll_data);
// }
// // Update Topology
// self.top.nodes = new_ts.top.nodes;
// self.top.tree_vec = new_ts.top.tree_vec;
// self.mat = new_ts.mat;
// self.ll = Some(new_ll);
// };
// *self
// }
}

pub fn hillclimb_accept(old_ll: &f64, new_ll: &f64) -> bool {
new_ll.gt(old_ll)
}

pub fn always_accept(_old_ll: &f64, _new_ll: &f64) -> bool {
true
}


pub struct TreeStateIter<'a, R: RateMatrix, M: TreeMove<R>> {
pub ts: TreeState<R>,
pub move_fn: M,
pub accept_fn: fn(&f64, &f64) -> bool,
pub gen_data: &'a mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>,
}

if self.ll.is_none() {
self.ll = Some(self.likelihood(gen_data));

impl<'a, R: RateMatrix + 'a, M: TreeMove<R>> Iterator for TreeStateIter<'a, R, M> {
type Item = TreeState<R>;
fn next(&mut self) -> Option<Self::Item> {

if self.ts.ll.is_none() {
self.ts.ll = Some(self.ts.likelihood(self.gen_data));
}
let old_ll = self.ll.unwrap();
let old_ll = self.ts.ll.unwrap();

let rate_mat = self.mat.get_matrix();
let new_ts = move_fn.generate(self);
let rate_mat = self.ts.mat.get_matrix();
let mut new_ts = self.move_fn.generate(&self.ts);

// If move did nothing, keep old TreeState
if new_ts.changed_nodes.is_none() {
return ()
return Some(new_ts)
}

// Do minimal likelihood updates (and push new values into HashMap temporarily)
let nodes = new_ts.top.changes_iter_notips(new_ts.changed_nodes.unwrap());
let changed_nodes = new_ts.changed_nodes.clone().unwrap();
let nodes = new_ts.top.changes_iter_notips(changed_nodes);
let mut temp_likelihoods: HashMap<usize, ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>> = HashMap::new();

for node in nodes {
// check if in HM
// println!("{:?}", node);
let lchild = node.get_lchild().unwrap();
let rchild = node.get_rchild().unwrap();
let seql: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;
let seqr: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;

match (temp_likelihoods.contains_key(&lchild), temp_likelihoods.contains_key(&rchild)) {
(true, true) => {
seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
},
(true, false) => {
seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
seqr = slice_data(rchild, gen_data);
},
(false, true) => {
seql = slice_data(lchild, gen_data);
seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
},
(false, false) => {
seql = slice_data(lchild, gen_data);
seqr = slice_data(rchild, gen_data);
},
let (lchild, rchild) = (node.get_lchild().unwrap(), node.get_rchild().unwrap());

let seql = match temp_likelihoods.contains_key(&lchild) {
true => {temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..])},
false => {slice_data(lchild, self.gen_data)},
};
let seqr = match temp_likelihoods.contains_key(&rchild) {
true => {temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..])},
false => {slice_data(rchild, self.gen_data)},
};

let node_ll = node_likelihood(seql, seqr,
Expand All @@ -85,35 +166,41 @@ impl<R: RateMatrix> TreeState<R> {
}

// Calculate whole new topology likelihood at root
let new_ll = temp_likelihoods
.get(&new_ts.top.get_root().get_id())
.unwrap()
.rows()
.into_iter()
.fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT));

// Likelihood decision rule
if accept_fn(&old_ll, &new_ll) {
// Drain hashmap into gen_data
for (i, ll_data) in temp_likelihoods.drain() {
gen_data.slice_mut(s![i, .., ..]).assign(&ll_data);
}
// Update Topology
self.top.nodes = new_ts.top.nodes;
self.top.tree_vec = new_ts.top.tree_vec;
self.mat = new_ts.mat;
self.ll = Some(new_ll);
};

}
}


pub fn hillclimb_accept(old_ll: &f64, new_ll: &f64) -> bool {
new_ll.gt(old_ll)
}
let new_ll = temp_likelihoods
.get(&new_ts.top.get_root().get_id())
.unwrap()
.rows()
.into_iter()
.fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT));

if (self.accept_fn)(&old_ll, &new_ll) {
// Drain hashmap into gen_data
for (i, ll_data) in temp_likelihoods.drain() {
self.gen_data.slice_mut(s![i, .., ..]).assign(&ll_data);
}
new_ts.ll = Some(new_ll);
// Return new TreeState
return Some(new_ts)
} else {
// Return old TreeState
let top = Topology{
nodes: self.ts.top.nodes.clone(),
tree_vec: self.ts.top.tree_vec.clone(),
likelihood: None,
};

pub fn always_accept(_old_ll: &f64, _new_ll: &f64) -> bool {
true
return Some(TreeState{
top,
mat: self.ts.mat,
ll: Some(old_ll),
changed_nodes: None,
})
}
}
}

// impl<'a, R: RateMatrix, M: TreeMove<R>> TreeState<R> {
// pub fn moveiter() -> TreeStateIter<'a, R, M> {
// todo!()
// }
// }

0 comments on commit 476a28c

Please sign in to comment.