Skip to content

Commit

Permalink
The Tree struct now has a RateParam struct that can be used to update…
Browse files Browse the repository at this point in the history
… a rate_matrix. Trees initialise with a default RateParams. Likelihood functions no longer require an external rate matrix and now use the Tree's rate_matrix
  • Loading branch information
jhellewell14 committed Sep 25, 2024
1 parent 0f2387c commit 362e3db
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 68 deletions.
5 changes: 5 additions & 0 deletions src/build_tree.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::node::Node;
use crate::rate_matrix::RateParam;
use crate::Tree;
use cxx::let_cxx_string;
use ndarray::*;
Expand Down Expand Up @@ -64,6 +65,7 @@ pub fn vector_to_tree(v: &[usize]) -> Tree {
}

tree.max_depth = tree.max_treedepth();
tree.update_rate_matrix_GTR();

tree
}
Expand Down Expand Up @@ -167,6 +169,8 @@ pub fn newick_to_tree(rjstr: String) -> Tree {
label_dictionary,
changes: HashMap::new(),
mutation_lists: Vec::new(),
rate_param: RateParam::default(),
rate_matrix: na::Matrix4::identity(),
};

// Add nodes to Tree from parent vector, give correct branch length
Expand All @@ -182,6 +186,7 @@ pub fn newick_to_tree(rjstr: String) -> Tree {

proto_tree.tree_vec = newick_to_vector(&proto_tree.newick(), proto_tree.count_leaves());
proto_tree.max_depth = proto_tree.max_treedepth();
proto_tree.update_rate_matrix_GTR();

proto_tree
}
Expand Down
6 changes: 3 additions & 3 deletions src/hillclimb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub fn peturb_vector(v: &[usize], n: usize) -> Vec<usize> {

impl Tree {
// Hill climbing optimisation algorithm
pub fn hillclimb(&mut self, q: &na::Matrix4<f64>, iterations: usize) {
pub fn hillclimb(&mut self, iterations: usize) {
let mut candidate_vec: Vec<usize> = Vec::with_capacity(self.tree_vec.len());
let mut best_vec: Vec<usize> = self.tree_vec.clone();
let mut best_likelihood: f64 = self.get_tree_likelihood();
Expand All @@ -47,7 +47,7 @@ impl Tree {
candidate_vec = peturb_vector(&best_vec, self.tree_vec.len());
println!("new vec: {:?}", candidate_vec);
self.update(&candidate_vec);
self.update_likelihood(q);
self.update_likelihood();
new_likelihood = self.get_tree_likelihood();
println!(
"Candidate likelihood: {} \n Current likelihood: {}",
Expand All @@ -62,6 +62,6 @@ impl Tree {
}

self.update(&best_vec);
self.update_likelihood(q);
self.update_likelihood();
}
}
74 changes: 22 additions & 52 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod tests;
mod tree;
mod tree_iterators;
mod tree_to_newick;
mod rate_matrix;

use crate::build_tree::*;
use crate::tree::Tree;
Expand All @@ -20,70 +21,39 @@ pub fn main() {
let args = cli_args();
let start = Instant::now();

// Define rate matrix
let q: na::Matrix4<f64> = na::Matrix4::new(
-1.0,
1.0 / 3.0,
1.0 / 3.0,
1.0 / 3.0,
1.0 / 3.0,
-1.0,
1.0 / 3.0,
1.0 / 3.0,
1.0 / 3.0,
1.0 / 3.0,
-1.0,
1.0 / 3.0,
1.0 / 3.0,
1.0 / 3.0,
1.0 / 3.0,
-1.0,
);

fn create_GTR_ratematrix(a: f64, b: f64, c: f64, d: f64, e: f64, f: f64, pv: Vec<f64>) -> na::Matrix4<f64> {
// pv = pivec defined as (piA, piC, piG, piT)
let mut q = na::Matrix4::new(
-(a * pv[1] + b * pv[2] + c * pv[3]),
a * pv[1],
b * pv[2],
c * pv[3],
a * pv[0],
-(a * pv[0] + d * pv[2] + e * pv[3]),
d * pv[2],
e * pv[3],
b * pv[0],
d * pv[1],
-(b * pv[0] + d * pv[1] + f * pv[3]),
f * pv[3],
c * pv[0],
e * pv[1],
f * pv[2],
-(c * pv[0] + e * pv[1] * f * pv[2]));

let mut diag = 0.0;
for i in 0..=3 {
diag -= q[(i, i)] * pv[i];
}

q / diag
}

let q2 = create_GTR_ratematrix(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, vec![0.25, 0.25, 0.25, 0.25]);
// println!("{:?}", q2);
// Trees initialise with a default rate matrix equal to this
// let q: na::Matrix4<f64> = na::Matrix4::new(
// -1.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// -1.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// -1.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// 1.0 / 3.0,
// -1.0,
// );

// let mut tr = vector_to_tree(&random_vector(4));
// tr.add_genetic_data(&String::from("/Users/joel/Downloads/listeria0.aln"));
let mut tr = vector_to_tree(&random_vector(27));
tr.add_genetic_data(&args.alignment);

tr.initialise_likelihood(&q);
tr.initialise_likelihood();
println!("{}", tr.get_tree_likelihood());
println!("{:?}", tr.newick());
println!("{:?}", tr.tree_vec);

if !args.no_optimise {
let start = Instant::now();
tr.hillclimb(&q, 0);
tr.hillclimb(5);
let end = Instant::now();

eprintln!("Done in {}s", end.duration_since(start).as_secs());
Expand Down
14 changes: 7 additions & 7 deletions src/likelihoods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ pub fn base_freq_logse(muta: &Mutation, bf: [f64; 4]) -> f64 {

impl Tree {
// Updates the genetic likelihood at a given node
pub fn update_node_likelihood(&mut self, index: usize, rate_matrix: &na::Matrix4<f64>) {
pub fn update_node_likelihood(&mut self, index: usize) {
if let (Some(ch1), Some(ch2)) = self.get_node(index).unwrap().children {
let p1 = na::Matrix::exp(&(rate_matrix * self.get_branchlength(ch1)));
let p2 = na::Matrix::exp(&(rate_matrix * self.get_branchlength(ch2)));
let p1 = na::Matrix::exp(&(self.rate_matrix * self.get_branchlength(ch1)));
let p2 = na::Matrix::exp(&(self.rate_matrix * self.get_branchlength(ch2)));

let seq1 = self.mutation_lists.get(ch1).unwrap();
let seq2 = self.mutation_lists.get(ch2).unwrap();
Expand All @@ -79,7 +79,7 @@ impl Tree {

// Goes through all nodes that have changed and updates genetic likelihood
// Used after tree.update()
pub fn update_likelihood(&mut self, rate_matrix: &na::Matrix4<f64>) {
pub fn update_likelihood(&mut self) {
if self.changes.is_empty() {
return;
}
Expand All @@ -98,7 +98,7 @@ impl Tree {

// Traverse all nodes at current_depth
for node in nodes {
self.update_node_likelihood(node, rate_matrix);
self.update_node_likelihood(node);

if current_depth > 0 {
// Put parent into HashMap so that they are updated
Expand All @@ -119,15 +119,15 @@ impl Tree {

// Traverses tree in post-order below given node (except leaves), updating likelihood
// Used after initial tree constructions to fill in likelihood at all internal nodes
pub fn initialise_likelihood(&mut self, rate_matrix: &na::Matrix4<f64>) {
pub fn initialise_likelihood(&mut self) {
let nodes: Vec<usize> = self
.postorder_notips(self.get_root())
.map(|n| n.index)
.collect();

for node in nodes {
// println!("Node: {}", node);
self.update_node_likelihood(node, rate_matrix);
self.update_node_likelihood(node);
}
}

Expand Down
42 changes: 42 additions & 0 deletions src/rate_matrix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use crate::Tree;

#[derive(Debug)]
pub struct RateParam(pub f64, pub f64, pub f64, pub f64, pub f64, pub f64, pub Vec<f64>);

impl Default for RateParam {
fn default() -> Self {
RateParam(4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0, vec![0.25, 0.25, 0.25, 0.25])
}
}

impl Tree {

pub fn update_rate_param(&mut self, a: f64, b: f64, c: f64, d: f64, e: f64, f: f64, pv: Vec<f64>) {
self.rate_param = RateParam(a, b, c, d, e, f, pv);
// self.update_rate_matrix_GTR();
}

pub fn update_rate_matrix_GTR(&mut self) {
let RateParam(a, b, c, d, e, f, pv) = &self.rate_param;
// pv = pivec defined as (piA, piC, piG, piT)
let mut q = na::Matrix4::new(
-(a * pv[1] + b * pv[2] + c * pv[3]),
a * pv[1],
b * pv[2],
c * pv[3],
a * pv[0],
-(a * pv[0] + d * pv[2] + e * pv[3]),
d * pv[2],
e * pv[3],
b * pv[0],
d * pv[1],
-(b * pv[0] + d * pv[1] + f * pv[3]),
f * pv[3],
c * pv[0],
e * pv[1],
f * pv[2],
-(c * pv[0] + e * pv[1] + f * pv[2]));

self.rate_matrix = q;
}
}
12 changes: 6 additions & 6 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ mod tests {

#[test]
fn likelihood_internal_consistency_check() {
let q: na::Matrix4<f64> = na::Matrix4::new(
-3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0,
);
// let q: na::Matrix4<f64> = na::Matrix4::new(
// -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0,
// );

let mut tr = vector_to_tree(&vec![0, 0, 0, 0]);

Expand Down Expand Up @@ -143,15 +143,15 @@ mod tests {

tr.mutation_lists = genetic_data;

tr.initialise_likelihood(&q);
tr.initialise_likelihood();

let old_likelihood = tr.get_tree_likelihood();

tr.update(&vec![0, 0, 0, 1]);
tr.update_likelihood(&q);
tr.update_likelihood();

tr.update(&vec![0, 0, 0, 0]);
tr.update_likelihood(&q);
tr.update_likelihood();

let new_likelihood = tr.get_tree_likelihood();

Expand Down
6 changes: 6 additions & 0 deletions src/tree.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::mutation::{create_list, Mutation};
use crate::node::Node;
use crate::rate_matrix::RateParam;
use crate::vector_to_tree;
use needletail::*;
use std::collections::HashMap;
Expand All @@ -12,6 +13,8 @@ pub struct Tree {
pub label_dictionary: HashMap<usize, String>,
pub changes: HashMap<usize, Vec<usize>>,
pub mutation_lists: Vec<Vec<Mutation>>,
pub rate_param: RateParam,
pub rate_matrix: na::Matrix4<f64>,
}

// Tree methods
Expand All @@ -27,6 +30,8 @@ impl Tree {
label_dictionary: HashMap::new(),
changes: HashMap::new(),
mutation_lists: Vec::with_capacity(n_nodes),
rate_param: RateParam::default(),
rate_matrix: na::Matrix4::identity(),
}
}

Expand Down Expand Up @@ -124,4 +129,5 @@ impl Tree {
pub fn max_treedepth(&self) -> usize {
self.nodes.iter().map(|node| node.depth).max().unwrap_or(0)
}

}

0 comments on commit 362e3db

Please sign in to comment.