Skip to content

Commit

Permalink
Experiment with alternative HMM implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Sep 4, 2024
1 parent dffa62c commit 7a3b625
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 55 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ authors = [
]
requires-python = ">=3.9"
dependencies = [
"tsinfer==0.3.3", # https://github.com/jeromekelleher/sc2ts/issues/201
# "tsinfer==0.3.3", # https://github.com/jeromekelleher/sc2ts/issues/201
# FIXME
"tsinfer @ git+https://github.com/jeromekelleher/tsinfer.git@experimental-hmm",
"pyfaidx",
"tskit>=0.5.3",
"tszip",
Expand Down
51 changes: 20 additions & 31 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,46 +402,45 @@ def match_samples(
show_progress=False,
num_threads=None,
):
# First pass, compute the matches at precision=0.
run_batch = samples

# Values based on https://github.com/jeromekelleher/sc2ts/issues/242,
# but somewhat arbitrary.
for precision, cost_threshold in [(0, 1), (1, 2), (2, 3)]:
logger.info(f"Running batch of {len(run_batch)} at p={precision}")
mu = 0.125 ## FIXME
for k in range(num_mismatches):
# To catch k mismatches we need a likelihood threshold of mu**k
likelihood_threshold = mu**k - 1e-15
logger.info(f"Running match={k} batch of {len(run_batch)} at threshold={likelihood_threshold}")
match_tsinfer(
samples=run_batch,
ts=base_ts,
num_mismatches=num_mismatches,
precision=precision,
likelihood_threshold=likelihood_threshold,
num_threads=num_threads,
show_progress=show_progress,
)

exceeding_threshold = []
for sample in run_batch:
cost = sample.get_hmm_cost(num_mismatches)
logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}")
if cost > cost_threshold:
logger.debug(f"HMM@k={k}: hmm_cost={cost} {sample.summary()}")
if cost > k + 1:
sample.path.clear()
sample.mutations.clear()
exceeding_threshold.append(sample)

num_matches_found = len(run_batch) - len(exceeding_threshold)
logger.info(
f"{num_matches_found} final matches for found p={precision}; "
f"{num_matches_found} final matches found at k={k}; "
f"{len(exceeding_threshold)} remain"
)
run_batch = exceeding_threshold

precision = 6
logger.info(f"Running final batch of {len(run_batch)} at p={precision}")
logger.info(f"Running final batch of {len(run_batch)} at full precision")
match_tsinfer(
samples=run_batch,
ts=base_ts,
num_mismatches=num_mismatches,
precision=precision,
num_threads=num_threads,
likelihood_threshold=1e-200,
show_progress=show_progress,
)
for sample in run_batch:
Expand Down Expand Up @@ -798,36 +797,26 @@ def add_matching_results(
return ts # , excluded_samples, added_samples


def solve_num_mismatches(ts, k):
def solve_num_mismatches(k, num_sites, mu=0.125):
"""
Return the low-level LS parameters corresponding to accepting
k mismatches in favour of a single recombination.
NOTE! This is NOT taking into account the spatial distance along
the genome, and so is not a very good model in some ways.
"""
# We can match against any node in tsinfer
m = ts.num_sites
n = ts.num_nodes
# values of k <= 1 are not relevant for SC2 and lead to awkward corner cases
assert k > 1

# NOTE: the magnitude of mu matters because it puts a limit
# on how low we can push the HMM precision. We should be able to solve
# for the optimal value of this parameter such that the magnitude of the
# values within the HMM are as large as possible (so that we can truncate
# usefully).
# mu = 1e-2
mu = 0.125
denom = (1 - mu) ** k + (n - 1) * mu**k
r = n * mu**k / denom
denom = (1 - mu) ** k
r = mu**k / denom

# Add a little bit of extra mass for recombination so that we deterministically
# chose to recombine over k mutations
# NOTE: the magnitude of this value will depend also on mu, see above.
r += r * 0.01
ls_recomb = np.full(m - 1, r)
ls_mismatch = np.full(m, mu)
r += r * 0.125
ls_recomb = np.full(num_sites - 1, r)
ls_mismatch = np.full(num_sites, mu)
return ls_recomb, ls_mismatch


Expand Down Expand Up @@ -1268,7 +1257,7 @@ def match_tsinfer(
ts,
*,
num_mismatches,
precision=None,
likelihood_threshold=None,
num_threads=0,
show_progress=False,
mirror_coordinates=False,
Expand All @@ -1284,7 +1273,7 @@ def match_tsinfer(
sd = convert_tsinfer_sample_data(ts, genotypes)

L = int(ts.sequence_length)
ls_recomb, ls_mismatch = solve_num_mismatches(ts, num_mismatches)
ls_recomb, ls_mismatch = solve_num_mismatches(num_mismatches, ts.num_sites)
pm = tsinfer.inference._get_progress_monitor(
show_progress,
generate_ancestors=False,
Expand All @@ -1309,7 +1298,7 @@ def match_tsinfer(
mismatch=ls_mismatch,
progress_monitor=pm,
num_threads=num_threads,
precision=precision,
likelihood_threshold=likelihood_threshold
)
results = manager.run_match(np.arange(sd.num_samples))

Expand Down
56 changes: 33 additions & 23 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import numpy.testing as nt
import pytest
import tsinfer
import tskit
Expand All @@ -8,6 +9,18 @@
import util


class TestSolveNumMismatches:

@pytest.mark.parametrize(
["k", "expected_rho"],
[(2, 0.02295918), (3, 0.00327988), (4, 0.00046855), (1000, 0)],
)
def test_examples(self, k, expected_rho):
rho, mu = sc2ts.solve_num_mismatches(k, num_sites=2)
assert mu[0] == 0.125
nt.assert_almost_equal(rho[0], expected_rho)


class TestInitialTs:
def test_reference_sequence(self):
ts = sc2ts.initial_ts()
Expand Down Expand Up @@ -612,13 +625,13 @@ def test_node_mutation_counts(self, fx_ts_map, date):
"2020-02-03": {"nodes": 36, "mutations": 42},
"2020-02-04": {"nodes": 41, "mutations": 48},
"2020-02-05": {"nodes": 42, "mutations": 48},
"2020-02-06": {"nodes": 49, "mutations": 51},
"2020-02-07": {"nodes": 51, "mutations": 57},
"2020-02-08": {"nodes": 57, "mutations": 58},
"2020-02-09": {"nodes": 59, "mutations": 61},
"2020-02-10": {"nodes": 60, "mutations": 65},
"2020-02-11": {"nodes": 62, "mutations": 66},
"2020-02-13": {"nodes": 66, "mutations": 68},
"2020-02-06": {"nodes": 48, "mutations": 51},
"2020-02-07": {"nodes": 50, "mutations": 57},
"2020-02-08": {"nodes": 56, "mutations": 58},
"2020-02-09": {"nodes": 58, "mutations": 61},
"2020-02-10": {"nodes": 59, "mutations": 65},
"2020-02-11": {"nodes": 61, "mutations": 66},
"2020-02-13": {"nodes": 65, "mutations": 68},
}
assert ts.num_nodes == expected[date]["nodes"]
assert ts.num_mutations == expected[date]["mutations"]
Expand All @@ -631,9 +644,9 @@ def test_node_mutation_counts(self, fx_ts_map, date):
(13, "SRR11597132", 10),
(16, "SRR11597177", 10),
(41, "SRR11597156", 10),
(57, "SRR11597216", 1),
(60, "SRR11597207", 40),
(62, "ERR4205570", 58),
(56, "SRR11597216", 1),
(59, "SRR11597207", 40),
(61, "ERR4205570", 57),
],
)
def test_exact_matches(self, fx_ts_map, node, strain, parent):
Expand Down Expand Up @@ -693,10 +706,9 @@ class TestMatchingDetails:
# assert s.path[0].parent == 37

@pytest.mark.parametrize(
("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 58)]
("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 57)]
)
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
def test_exact_matches(
self,
fx_ts_map,
Expand All @@ -705,17 +717,18 @@ def test_exact_matches(
strain,
parent,
num_mismatches,
precision,
):
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
)
# FIXME
mu = 0.125
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
precision=precision,
likelihood_threshold = mu**num_mismatches - 1e-12,
num_threads=0,
)
s = samples[0]
Expand All @@ -725,10 +738,10 @@ def test_exact_matches(

@pytest.mark.parametrize(
("strain", "parent", "position", "derived_state"),
[("SRR11597218", 10, 289, "T"), ("ERR4206593", 58, 26994, "T")],
[("SRR11597218", 10, 289, "T"), ("ERR4206593", 57, 26994, "T")],
)
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
# @pytest.mark.parametrize("precision", [0, 1, 2, 12])
def test_one_mismatch(
self,
fx_ts_map,
Expand All @@ -739,7 +752,6 @@ def test_one_mismatch(
position,
derived_state,
num_mismatches,
precision,
):
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
Expand All @@ -749,7 +761,8 @@ def test_one_mismatch(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
precision=precision,
# FIXME
likelihood_threshold=0.12499999,
num_threads=0,
)
s = samples[0]
Expand All @@ -760,30 +773,27 @@ def test_one_mismatch(
assert s.path[0].parent == parent

@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
def test_two_mismatches(
self,
fx_ts_map,
fx_alignment_store,
fx_metadata_db,
num_mismatches,
precision,
):
strain = "ERR4204459"
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
[fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store
)
mu = 0.125
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
num_mismatches=num_mismatches,
precision=precision,
likelihood_threshold=mu**2 - 1e-12,
num_threads=0,
)
s = samples[0]
assert len(s.path) == 1
assert s.path[0].parent == 5
assert len(s.mutations) == 2
# assert s.mutations[0].site_position == position
# assert s.mutations[0].derived_state == derived_state

0 comments on commit 7a3b625

Please sign in to comment.