From c559f64713ffaec36289f0c276e2a39aede3e71f Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 23 Jun 2023 16:41:09 +0100 Subject: [PATCH] Add almost-working version with Python simplifier --- python/tests/simplify.py | 20 ++-- python/tests/test_forward_sims.py | 174 +++++++++++++++++++++++++++--- 2 files changed, 171 insertions(+), 23 deletions(-) diff --git a/python/tests/simplify.py b/python/tests/simplify.py index 02e0482cca..c009ee8fa4 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2022 Tskit Developers +# Copyright (c) 2019-2023 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -114,6 +114,8 @@ def __init__( filter_nodes=True, update_sample_flags=True, ): + # DELETE ME + self.parent_edges_processed = 0 self.ts = ts self.n = len(sample) self.reduce_to_site_topology = reduce_to_site_topology @@ -397,6 +399,7 @@ def process_parent_edges(self, edges): """ Process all of the edges for a given parent. """ + self.parent_edges_processed += len(edges) assert len({e.parent for e in edges}) == 1 parent = edges[0].parent S = [] @@ -535,6 +538,14 @@ def insert_input_roots(self): offset += 1 self.sort_offset = offset + def finalise(self): + if self.keep_input_roots: + self.insert_input_roots() + self.finalise_sites() + self.finalise_references() + if self.sort_offset != -1: + self.tables.sort(edge_start=self.sort_offset) + def simplify(self): if self.ts.num_edges > 0: all_edges = list(self.ts.edges()) @@ -545,12 +556,7 @@ def simplify(self): edges = [] edges.append(e) self.process_parent_edges(edges) - if self.keep_input_roots: - self.insert_input_roots() - self.finalise_sites() - self.finalise_references() - if self.sort_offset != -1: - self.tables.sort(edge_start=self.sort_offset) + self.finalise() ts = self.tables.tree_sequence() return ts, self.node_id_map diff --git a/python/tests/test_forward_sims.py b/python/tests/test_forward_sims.py index 17d91fd7ab..54ea9039fd 100644 --- a/python/tests/test_forward_sims.py +++ b/python/tests/test_forward_sims.py @@ -22,28 +22,147 @@ """ Python implementation of the low-level supporting code for forward simulations. """ -import collections +import itertools import random import numpy as np import pytest import tskit +from tests import simplify -def simplify_with_buffer(tables, parent_buffer, samples, verbose): - # Pretend this was done efficiently internally without any sorting - # by creating a simplifier object and adding the ancstry for the - # new parents appropriately before flushing through the rest of the - # edges. - for parent, edges in parent_buffer.items(): - for left, right, child in edges: +class BirthBuffer: + def __init__(self): + self.edges = {} + self.parents = [] + + def add_edge(self, left, right, parent, child): + if parent not in self.edges: + self.parents.append(parent) + self.edges[parent] = [] + self.edges[parent].append((child, left, right)) + + def clear(self): + self.edges = {} + self.parents = [] + + def __str__(self): + s = "" + for parent in self.parents: + for child, left, right in self.edges[parent]: + s += f"{parent}\t{child}\t{left:0.3f}\t{right:0.3f}\n" + return s + + +def add_younger_edges_to_simplifier(simplifier, t, tables, edge_offset): + parent_edges = [] + while ( + edge_offset < len(tables.edges) + and tables.nodes.time[tables.edges.parent[edge_offset]] <= t + ): + print("edge offset = ", edge_offset) + if len(parent_edges) == 0: + last_parent = tables.edges.parent[edge_offset] + else: + last_parent = parent_edges[-1].parent + if last_parent == tables.edges.parent[edge_offset]: + parent_edges.append(tables.edges[edge_offset]) + else: + print( + "Flush ", tables.nodes.time[parent_edges[-1].parent], len(parent_edges) + ) + simplifier.process_parent_edges(parent_edges) + parent_edges = [] + edge_offset += 1 + if len(parent_edges) > 0: + print("Flush ", tables.nodes.time[parent_edges[-1].parent], len(parent_edges)) + simplifier.process_parent_edges(parent_edges) + return edge_offset + + +def simplify_with_births(tables, births, alive, verbose): + total_edges = len(tables.edges) + for edges in births.edges.values(): + total_edges += len(edges) + if verbose > 0: + print("Simplify with births") + # print(births) + print("total_input edges = ", total_edges) + print("alive = ", alive) + print("\ttable edges:", len(tables.edges)) + print("\ttable nodes:", len(tables.nodes)) + + simplifier = simplify.Simplifier(tables.tree_sequence(), alive) + nodes_time = tables.nodes.time + # This should be almost sorted, because + parent_time = nodes_time[births.parents] + index = np.argsort(parent_time) + print(index) + offset = 0 + for parent in np.array(births.parents)[index]: + offset = add_younger_edges_to_simplifier( + simplifier, nodes_time[parent], tables, offset + ) + edges = [ + tskit.Edge(left, right, parent, child) + for child, left, right in sorted(births.edges[parent]) + ] + # print("Adding parent from time", nodes_time[parent], len(edges)) + # print("edges = ", edges) + simplifier.process_parent_edges(edges) + # simplifier.print_state() + + # FIXME should probably reuse the add_younger_edges_to_simplifier function + # for this - doesn't quite seem to work though + for _, edges in itertools.groupby(tables.edges[offset:], lambda e: e.parent): + edges = list(edges) + simplifier.process_parent_edges(edges) + + simplifier.check_state() + assert simplifier.parent_edges_processed == total_edges + # if simplifier.parent_edges_processed != total_edges: + # print("HERE!!!!", total_edges) + simplifier.finalise() + + tables.nodes.replace_with(simplifier.tables.nodes) + tables.edges.replace_with(simplifier.tables.edges) + + # This is needed because we call .tree_sequence here and later. + # Can be removed is we change the Simplifier to take a set of + # tables which it modifies, like the C version. + tables.drop_index() + # Just to check + tables.tree_sequence() + + births.clear() + # Add back all the edges with an alive parent to the buffer, so that + # we store them contiguously + keep = np.ones(len(tables.edges), dtype=bool) + for u in alive: + u = simplifier.node_id_map[u] + for e in np.where(tables.edges.parent == u)[0]: + keep[e] = False + edge = tables.edges[e] + # print(edge) + births.add_edge(edge.left, edge.right, edge.parent, edge.child) + + if verbose > 0: + print("Done") + print(births) + print("\ttable edges:", len(tables.edges)) + print("\ttable nodes:", len(tables.nodes)) + + +def simplify_with_births_easy(tables, births, alive, verbose): + for parent, edges in births.edges.items(): + for child, left, right in edges: tables.edges.add_row(left, right, parent, child) tables.sort() - tables.simplify(samples) - # We've exhausted the parent buffer, so clear it out. In reality we'd - # do this more carefully, like KT does in the post_simplify step. - parent_buffer.clear() + tables.simplify(alive) + births.clear() + + # print(tables.nodes.time[tables.edges.parent]) def wright_fisher( @@ -52,7 +171,7 @@ def wright_fisher( rng = random.Random(seed) tables = tskit.TableCollection(L) alive = [tables.nodes.add_row(time=T) for _ in range(N)] - parent_buffer = collections.defaultdict(list) + births = BirthBuffer() t = T while t > 0: @@ -66,12 +185,16 @@ def wright_fisher( a = rng.randint(0, N - 1) b = rng.randint(0, N - 1) x = rng.uniform(0, L) - parent_buffer[alive[a]].append((0, x, u)) - parent_buffer[alive[b]].append((x, L, u)) + # TODO Possibly more natural do this like + # births.add(u, parents=[a, b], breaks=[0, x, L]) + births.add_edge(0, x, alive[a], u) + births.add_edge(x, L, alive[b], u) alive = next_alive if t % simplify_interval == 0 or t == 0: - simplify_with_buffer(tables, parent_buffer, alive, verbose=verbose) + simplify_with_births(tables, births, alive, verbose=verbose) + # simplify_with_births_easy(tables, births, alive, verbose=verbose) alive = list(range(N)) + # print(tables.tree_sequence()) return tables.tree_sequence() @@ -115,3 +238,22 @@ def test_full_simulation(self): ts = wright_fisher(N=5, T=500, death_proba=0.9, simplify_interval=1000) for tree in ts.trees(): assert tree.num_roots == 1 + + +class TestSimplifyIntervals: + @pytest.mark.parametrize("interval", [1, 10, 33, 100]) + def test_non_overlapping_generations(self, interval): + N = 10 + ts = wright_fisher(N, T=100, death_proba=1, simplify_interval=interval) + assert ts.num_samples == N + + @pytest.mark.parametrize("interval", [1, 10, 33, 100]) + @pytest.mark.parametrize("death_proba", [0.33, 0.5, 0.9]) + def test_overlapping_generations(self, interval, death_proba): + N = 4 + ts = wright_fisher( + N, T=20, death_proba=death_proba, simplify_interval=interval, verbose=1 + ) + assert ts.num_samples == N + print() + print(ts.draw_text())