Skip to content

Commit

Permalink
feat: ensure FileBasedVariantLookup is used as a context manager (#71)
Browse files Browse the repository at this point in the history
The
[`FileBasedVariantLookup`](https://prymer.readthedocs.io/en/latest/reference/prymer/api/variant_lookup.html#prymer.api.variant_lookup.FileBasedVariantLookup)
opens one or more
[`pysam.VariantFile`](https://pysam.readthedocs.io/en/latest/api.html#pysam.VariantFile)
objects but provides no public API for [closing
them](https://pysam.readthedocs.io/en/latest/api.html#pysam.VariantFile.close).
The file-based lookup should have a public method for closing the file
handles and should also support use as a context manager so there are
safe ways to cleanup IO resources.

Closes: #27
  • Loading branch information
clintval authored Nov 13, 2024
1 parent 8cbc1a2 commit 7e2e784
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 49 deletions.
50 changes: 46 additions & 4 deletions prymer/api/variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
from enum import auto
from enum import unique
from pathlib import Path
from types import TracebackType
from typing import ContextManager
from typing import Optional
from typing import final

Expand Down Expand Up @@ -320,10 +322,20 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
"""Subclasses must implement this method."""


class FileBasedVariantLookup(VariantLookup):
"""Implementation of VariantLookup that queries against indexed VCF files each time a query is
class FileBasedVariantLookup(ContextManager, VariantLookup):
"""Implementation of `VariantLookup` that queries against indexed VCF files each time a query is
performed. Assumes the index is located adjacent to the VCF file and has the same base name with
either a .csi or .tbi suffix."""
either a .csi or .tbi suffix.
Example:
```python
>>> with FileBasedVariantLookup([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0, include_missing_mafs=False) as lookup:
... lookup.query(refname="chr2", start=7999, end=8000)
[SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='<DEL>', end=8000, variant_type=<VariantType.OTHER: 'OTHER'>, maf=None)]
```
""" # noqa: E501

def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_missing_mafs: bool):
self._readers: list[VariantFile] = []
Expand All @@ -341,6 +353,20 @@ def __init__(self, vcf_paths: list[Path], min_maf: Optional[float], include_miss
open_fh = pysam.VariantFile(str(path))
self._readers.append(open_fh)

def __enter__(self) -> "FileBasedVariantLookup":
"""Enter the context manager."""
return self

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Exit this context manager while closing the underlying VCF handles."""
self.close()
return None

def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
"""Queries variants from the VCFs used by this lookup and returns a `SimpleVariant`."""
simple_variants: list[SimpleVariant] = []
Expand All @@ -353,6 +379,11 @@ def _query(self, refname: str, start: int, end: int) -> list[SimpleVariant]:
simple_variants.extend(self.to_variants(variants, source_vcf=path))
return sorted(simple_variants, key=lambda x: x.pos)

def close(self) -> None:
"""Close the underlying VCF file handles."""
for handle in self._readers:
handle.close()


class VariantOverlapDetector(VariantLookup):
"""Implements `VariantLookup` by reading the entire VCF into memory and loading the resulting
Expand Down Expand Up @@ -443,7 +474,18 @@ def disk_based(
vcf_paths: list[Path], min_maf: float, include_missing_mafs: bool = False
) -> FileBasedVariantLookup:
"""Constructs a `VariantLookup` that queries indexed VCFs on disk for each lookup.
Appropriate for large VCFs."""
Appropriate for large VCFs.
Example:
```python
>>> with disk_based([Path("./tests/api/data/miniref.variants.vcf.gz")], min_maf=0.0) as lookup:
... lookup.query(refname="chr2", start=7999, end=8000)
[SimpleVariant(id='complex-variant-sv-1/1', refname='chr2', pos=8000, ref='T', alt='<DEL>', end=8000, variant_type=<VariantType.OTHER: 'OTHER'>, maf=None)]
```
""" # noqa: E501
return FileBasedVariantLookup(
vcf_paths=vcf_paths, min_maf=min_maf, include_missing_mafs=include_missing_mafs
)
44 changes: 24 additions & 20 deletions tests/api/test_picking.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,25 +383,29 @@ def test_build_primer_pairs_fails_when_primers_on_wrong_reference(
assert next(picks) is not None

with pytest.raises(ValueError, match="Left primers exist on different reference"):
_picks = list(picking.build_primer_pairs(
left_primers=invalid_lefts,
right_primers=valid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
))
_picks = list(
picking.build_primer_pairs(
left_primers=invalid_lefts,
right_primers=valid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
)
)

with pytest.raises(ValueError, match="Right primers exist on different reference"):
_picks = list(picking.build_primer_pairs(
left_primers=valid_lefts,
right_primers=invalid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
))
_picks = list(
picking.build_primer_pairs(
left_primers=valid_lefts,
right_primers=invalid_rights,
target=target,
amplicon_sizes=MinOptMax(0, 100, 500),
amplicon_tms=MinOptMax(0, 80, 150),
max_heterodimer_tm=None,
weights=weights,
fasta_path=fasta,
)
)
61 changes: 36 additions & 25 deletions tests/api/test_variant_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import replace
from pathlib import Path
from typing import Optional
from typing import Type

import fgpyo.vcf.builder
import pytest
Expand All @@ -17,7 +16,6 @@
from prymer.api.span import Strand
from prymer.api.variant_lookup import FileBasedVariantLookup
from prymer.api.variant_lookup import SimpleVariant
from prymer.api.variant_lookup import VariantLookup
from prymer.api.variant_lookup import VariantOverlapDetector
from prymer.api.variant_lookup import VariantType
from prymer.api.variant_lookup import cached
Expand Down Expand Up @@ -435,13 +433,24 @@ def test_simple_variant_conversion(vcf_path: Path, sample_vcf: list[VariantRecor
assert actual_simple_variants == VALID_SIMPLE_VARIANTS_APPROX


@pytest.mark.parametrize("variant_lookup_class", [FileBasedVariantLookup, VariantOverlapDetector])
def test_simple_variant_conversion_logs(
variant_lookup_class: Type[VariantLookup], vcf_path: Path, caplog: pytest.LogCaptureFixture
def test_simple_variant_conversion_logs_file_based(
vcf_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert."""
caplog.set_level(logging.DEBUG)
variant_lookup = variant_lookup_class(
with FileBasedVariantLookup(
vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False
) as variant_lookup:
variant_lookup.query(refname="foo", start=1, end=2)
assert "No variants extracted from region of interest" in caplog.text


def test_simple_variant_conversion_logs_non_file_based(
vcf_path: Path, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that `to_variants()` logs a debug message with no pysam.VariantRecords to convert."""
caplog.set_level(logging.DEBUG)
variant_lookup = VariantOverlapDetector(
vcf_paths=[vcf_path], min_maf=0.01, include_missing_mafs=False
)
variant_lookup.query(refname="foo", start=1, end=2)
Expand All @@ -451,15 +460,17 @@ def test_simple_variant_conversion_logs(
def test_missing_index_file_raises(temp_missing_path: Path) -> None:
"""Test that both VariantLookup objects raise an error with a missing index file."""
with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"):
disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)
with disk_based(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False):
pass
with pytest.raises(ValueError, match="Cannot perform fetch with missing index file for VCF"):
cached(vcf_paths=[temp_missing_path], min_maf=0.01, include_missing_mafs=False)


def test_missing_vcf_files_raises() -> None:
"""Test that an error is raised when no VCF_paths are provided."""
with pytest.raises(ValueError, match="No VCF paths given to query"):
disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)
with disk_based(vcf_paths=[], min_maf=0.01, include_missing_mafs=False):
pass
with pytest.raises(ValueError, match="No VCF paths given to query"):
cached(vcf_paths=[], min_maf=0.01, include_missing_mafs=False)

Expand All @@ -480,12 +491,12 @@ def test_vcf_header_missing_chrom(
caplog.set_level(logging.DEBUG)
vcf_paths = [vcf_path, mini_chr1_vcf, mini_chr3_vcf]
random.Random(random_seed).shuffle(vcf_paths)
variant_lookup = FileBasedVariantLookup(
with FileBasedVariantLookup(
vcf_paths=vcf_paths, min_maf=0.00, include_missing_mafs=True
)
variants_of_interest = variant_lookup.query(
refname="chr2", start=7999, end=9900
) # (chr2 only in vcf_path)
) as variant_lookup:
variants_of_interest = variant_lookup.query(
refname="chr2", start=7999, end=9900
) # (chr2 only in vcf_path)
# Should find all 12 variants from vcf_path (no filtering), with two variants having two
# alternate alleles
assert len(variants_of_interest) == 14
Expand Down Expand Up @@ -587,19 +598,19 @@ def test_variant_overlap_query_maf_filter(vcf_path: Path, include_missing_mafs:
@pytest.mark.parametrize("include_missing_mafs", [False, True])
def test_file_based_variant_query(vcf_path: Path, include_missing_mafs: bool) -> None:
"""Test that `FileBasedVariantLookup.query()` MAF filtering is as expected."""
file_based_vcf_query = FileBasedVariantLookup(
with FileBasedVariantLookup(
vcf_paths=[vcf_path], min_maf=0.0, include_missing_mafs=include_missing_mafs
)
query = [
_round_simple_variant(simple_variant)
for simple_variant in file_based_vcf_query.query(
refname="chr2",
start=8000,
end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100
maf=0.05,
include_missing_mafs=include_missing_mafs,
)
]
) as file_based_vcf_query:
query = [
_round_simple_variant(simple_variant)
for simple_variant in file_based_vcf_query.query(
refname="chr2",
start=8000,
end=9100, # while "common-mixed-2/2" starts at 9101, in the VCf is starts at 9100
maf=0.05,
include_missing_mafs=include_missing_mafs,
)
]

if not include_missing_mafs:
assert query == get_simple_variant_approx_by_id(
Expand Down

0 comments on commit 7e2e784

Please sign in to comment.