From 738c6173b2a8a8b8e53c44778d132e88fbcc3679 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 2 Aug 2024 12:22:03 +0100 Subject: [PATCH] Python: Factor out distributed SMC logic --- algorithms.py | 123 +++++++++----------------------------------------- 1 file changed, 22 insertions(+), 101 deletions(-) diff --git a/algorithms.py b/algorithms.py index f39d7012e..ce0006aef 100644 --- a/algorithms.py +++ b/algorithms.py @@ -1016,57 +1016,12 @@ def initialise(self, ts): lineage = root_lineages[node] if lineage is not None: seg = lineage.head - left_end = seg.left while seg is not None: self.set_segment_mass(seg) lineage.tail = seg seg = seg.next self.add_lineage(lineage) - if self.model == "smc_k": - for node in range(ts.num_nodes): - lineage = root_lineages[node] - if lineage is not None: - seg = lineage.head - left_end = seg.left - pop = lineage.population - label = lineage.label - right_end = root_segments_tail[node].right - new_hull = self.alloc_hull(left_end, right_end, lineage) - # insert Hull - floor = self.P[pop].hulls_left[label].floor_key(new_hull) - insertion_order = 0 - if floor is not None: - if floor.left == new_hull.left: - insertion_order = floor.insertion_order + 1 - new_hull.insertion_order = insertion_order - self.P[pop].hulls_left[label][new_hull] = -1 - - # initialise the correct coalesceable pairs count - for pop in self.P: - for label, ost_left in enumerate(pop.hulls_left): - avl = ost_left.avl - ost_right = pop.hulls_right[label] - count = 0 - for hull in avl.keys(): - floor = ost_right.floor_key(HullEnd(hull.left)) - num_ending_before_hull = 0 - if floor is not None: - num_ending_before_hull = ost_right.rank[floor] + 1 - num_pairs = count - num_ending_before_hull - avl[hull] = num_pairs - pop.coal_mass_index[label].set_value(hull.index, num_pairs) - # insert HullEnd - hull_end = HullEnd(hull.right) - floor = ost_right.floor_key(hull_end) - insertion_order = 0 - if floor is not None: - if floor.x == hull.right: - insertion_order = floor.insertion_order + 1 - hull_end.insertion_order = insertion_order - ost_right[hull_end] = -1 - count += 1 - def ancestors_remain(self): """ Returns True if the simulation is not finished, i.e., there is some ancestral @@ -1198,6 +1153,15 @@ def store_edge(self, left, right, parent, child): tskit.Edge(left=left, right=right, parent=parent, child=child) ) + def update_lineage_right(self, lineage): + if self.model == "smc_k": + # modify original hull + pop = lineage.population + hull = lineage.hull + old_right = hull.right + hull.right = min(lineage.tail.right + self.hull_offset, self.L) + self.P[pop].reset_hull_right(lineage.label, hull, old_right, hull.right) + def add_lineage(self, lineage): pop = lineage.population self.P[pop].add(lineage, lineage.label) @@ -1208,6 +1172,15 @@ def add_lineage(self, lineage): assert x.lineage == lineage x = x.next + if self.model == "smc_k": + head = lineage.head + assert head.prev is None + hull = self.alloc_hull(head.left, head.right, lineage) + right = lineage.tail.right + hull.right = min(right + self.hull_offset, self.L) + pop = self.P[lineage.population] + pop.add_hull(lineage.label, hull) + def finalise(self): """ Finalises the simulation returns an msprime tree sequence object. @@ -1836,22 +1809,11 @@ def hudson_recombination_event(self, label): left_lineage.tail = x lhs_tail = x + self.update_lineage_right(left_lineage) right_lineage = self.alloc_lineage(alpha, left_lineage.population, label=label) self.set_segment_mass(alpha) self.add_lineage(right_lineage) - if self.model == "smc_k": - # modify original hull - pop = left_lineage.population - lhs_hull = lhs_tail.get_hull() - rhs_right = lhs_hull.right - lhs_hull.right = min(lhs_tail.right + self.hull_offset, self.L) - self.P[pop].reset_hull_right(label, lhs_hull, rhs_right, lhs_hull.right) - - # create hull for alpha - alpha_hull = self.alloc_hull(alpha.left, rhs_right, right_lineage) - self.P[pop].add_hull(label, alpha_hull) - if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0: self.store_node(left_lineage.population, flags=msprime.NODE_IS_RE_EVENT) self.store_arg_edges(lhs_tail) @@ -1890,11 +1852,8 @@ def wiuf_gene_conversion_within_event(self, label): # lbp rbp return None self.num_gc_events += 1 - hull = y.get_hull() - assert (self.model == "smc_k") == (hull is not None) lineage = y.lineage pop = lineage.population - reset_right = -1 # Process left break insert_alpha = True @@ -1913,7 +1872,6 @@ def wiuf_gene_conversion_within_event(self, label): insert_alpha = False else: x.next = None - reset_right = x.right y.prev = None alpha = y tail = x @@ -1935,15 +1893,11 @@ def wiuf_gene_conversion_within_event(self, label): y.right = left_breakpoint self.set_segment_mass(y) tail = y - reset_right = left_breakpoint self.set_segment_mass(alpha) # Find the segment z that the right breakpoint falls in z = alpha - hull_left = z.left - hull_right = -1 while z is not None and right_breakpoint >= z.right: - hull_right = z.right z = z.next head = None @@ -1967,7 +1921,6 @@ def wiuf_gene_conversion_within_event(self, label): z.right = right_breakpoint z.next = None self.set_segment_mass(z) - hull_right = right_breakpoint else: # tail z # ====== @@ -1985,12 +1938,6 @@ def wiuf_gene_conversion_within_event(self, label): tail.next = head head.prev = tail self.set_segment_mass(head) - else: - # rbp lies beyond segment chain, regular recombination logic applies - if insert_alpha and self.model == "smc_k": - assert reset_right > 0 - reset_right = min(reset_right + self.hull_offset, self.L) - self.P[pop].reset_hull_right(label, hull, hull.right, reset_right) # y z # | ========== ... ===== | @@ -2005,12 +1952,8 @@ def wiuf_gene_conversion_within_event(self, label): if new_individual_head is not None: # FIXME when doing the smc_k update lineage.reset_segments() + self.update_lineage_right(lineage) new_lineage = self.alloc_lineage(new_individual_head, pop) - if self.model == "smc_k": - assert hull_left < hull_right - hull_right = min(self.L, hull_right + self.hull_offset) - hull = self.alloc_hull(hull_left, hull_right, new_lineage) - self.P[new_lineage.population].add_hull(new_lineage.label, hull) self.add_lineage(new_lineage) def wiuf_gene_conversion_left_event(self, label): @@ -2042,8 +1985,6 @@ def wiuf_gene_conversion_left_event(self, label): x = y.prev lineage = y.lineage pop = lineage.population - lhs_hull = y.get_hull() - assert (self.model == "smc_k") == (lhs_hull is not None) if y.left < bp: # x y # ===== =====|==== @@ -2061,7 +2002,6 @@ def wiuf_gene_conversion_left_event(self, label): y.next = None y.right = bp self.set_segment_mass(y) - right = y.right else: # x y # ===== | ========= @@ -2075,19 +2015,10 @@ def wiuf_gene_conversion_left_event(self, label): x.next = None y.prev = None alpha = y - right = x.right - - if self.model == "smc_k": - # lhs logic is identical to the lhs recombination event - lhs_old_right = lhs_hull.right - lhs_new_right = min(self.L, right + self.hull_offset) - self.P[pop].reset_hull_right(label, lhs_hull, lhs_old_right, lhs_new_right) - - # rhs - hull = self.alloc_hull(alpha.left, lhs_old_right, alpha) - self.P[pop].add_hull(label, hull) + # FIXME lineage.reset_segments() + self.update_lineage_right(lineage) self.set_segment_mass(alpha) assert alpha.prev is None @@ -2574,16 +2505,6 @@ def insert_merged_lineage( # assert tail == new_lineage.tail self.add_lineage(new_lineage) - if self.model == "smc_k": - merged_head = new_lineage.head - assert merged_head.prev is None - hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage) - while merged_head is not None: - right = merged_head.right - merged_head = merged_head.next - hull.right = min(right + self.hull_offset, self.L) - pop = self.P[new_lineage.population] - pop.add_hull(new_lineage.label, hull) return new_lineage def print_state(self, verify=False):