Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make date_field parametrisable #456

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sc2ts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class Variant:
class Dataset(collections.abc.Mapping):

def __init__(self, path, chunk_cache_size=1, date_field="date"):
logger.info(f"Loading dateset @{path} using {date_field} as date field")
self.date_field = date_field
self.path = pathlib.Path(path)
if self.path.suffix == ".zip":
Expand Down
5 changes: 3 additions & 2 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def extend(
date,
base_ts,
match_db,
date_field="date",
include_samples=None,
num_mismatches=None,
hmm_cost_threshold=None,
Expand Down Expand Up @@ -557,7 +558,7 @@ def extend(

start_time = time.time() # wall time
base_ts = tszip.load(base_ts)
ds = _dataset.Dataset(dataset)
ds = _dataset.Dataset(dataset, date_field=date_field)

with MatchDb(match_db) as matches:
tables = _extend(
Expand Down Expand Up @@ -1462,7 +1463,7 @@ def match_tsinfer(
coord_map, mirror_coordinates, ts.sites_position
)
logger.debug(
f"HMM@T={mismatch_threshold}: {sample.strain} {sample.pango}"
f"HMM@T={mismatch_threshold}: {sample.strain} {sample.pango} "
f"hmm_cost={sample.hmm_match.cost} match={sample.hmm_match.summary()}"
)

Expand Down
7 changes: 7 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ def test_examples(self, fx_dataset):
],
)

def test_date_field(self, fx_dataset):
ds1 = sc2ts.Dataset(fx_dataset.path, date_field="date")
ds2 = sc2ts.Dataset(fx_dataset.path, date_field="Collection_date")
diffs = np.where(ds1.metadata.sample_date != ds2.metadata.sample_date)[0]
# The point is just to see if they are different here
assert len(diffs) == 6


class TestDatasetAlignments:

Expand Down
14 changes: 14 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,20 @@ def test_2020_01_25(self, tmp_path, fx_ts_map, fx_dataset):
}
ts.tables.assert_equals(fx_ts_map["2020-01-25"].tables, ignore_provenance=True)

def test_using_collection_date(self, tmp_path, fx_ts_map, fx_dataset):

ts = run_extend(
dataset=fx_dataset,
base_ts=fx_ts_map[self.dates[0]],
date="2020-06-16",
match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db"),
date_field="Collection_date",
hmm_cost_threshold=25,
)
assert ts.num_samples == 1
# This has a hmm cost of 21
assert ts.metadata["sc2ts"]["samples_strain"] == ["SRR15736313"]

@pytest.mark.parametrize("num_threads", [0, 1, 3, 10])
def test_2020_02_02(self, tmp_path, fx_ts_map, fx_dataset, num_threads):
ts = run_extend(
Expand Down
Loading