Skip to content

Commit

Permalink
Python: Factor out distributed SMC logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 2, 2024
1 parent 590a9d5 commit 738c617
Showing 1 changed file with 22 additions and 101 deletions.
123 changes: 22 additions & 101 deletions algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,57 +1016,12 @@ def initialise(self, ts):
lineage = root_lineages[node]
if lineage is not None:
seg = lineage.head
left_end = seg.left
while seg is not None:
self.set_segment_mass(seg)
lineage.tail = seg
seg = seg.next
self.add_lineage(lineage)

if self.model == "smc_k":
for node in range(ts.num_nodes):
lineage = root_lineages[node]
if lineage is not None:
seg = lineage.head
left_end = seg.left
pop = lineage.population
label = lineage.label
right_end = root_segments_tail[node].right
new_hull = self.alloc_hull(left_end, right_end, lineage)
# insert Hull
floor = self.P[pop].hulls_left[label].floor_key(new_hull)
insertion_order = 0
if floor is not None:
if floor.left == new_hull.left:
insertion_order = floor.insertion_order + 1
new_hull.insertion_order = insertion_order
self.P[pop].hulls_left[label][new_hull] = -1

# initialise the correct coalesceable pairs count
for pop in self.P:
for label, ost_left in enumerate(pop.hulls_left):
avl = ost_left.avl
ost_right = pop.hulls_right[label]
count = 0
for hull in avl.keys():
floor = ost_right.floor_key(HullEnd(hull.left))
num_ending_before_hull = 0
if floor is not None:
num_ending_before_hull = ost_right.rank[floor] + 1
num_pairs = count - num_ending_before_hull
avl[hull] = num_pairs
pop.coal_mass_index[label].set_value(hull.index, num_pairs)
# insert HullEnd
hull_end = HullEnd(hull.right)
floor = ost_right.floor_key(hull_end)
insertion_order = 0
if floor is not None:
if floor.x == hull.right:
insertion_order = floor.insertion_order + 1
hull_end.insertion_order = insertion_order
ost_right[hull_end] = -1
count += 1

def ancestors_remain(self):
"""
Returns True if the simulation is not finished, i.e., there is some ancestral
Expand Down Expand Up @@ -1198,6 +1153,15 @@ def store_edge(self, left, right, parent, child):
tskit.Edge(left=left, right=right, parent=parent, child=child)
)

def update_lineage_right(self, lineage):
if self.model == "smc_k":
# modify original hull
pop = lineage.population
hull = lineage.hull
old_right = hull.right
hull.right = min(lineage.tail.right + self.hull_offset, self.L)
self.P[pop].reset_hull_right(lineage.label, hull, old_right, hull.right)

def add_lineage(self, lineage):
pop = lineage.population
self.P[pop].add(lineage, lineage.label)
Expand All @@ -1208,6 +1172,15 @@ def add_lineage(self, lineage):
assert x.lineage == lineage
x = x.next

if self.model == "smc_k":
head = lineage.head
assert head.prev is None
hull = self.alloc_hull(head.left, head.right, lineage)
right = lineage.tail.right
hull.right = min(right + self.hull_offset, self.L)
pop = self.P[lineage.population]
pop.add_hull(lineage.label, hull)

def finalise(self):
"""
Finalises the simulation returns an msprime tree sequence object.
Expand Down Expand Up @@ -1836,22 +1809,11 @@ def hudson_recombination_event(self, label):
left_lineage.tail = x
lhs_tail = x

self.update_lineage_right(left_lineage)
right_lineage = self.alloc_lineage(alpha, left_lineage.population, label=label)
self.set_segment_mass(alpha)
self.add_lineage(right_lineage)

if self.model == "smc_k":
# modify original hull
pop = left_lineage.population
lhs_hull = lhs_tail.get_hull()
rhs_right = lhs_hull.right
lhs_hull.right = min(lhs_tail.right + self.hull_offset, self.L)
self.P[pop].reset_hull_right(label, lhs_hull, rhs_right, lhs_hull.right)

# create hull for alpha
alpha_hull = self.alloc_hull(alpha.left, rhs_right, right_lineage)
self.P[pop].add_hull(label, alpha_hull)

if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0:
self.store_node(left_lineage.population, flags=msprime.NODE_IS_RE_EVENT)
self.store_arg_edges(lhs_tail)
Expand Down Expand Up @@ -1890,11 +1852,8 @@ def wiuf_gene_conversion_within_event(self, label):
# lbp rbp
return None
self.num_gc_events += 1
hull = y.get_hull()
assert (self.model == "smc_k") == (hull is not None)
lineage = y.lineage
pop = lineage.population
reset_right = -1

# Process left break
insert_alpha = True
Expand All @@ -1913,7 +1872,6 @@ def wiuf_gene_conversion_within_event(self, label):
insert_alpha = False
else:
x.next = None
reset_right = x.right
y.prev = None
alpha = y
tail = x
Expand All @@ -1935,15 +1893,11 @@ def wiuf_gene_conversion_within_event(self, label):
y.right = left_breakpoint
self.set_segment_mass(y)
tail = y
reset_right = left_breakpoint
self.set_segment_mass(alpha)

# Find the segment z that the right breakpoint falls in
z = alpha
hull_left = z.left
hull_right = -1
while z is not None and right_breakpoint >= z.right:
hull_right = z.right
z = z.next

head = None
Expand All @@ -1967,7 +1921,6 @@ def wiuf_gene_conversion_within_event(self, label):
z.right = right_breakpoint
z.next = None
self.set_segment_mass(z)
hull_right = right_breakpoint
else:
# tail z
# ======
Expand All @@ -1985,12 +1938,6 @@ def wiuf_gene_conversion_within_event(self, label):
tail.next = head
head.prev = tail
self.set_segment_mass(head)
else:
# rbp lies beyond segment chain, regular recombination logic applies
if insert_alpha and self.model == "smc_k":
assert reset_right > 0
reset_right = min(reset_right + self.hull_offset, self.L)
self.P[pop].reset_hull_right(label, hull, hull.right, reset_right)

# y z
# | ========== ... ===== |
Expand All @@ -2005,12 +1952,8 @@ def wiuf_gene_conversion_within_event(self, label):
if new_individual_head is not None:
# FIXME when doing the smc_k update
lineage.reset_segments()
self.update_lineage_right(lineage)
new_lineage = self.alloc_lineage(new_individual_head, pop)
if self.model == "smc_k":
assert hull_left < hull_right
hull_right = min(self.L, hull_right + self.hull_offset)
hull = self.alloc_hull(hull_left, hull_right, new_lineage)
self.P[new_lineage.population].add_hull(new_lineage.label, hull)
self.add_lineage(new_lineage)

def wiuf_gene_conversion_left_event(self, label):
Expand Down Expand Up @@ -2042,8 +1985,6 @@ def wiuf_gene_conversion_left_event(self, label):
x = y.prev
lineage = y.lineage
pop = lineage.population
lhs_hull = y.get_hull()
assert (self.model == "smc_k") == (lhs_hull is not None)
if y.left < bp:
# x y
# ===== =====|====
Expand All @@ -2061,7 +2002,6 @@ def wiuf_gene_conversion_left_event(self, label):
y.next = None
y.right = bp
self.set_segment_mass(y)
right = y.right
else:
# x y
# ===== | =========
Expand All @@ -2075,19 +2015,10 @@ def wiuf_gene_conversion_left_event(self, label):
x.next = None
y.prev = None
alpha = y
right = x.right

if self.model == "smc_k":
# lhs logic is identical to the lhs recombination event
lhs_old_right = lhs_hull.right
lhs_new_right = min(self.L, right + self.hull_offset)
self.P[pop].reset_hull_right(label, lhs_hull, lhs_old_right, lhs_new_right)

# rhs
hull = self.alloc_hull(alpha.left, lhs_old_right, alpha)
self.P[pop].add_hull(label, hull)

# FIXME
lineage.reset_segments()
self.update_lineage_right(lineage)

self.set_segment_mass(alpha)
assert alpha.prev is None
Expand Down Expand Up @@ -2574,16 +2505,6 @@ def insert_merged_lineage(
# assert tail == new_lineage.tail
self.add_lineage(new_lineage)

if self.model == "smc_k":
merged_head = new_lineage.head
assert merged_head.prev is None
hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage)
while merged_head is not None:
right = merged_head.right
merged_head = merged_head.next
hull.right = min(right + self.hull_offset, self.L)
pop = self.P[new_lineage.population]
pop.add_hull(new_lineage.label, hull)
return new_lineage

def print_state(self, verify=False):
Expand Down

0 comments on commit 738c617

Please sign in to comment.