diff --git a/src/build_tree.rs b/src/build_tree.rs index 7feb195..a8d8193 100644 --- a/src/build_tree.rs +++ b/src/build_tree.rs @@ -1,4 +1,5 @@ use crate::node::Node; +use crate::rate_matrix::RateParam; use crate::Tree; use cxx::let_cxx_string; use ndarray::*; @@ -64,6 +65,7 @@ pub fn vector_to_tree(v: &[usize]) -> Tree { } tree.max_depth = tree.max_treedepth(); + tree.update_rate_matrix_GTR(); tree } @@ -167,6 +169,8 @@ pub fn newick_to_tree(rjstr: String) -> Tree { label_dictionary, changes: HashMap::new(), mutation_lists: Vec::new(), + rate_param: RateParam::default(), + rate_matrix: na::Matrix4::identity(), }; // Add nodes to Tree from parent vector, give correct branch length @@ -182,6 +186,7 @@ pub fn newick_to_tree(rjstr: String) -> Tree { proto_tree.tree_vec = newick_to_vector(&proto_tree.newick(), proto_tree.count_leaves()); proto_tree.max_depth = proto_tree.max_treedepth(); + proto_tree.update_rate_matrix_GTR(); proto_tree } diff --git a/src/hillclimb.rs b/src/hillclimb.rs index 8fe555f..e9dc022 100644 --- a/src/hillclimb.rs +++ b/src/hillclimb.rs @@ -36,7 +36,7 @@ pub fn peturb_vector(v: &[usize], n: usize) -> Vec { impl Tree { // Hill climbing optimisation algorithm - pub fn hillclimb(&mut self, q: &na::Matrix4, iterations: usize) { + pub fn hillclimb(&mut self, iterations: usize) { let mut candidate_vec: Vec = Vec::with_capacity(self.tree_vec.len()); let mut best_vec: Vec = self.tree_vec.clone(); let mut best_likelihood: f64 = self.get_tree_likelihood(); @@ -47,7 +47,7 @@ impl Tree { candidate_vec = peturb_vector(&best_vec, self.tree_vec.len()); println!("new vec: {:?}", candidate_vec); self.update(&candidate_vec); - self.update_likelihood(q); + self.update_likelihood(); new_likelihood = self.get_tree_likelihood(); println!( "Candidate likelihood: {} \n Current likelihood: {}", @@ -62,6 +62,6 @@ impl Tree { } self.update(&best_vec); - self.update_likelihood(q); + self.update_likelihood(); } } diff --git a/src/lib.rs b/src/lib.rs index 4d630f7..b120b12 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ mod tests; mod tree; mod tree_iterators; mod tree_to_newick; +mod rate_matrix; use crate::build_tree::*; use crate::tree::Tree; @@ -20,70 +21,39 @@ pub fn main() { let args = cli_args(); let start = Instant::now(); - // Define rate matrix - let q: na::Matrix4 = na::Matrix4::new( - -1.0, - 1.0 / 3.0, - 1.0 / 3.0, - 1.0 / 3.0, - 1.0 / 3.0, - -1.0, - 1.0 / 3.0, - 1.0 / 3.0, - 1.0 / 3.0, - 1.0 / 3.0, - -1.0, - 1.0 / 3.0, - 1.0 / 3.0, - 1.0 / 3.0, - 1.0 / 3.0, - -1.0, - ); - - fn create_GTR_ratematrix(a: f64, b: f64, c: f64, d: f64, e: f64, f: f64, pv: Vec) -> na::Matrix4 { - // pv = pivec defined as (piA, piC, piG, piT) - let mut q = na::Matrix4::new( - -(a * pv[1] + b * pv[2] + c * pv[3]), - a * pv[1], - b * pv[2], - c * pv[3], - a * pv[0], - -(a * pv[0] + d * pv[2] + e * pv[3]), - d * pv[2], - e * pv[3], - b * pv[0], - d * pv[1], - -(b * pv[0] + d * pv[1] + f * pv[3]), - f * pv[3], - c * pv[0], - e * pv[1], - f * pv[2], - -(c * pv[0] + e * pv[1] * f * pv[2])); - - let mut diag = 0.0; - for i in 0..=3 { - diag -= q[(i, i)] * pv[i]; - } - - q / diag - } - - let q2 = create_GTR_ratematrix(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, vec![0.25, 0.25, 0.25, 0.25]); - // println!("{:?}", q2); + // Trees initialise with a default rate matrix equal to this + // let q: na::Matrix4 = na::Matrix4::new( + // -1.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // -1.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // -1.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // 1.0 / 3.0, + // -1.0, + // ); // let mut tr = vector_to_tree(&random_vector(4)); // tr.add_genetic_data(&String::from("/Users/joel/Downloads/listeria0.aln")); let mut tr = vector_to_tree(&random_vector(27)); tr.add_genetic_data(&args.alignment); - tr.initialise_likelihood(&q); + tr.initialise_likelihood(); println!("{}", tr.get_tree_likelihood()); println!("{:?}", tr.newick()); println!("{:?}", tr.tree_vec); if !args.no_optimise { let start = Instant::now(); - tr.hillclimb(&q, 0); + tr.hillclimb(5); let end = Instant::now(); eprintln!("Done in {}s", end.duration_since(start).as_secs()); diff --git a/src/likelihoods.rs b/src/likelihoods.rs index e78a4cc..eb4f9b0 100644 --- a/src/likelihoods.rs +++ b/src/likelihoods.rs @@ -65,10 +65,10 @@ pub fn base_freq_logse(muta: &Mutation, bf: [f64; 4]) -> f64 { impl Tree { // Updates the genetic likelihood at a given node - pub fn update_node_likelihood(&mut self, index: usize, rate_matrix: &na::Matrix4) { + pub fn update_node_likelihood(&mut self, index: usize) { if let (Some(ch1), Some(ch2)) = self.get_node(index).unwrap().children { - let p1 = na::Matrix::exp(&(rate_matrix * self.get_branchlength(ch1))); - let p2 = na::Matrix::exp(&(rate_matrix * self.get_branchlength(ch2))); + let p1 = na::Matrix::exp(&(self.rate_matrix * self.get_branchlength(ch1))); + let p2 = na::Matrix::exp(&(self.rate_matrix * self.get_branchlength(ch2))); let seq1 = self.mutation_lists.get(ch1).unwrap(); let seq2 = self.mutation_lists.get(ch2).unwrap(); @@ -79,7 +79,7 @@ impl Tree { // Goes through all nodes that have changed and updates genetic likelihood // Used after tree.update() - pub fn update_likelihood(&mut self, rate_matrix: &na::Matrix4) { + pub fn update_likelihood(&mut self) { if self.changes.is_empty() { return; } @@ -98,7 +98,7 @@ impl Tree { // Traverse all nodes at current_depth for node in nodes { - self.update_node_likelihood(node, rate_matrix); + self.update_node_likelihood(node); if current_depth > 0 { // Put parent into HashMap so that they are updated @@ -119,7 +119,7 @@ impl Tree { // Traverses tree in post-order below given node (except leaves), updating likelihood // Used after initial tree constructions to fill in likelihood at all internal nodes - pub fn initialise_likelihood(&mut self, rate_matrix: &na::Matrix4) { + pub fn initialise_likelihood(&mut self) { let nodes: Vec = self .postorder_notips(self.get_root()) .map(|n| n.index) @@ -127,7 +127,7 @@ impl Tree { for node in nodes { // println!("Node: {}", node); - self.update_node_likelihood(node, rate_matrix); + self.update_node_likelihood(node); } } diff --git a/src/rate_matrix.rs b/src/rate_matrix.rs new file mode 100644 index 0000000..a097c08 --- /dev/null +++ b/src/rate_matrix.rs @@ -0,0 +1,42 @@ +use crate::Tree; + +#[derive(Debug)] +pub struct RateParam(pub f64, pub f64, pub f64, pub f64, pub f64, pub f64, pub Vec); + +impl Default for RateParam { + fn default() -> Self { + RateParam(4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0, vec![0.25, 0.25, 0.25, 0.25]) + } +} + +impl Tree { + + pub fn update_rate_param(&mut self, a: f64, b: f64, c: f64, d: f64, e: f64, f: f64, pv: Vec) { + self.rate_param = RateParam(a, b, c, d, e, f, pv); + // self.update_rate_matrix_GTR(); + } + + pub fn update_rate_matrix_GTR(&mut self) { + let RateParam(a, b, c, d, e, f, pv) = &self.rate_param; + // pv = pivec defined as (piA, piC, piG, piT) + let mut q = na::Matrix4::new( + -(a * pv[1] + b * pv[2] + c * pv[3]), + a * pv[1], + b * pv[2], + c * pv[3], + a * pv[0], + -(a * pv[0] + d * pv[2] + e * pv[3]), + d * pv[2], + e * pv[3], + b * pv[0], + d * pv[1], + -(b * pv[0] + d * pv[1] + f * pv[3]), + f * pv[3], + c * pv[0], + e * pv[1], + f * pv[2], + -(c * pv[0] + e * pv[1] + f * pv[2])); + + self.rate_matrix = q; + } +} \ No newline at end of file diff --git a/src/tests.rs b/src/tests.rs index ed04398..0c64b29 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -104,9 +104,9 @@ mod tests { #[test] fn likelihood_internal_consistency_check() { - let q: na::Matrix4 = na::Matrix4::new( - -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, - ); + // let q: na::Matrix4 = na::Matrix4::new( + // -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, + // ); let mut tr = vector_to_tree(&vec![0, 0, 0, 0]); @@ -143,15 +143,15 @@ mod tests { tr.mutation_lists = genetic_data; - tr.initialise_likelihood(&q); + tr.initialise_likelihood(); let old_likelihood = tr.get_tree_likelihood(); tr.update(&vec![0, 0, 0, 1]); - tr.update_likelihood(&q); + tr.update_likelihood(); tr.update(&vec![0, 0, 0, 0]); - tr.update_likelihood(&q); + tr.update_likelihood(); let new_likelihood = tr.get_tree_likelihood(); diff --git a/src/tree.rs b/src/tree.rs index 6f63724..670a771 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,5 +1,6 @@ use crate::mutation::{create_list, Mutation}; use crate::node::Node; +use crate::rate_matrix::RateParam; use crate::vector_to_tree; use needletail::*; use std::collections::HashMap; @@ -12,6 +13,8 @@ pub struct Tree { pub label_dictionary: HashMap, pub changes: HashMap>, pub mutation_lists: Vec>, + pub rate_param: RateParam, + pub rate_matrix: na::Matrix4, } // Tree methods @@ -27,6 +30,8 @@ impl Tree { label_dictionary: HashMap::new(), changes: HashMap::new(), mutation_lists: Vec::with_capacity(n_nodes), + rate_param: RateParam::default(), + rate_matrix: na::Matrix4::identity(), } } @@ -124,4 +129,5 @@ impl Tree { pub fn max_treedepth(&self) -> usize { self.nodes.iter().map(|node| node.depth).max().unwrap_or(0) } + }