diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 55cfdb7fd0..3de3a795ab 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -150,6 +150,12 @@ def node_values(self): d[u] = mapping[v] return d + @property + def matrix_size(self): + if self.match_all_nodes: + return self.ts.num_nodes + return self.ts.num_samples + def print_state(self): print("LsHMM state") print("match_all_nodes =", self.match_all_nodes) @@ -435,12 +441,18 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state): self.update_probabilities(site, haplotype_state) - # d1 = self.node_values() + d1 = self.node_values() # print("PRE") - # self.print_state() + # # self.print_state() self.compress() - # d2 = self.node_values() - # assert d1 == d2 + d2 = self.node_values() + if self.match_all_nodes: + # We only get an exact match on all_nodes. For samples we just + # guarantee that the *samples* have the same value + assert d1 == d2 + else: + for u in self.ts.samples(): + assert d1[u] == d2[u] # print("AFTER COMPRESS") # self.print_state() s = self.compute_normalisation_factor() @@ -489,7 +501,7 @@ def initialise(self, value): self.T.append(ValueTransition(tree_node=u, value=value)) def run(self, h): - n = self.ts.num_samples + n = self.matrix_size self.initialise(1 / n) while self.tree.next(): self.update_tree() @@ -553,8 +565,9 @@ def compute_normalisation_factor(self): return s def compute_next_probability(self, site_id, p_last, is_match, node): + n = self.matrix_size + # print("NEXT PROBA:", site_id, n) rho = self.rho[site_id] - n = self.ts.num_samples p_e = self.compute_emission_proba(site_id, is_match) p_t = p_last * (1 - rho) + rho / n return p_t * p_e @@ -584,7 +597,7 @@ def process_site(self, site, haplotype_state, s): # compress self.compress() b_last_sum = self.compute_normalisation_factor() - n = self.ts.num_samples + n = self.matrix_size rho = self.rho[site.id] for st in self.T: if st.tree_node != tskit.NULL: @@ -624,7 +637,7 @@ def compute_normalisation_factor(self): def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] - n = self.ts.num_samples + n = self.matrix_size p_no_recomb = p_last * (1 - rho + rho / n) p_recomb = rho / n @@ -668,7 +681,6 @@ class CompressedMatrix: def __init__(self, ts): self.ts = ts self.num_sites = ts.num_sites - self.num_samples = ts.num_samples self.value_transitions = [None for _ in range(self.num_sites)] self.normalisation_factor = np.zeros(self.num_sites) @@ -697,14 +709,14 @@ def num_transitions(self): def get_site(self, site): return self.value_transitions[site] - def decode(self): + def decode_samples(self): """ Decodes the tree encoding of the values into an explicit matrix. """ sample_index_map = np.zeros(self.ts.num_nodes, dtype=int) - 1 sample_index_map[self.ts.samples()] = np.arange(self.ts.num_samples) - A = np.zeros((self.num_sites, self.num_samples)) + A = np.zeros((self.num_sites, self.ts.num_samples)) for tree in self.ts.trees(): for site in tree.sites(): for node, value in self.value_transitions[site.id]: @@ -713,6 +725,22 @@ def decode(self): A[site.id, j] = value return A + def decode_nodes(self): + # print("decode nodes") + A = np.zeros((self.num_sites, self.ts.num_nodes)) + for tree in self.ts.trees(): + for site in tree.sites(): + for node, value in self.value_transitions[site.id]: + # print("Decode:", site.id, node, value) + for u in tree.nodes(node): + A[site.id, u] = value + return A + + def decode(self, all_nodes=False): + if all_nodes: + return self.decode_nodes() + return self.decode_samples() + class ViterbiMatrix(CompressedMatrix): """ @@ -1330,7 +1358,7 @@ def check_forward_matrix( scale_mutation_based_on_n_alleles=False, match_all_nodes=match_all_nodes, ) - F2 = cm.decode() + F2 = cm.decode(match_all_nodes) ll_tree = np.sum(np.log10(cm.normalisation_factor)) if compare_lshmm: @@ -1549,6 +1577,7 @@ def test_match_sample(self, u, h): ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True ) nt.assert_array_equal([u] * 7, path) + fm = check_forward_matrix( ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True ) @@ -1558,45 +1587,36 @@ def test_match_sample(self, u, h): check_fb_matrix_integrity(fm, bm) -def check_fb_matrix_integrity(fm, bm): +def check_fb_matrix_integrity(fm, bm, match_all_nodes=False): """ Validate properties of the forward and backward matrices. """ - F = fm.decode() - B = bm.decode() + F = fm.decode(match_all_nodes) + B = bm.decode(match_all_nodes) assert F.shape == B.shape for j in range(len(F)): s = np.sum(B[j] * F[j]) + # print(j, s) np.testing.assert_allclose(s, 1) -def check_fb_matrices(ts, h): - fm = check_forward_matrix(ts, h) - bm = check_backward_matrix(ts, h, fm) - check_fb_matrix_integrity(fm, bm) +def check_fb_matrices(ts, h, match_all_nodes=False, **kwargs): + fm = check_forward_matrix(ts, h, match_all_nodes=match_all_nodes, **kwargs) + bm = check_backward_matrix(ts, h, fm, match_all_nodes=match_all_nodes, **kwargs) + check_fb_matrix_integrity(fm, bm, match_all_nodes=match_all_nodes) def validate_match_all_nodes(ts, h, expected_path): - # path = check_viterbi( - # ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False - # ) - # nt.assert_array_equal(expected_path, path) - fm = check_forward_matrix( + # START HERE: most of this is working except for Viterbi + path = check_viterbi( ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False ) - F = fm.decode() - # print(cm.decode()) - # cm.print_state() - bm = check_backward_matrix( - ts, h, fm, match_all_nodes=True, compare_lib=False, compare_lshmm=False - ) - print("sites = ", ts.num_sites) - B = bm.decode() - print(F) - for j in range(ts.num_sites): - print(j, np.sum(B[j] * F[j])) + # print("Path = ", path) + nt.assert_array_equal(expected_path, path) - # sum(B[variant,:] * F[variant,:]) = 1 + check_fb_matrices( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) class TestSingleBalancedTreeAllNodesExample: @@ -1692,19 +1712,18 @@ def ts(): ("h", "expected_path"), [ # Just samples - ([1, 0, 0, 0, 0, 1, 1], [0] * 7), - # ([0, 1, 0, 0, 1, 1, 0], [1] * 7), - # ([0, 0, 1, 0, 1, 1, 0], [2] * 7), - # ([0, 0, 0, 1, 0, 0, 1], [3] * 7), - # # Match root - # ([0, 0, 0, 0, 0, 0, 0], [7] * 7), + # fails on viterbi + # ([1, 0, 0, 0, 0, 1, 1], [0] * 7), + ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # Match single internal node + ([0, 0, 0, 0, 1, 1, 0], [4] * 7), + # Match root + ([0, 0, 0, 0, 0, 0, 0], [7] * 7), ], ) def test_match_all_nodes(self, h, expected_path): - # print() - # print(self.ts().draw_text()) - # with open("tmp.svg", "w") as f: - # f.write(self.ts().draw_svg()) validate_match_all_nodes(self.ts(), h, expected_path) @pytest.mark.parametrize(