Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: convert samples argument in Genotypes.read into a set and fix tr_harmonizer bug arising when TRTools is also installed #225

Merged
merged 13 commits into from
Jan 12, 2024
Merged
18 changes: 9 additions & 9 deletions haptools/__main__.py
Original file line number Diff line number Diff line change
@@ -491,10 +491,10 @@ def simphenotype(
)
if samples_file:
with samples_file as samps_file:
samples = samps_file.read().splitlines()
samples = set(samps_file.read().splitlines())
elif samples:
# needs to be converted from tuple to list
samples = list(samples)
# needs to be converted from tuple to set
samples = set(samples)
else:
samples = None

@@ -657,10 +657,10 @@ def transform(
)
if samples_file:
with samples_file as samps_file:
samples = samps_file.read().splitlines()
samples = set(samps_file.read().splitlines())
elif samples:
# needs to be converted from tuple to list
samples = list(samples)
# needs to be converted from tuple to set
samples = set(samples)
else:
samples = None

@@ -828,10 +828,10 @@ def ld(
)
if samples_file:
with samples_file as samps_file:
samples = samps_file.read().splitlines()
samples = set(samps_file.read().splitlines())
elif samples:
# needs to be converted from tuple to list
samples = list(samples)
# needs to be converted from tuple to set
samples = set(samples)
else:
samples = None

112 changes: 83 additions & 29 deletions haptools/data/genotypes.py
Original file line number Diff line number Diff line change
@@ -3,15 +3,15 @@
import gc
from csv import reader
from pathlib import Path
from logging import Logger
from typing import Iterator
from logging import getLogger, Logger
from collections import namedtuple, Counter

import pgenlib
import numpy as np
import numpy.typing as npt
from pysam import VariantFile
from cyvcf2 import VCF, Variant
from pysam import VariantFile, TabixFile

try:
import trtools.utils.tr_harmonizer as trh
@@ -77,7 +77,7 @@ def load(
cls: Genotypes,
fname: Path | str,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
) -> Genotypes:
"""
@@ -91,7 +91,7 @@ def load(
See documentation for :py:attr:`~.Data.fname`
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
@@ -112,7 +112,7 @@ def load(
def read(
self,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
max_variants: int = None,
):
@@ -132,9 +132,11 @@ def read(
For this to work, the VCF must be indexed and the seqname must match!

Defaults to loading all genotypes
samples : list[str], optional
samples : set[str], optional
A subset of the samples from which to extract genotypes

Note that they are loaded in the same order as in the file

Defaults to loading genotypes from all samples
variants : set[str], optional
A set of variant IDs for which to extract genotypes
@@ -307,7 +309,7 @@ def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None):
vcf.close()

def __iter__(
self, region: str = None, samples: list[str] = None, variants: set[str] = None
self, region: str = None, samples: set[str] = None, variants: set[str] = None
) -> Iterator[namedtuple]:
"""
Read genotypes from a VCF line by line without storing anything
@@ -316,7 +318,7 @@ def __iter__(
----------
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
@@ -326,6 +328,13 @@ def __iter__(
Iterator[namedtuple]
See documentation for :py:meth:`~.Genotypes._iterate`
"""
if samples is not None:
if not isinstance(samples, set):
self.log.warning(
"Samples cannot be loaded in a particular order. "
"Use subset() to reorder the samples after loading them."
)
samples = list(samples)
vcf = VCF(str(self.fname), samples=samples, lazy=True)
self.samples = tuple(vcf.samples)
# call another function to force the lines above to be run immediately
@@ -797,6 +806,37 @@ def write(self):
vcf.close()


class TRRecordHarmonizerRegion(trh.TRRecordHarmonizer):
"""
Parameters
----------
vcffile : VCF
vcftype : {'auto', 'gangstr', 'advntr', 'hipstr', 'eh', 'popstr'}, optional
Type of the VCF file. Default='auto'.
If vcftype=='auto', attempts to infer the type.
Attributes
----------
vcffile : VCF
vcfiter : VCF
Region to grab strs from within the VCF file.
vcftype : enum
Type of the VCF file. Must be included in VcfTypes
"""

def __init__(
self,
vcffile: VCF,
vcfiter: object,
vcftype: str | trh.VcfTypes = "auto",
):
super().__init__(vcffile, vcftype)
self.vcfiter = vcfiter

def __next__(self) -> trh.TRRecord:
"""Iterate over TRRecord produced from the underlying vcf."""
return trh.HarmonizeRecord(self.vcftype, next(self.vcfiter))


class GenotypesTR(Genotypes):
"""
A class for processing TR genotypes from a file
@@ -823,7 +863,12 @@ class GenotypesTR(Genotypes):
{'auto', 'gangstr', 'advntr', 'hipstr', 'eh', 'popstr'}
"""

def __init__(self, fname: Path | str, log: Logger = None, vcftype: str = "auto"):
def __init__(
self,
fname: Path | str,
log: Logger = None,
vcftype: str = "auto",
):
super().__init__(fname, log)
self.vcftype = vcftype

@@ -832,7 +877,7 @@ def load(
cls: GenotypesTR,
fname: Path | str,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
vcftype: str = "auto",
) -> Genotypes:
@@ -847,7 +892,7 @@ def load(
See documentation for :py:attr:`~.Data.fname`
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
@@ -878,8 +923,8 @@ def _vcf_iter(self, vcf: cyvcf2.VCF, region: str = None):
tr_records: trh.TRRecord
TRRecord objects yielded from TRRecordHarmonizer
"""
for record in trh.TRRecordHarmonizer(
vcffile=vcf, vcfiter=vcf(region), region=region, vcftype=self.vcftype
for record in TRRecordHarmonizerRegion(
vcffile=vcf, vcfiter=vcf(region), vcftype=self.vcftype
):
record.ID = record.record_id
record.CHROM = record.chrom
@@ -938,7 +983,7 @@ class GenotypesPLINK(GenotypesVCF):
----------
data : npt.NDArray
See documentation for :py:attr:`~.GenotypesVCF.data`
samples : tuple
samples : tuple[str]
See documentation for :py:attr:`~.GenotypesVCF.samples`
variants : np.array
See documentation for :py:attr:`~.GenotypesVCF.variants`
@@ -956,11 +1001,16 @@ class GenotypesPLINK(GenotypesVCF):
>>> genotypes = GenotypesPLINK.load('tests/data/simple.pgen')
"""

def __init__(self, fname: Path | str, log: Logger = None, chunk_size: int = None):
def __init__(
self,
fname: Path | str,
log: Logger = None,
chunk_size: int = None,
):
super().__init__(fname, log)
self.chunk_size = chunk_size

def read_samples(self, samples: list[str] = None):
def read_samples(self, samples: set[str] = None):
"""
Read sample IDs from a PSAM file into a list stored in
:py:attr:`~.GenotypesPLINK.samples`
@@ -969,7 +1019,7 @@ def read_samples(self, samples: list[str] = None):

Parameters
----------
samples : list[str], optional
samples : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`

Returns
@@ -980,6 +1030,10 @@ def read_samples(self, samples: list[str] = None):
if len(self.samples) != 0:
self.log.warning("Sample data has already been loaded. Overriding.")
if samples is not None and not isinstance(samples, set):
self.log.warning(
"Samples cannot be loaded in a particular order. "
"Use subset() to reorder the samples after loading them."
)
samples = set(samples)
with self.hook_compressed(self.fname.with_suffix(".psam"), mode="r") as psam:
psamples = reader(psam, delimiter="\t")
@@ -1210,7 +1264,7 @@ def read_variants(
def read(
self,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
max_variants: int = None,
):
@@ -1222,7 +1276,7 @@ def read(
----------
region : str, optional
See documentation for :py:attr:`~.GenotypesVCF.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
variants : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
@@ -1366,7 +1420,7 @@ def _iterate(
pgen.close()

def __iter__(
self, region: str = None, samples: list[str] = None, variants: set[str] = None
self, region: str = None, samples: set[str] = None, variants: set[str] = None
) -> Iterator[namedtuple]:
"""
Read genotypes from a PGEN line by line without storing anything
@@ -1375,7 +1429,7 @@ def __iter__(
----------
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
@@ -1552,7 +1606,7 @@ class GenotypesPLINKTR(GenotypesPLINK):
----------
data : npt.NDArray
See documentation for :py:attr:`~.GenotypesPLINK.data`
samples : tuple
samples : tuple[str]
See documentation for :py:attr:`~.GenotypesPLINK.samples`
variants : np.array
See documentation for :py:attr:`~.GenotypesPLINK.variants`
@@ -1562,7 +1616,6 @@ class GenotypesPLINKTR(GenotypesPLINK):
See documentation for :py:attr:`~.GenotypesPLINK.chunk_size`
vcftype: str, optional
See documentation for :py:attr:`~.GenotypesTR.vcftype`

Examples
--------
>>> genotypes = GenotypesPLINK.load('tests/data/simple.pgen')
@@ -1583,7 +1636,7 @@ def load(
cls: GenotypesPLINKTR,
fname: Path | str,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
vcftype: str = "auto",
) -> Genotypes:
@@ -1598,7 +1651,7 @@ def load(
See documentation for :py:attr:`~.Data.fname`
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
See documentation for :py:meth:`~.Genotypes.read`
@@ -1632,12 +1685,12 @@ def _iter_TRRecords(self, region: str = None, variants: set[str] = None):
An iterator over each line of the PVAR file
"""
vcf = VCF(self.fname.with_suffix(".pvar"))
tr_records = trh.TRRecordHarmonizer(
tr_records = TRRecordHarmonizerRegion(
vcffile=vcf,
vcfiter=vcf(region),
region=region,
vcftype=self.vcftype,
)

# filter out TRs that we didn't want
if variants is not None:
tr_records = filter(lambda rec: rec.record_id in variants, tr_records)
@@ -1646,7 +1699,7 @@ def _iter_TRRecords(self, region: str = None, variants: set[str] = None):
def read(
self,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
max_variants: int = None,
):
@@ -1658,14 +1711,15 @@ def read(
----------
region : str, optional
See documentation for :py:attr:`~.GenotypesVCF.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
variants : set[str], optional
See documentation for :py:attr:`~.GenotypesVCF.read`
max_variants : int, optional
See documentation for :py:attr:`~.GenotypesVCF.read`
"""
super().read(region, samples, variants, max_variants)

num_variants = len(self.variants)
# initialize a jagged array of allele lengths
max_num_alleles = max(map(len, self.variants["alleles"]))
14 changes: 2 additions & 12 deletions haptools/data/tr_harmonizer.py
Original file line number Diff line number Diff line change
@@ -1633,8 +1633,6 @@ class TRRecordHarmonizer:
Attributes
----------
vcffile : cyvcf2.VCF instance
region : str
Region to grab strs from within the VCF file.
vcftype : enum
Type of the VCF file. Must be included in VcfTypes
Raises
@@ -1644,17 +1642,9 @@ class TRRecordHarmonizer:
See :py:meth:`InferVCFType` for more details.
"""

def __init__(
self,
vcffile: cyvcf2.VCF,
vcfiter: object,
region: str,
vcftype: Union[str, VcfTypes] = "auto",
):
def __init__(self, vcffile: cyvcf2.VCF, vcftype: Union[str, VcfTypes] = "auto"):
self.vcffile = vcffile
self.vcfiter = vcfiter
self.vcftype = InferVCFType(vcffile, vcftype)
self.region = region

def MayHaveImpureRepeats(self) -> bool:
"""
@@ -1725,7 +1715,7 @@ def __iter__(self) -> Iterator[TRRecord]:

def __next__(self) -> TRRecord:
"""Iterate over TRRecord produced from the underlying vcf."""
return HarmonizeRecord(self.vcftype, next(self.vcfiter))
return HarmonizeRecord(self.vcftype, next(self.vcffile))


# TODO check all users of this class for new options
5 changes: 3 additions & 2 deletions haptools/ld.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import logging
from pathlib import Path
from dataclasses import dataclass, field

@@ -50,7 +51,7 @@ def calc_ld(
genotypes: Path,
haplotypes: Path,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
ids: tuple[str] = None,
chunk_size: int = None,
discard_missing: bool = False,
@@ -72,7 +73,7 @@ def calc_ld(
region : str, optional
See documentation for :py:meth:`~.data.Genotypes.read`
and :py:meth:`~.data.Haplotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.data.Genotypes.read`
ids: set[str], optional
A subset of haplotype IDs to obtain from the .hap file. All others
7 changes: 2 additions & 5 deletions haptools/sim_phenotype.py
Original file line number Diff line number Diff line change
@@ -303,7 +303,7 @@ def simulate_pt(
prevalence: float = None,
normalize: bool = True,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
haplotype_ids: set[str] = None,
chunk_size: int = None,
repeats: Path = None,
@@ -352,13 +352,10 @@ def simulate_pt(
match!
Defaults to loading all haplotypes
sample : tuple[str], optional
samples : set[str], optional
A subset of the samples from which to extract genotypes
Defaults to loading genotypes from all samples
samples_file : Path, optional
A single column txt file containing a list of the samples (one per line) to
subset from the genotypes file
haplotype_ids: set[str], optional
A list of haplotype IDs to obtain from the .hap file. All others are ignored.
16 changes: 8 additions & 8 deletions haptools/transform.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,8 @@
from dataclasses import dataclass, field

import numpy as np
from cyvcf2 import VCF
import numpy.typing as npt
from cyvcf2 import VCF, Variant
from pysam import VariantFile

from . import data
@@ -28,7 +28,7 @@ class HaplotypeAncestry(data.Haplotype):
default=(data.Extra("ancestry", "s", "Local ancestry"),),
)

def transform(self, genotypes: data.GenotypesVCF) -> npt.NDArray[bool]:
def transform(self, genotypes: data.GenotypesVCF) -> npt.NDArray:
"""
Transform a genotypes matrix via the current haplotype and its ancestral
population
@@ -80,7 +80,7 @@ def __init__(
fname: Path | str,
haplotype: type[HaplotypeAncestry] = HaplotypeAncestry,
variant: type[data.Variant] = data.Variant,
log: Logger = None,
log: logging.Logger = None,
):
"""
Contrasting with the base Haplotypes class: this class uses HaplotypeAncestry
@@ -171,11 +171,11 @@ class GenotypesAncestry(data.GenotypesVCF):
ancestry : np.array
The ancestral population of each allele in each sample of
:py:attr:`~.GenotypesAncestry.data`
log: Logger
log: logging.Logger
See documentation for :py:attr:`~.Genotypes.log`
"""

def __init__(self, fname: Path | str, log: Logger = None):
def __init__(self, fname: Path | str, log: logging.Logger = None):
super().__init__(fname, log)
self.ancestry = None
self.valid_labels = None
@@ -227,7 +227,7 @@ def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None):
def read(
self,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
variants: set[str] = None,
max_variants: int = None,
):
@@ -532,7 +532,7 @@ def transform_haps(
genotypes: Path,
haplotypes: Path,
region: str = None,
samples: list[str] = None,
samples: set[str] = None,
haplotype_ids: set[str] = None,
chunk_size: int = None,
discard_missing: bool = False,
@@ -552,7 +552,7 @@ def transform_haps(
region : str, optional
See documentation for :py:meth:`~.data.Genotypes.read`
and :py:meth:`~.data.Haplotypes.read`
samples : list[str], optional
samples : set[str], optional
See documentation for :py:meth:`~.data.Genotypes.read`
haplotype_ids: set[str], optional
A set of haplotype IDs to obtain from the .hap file. All others are ignored.
12 changes: 6 additions & 6 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,6 @@
Breakpoints,
GenotypesTR,
GenotypesVCF,
GenotypesTR,
GenotypesPLINK,
GenotypesPLINKTR,
)
@@ -195,17 +194,17 @@ def test_load_genotypes_subset(self):

gts = Genotypes(DATADIR / "simple.vcf.gz")
samples = ["HG00097", "HG00100"]
gts.read(region="1:10115-10117", samples=samples)
samples_set = set(samples)
gts.read(region="1:10115-10117", samples=samples_set)
np.testing.assert_allclose(gts.data, expected)
assert gts.samples == tuple(samples)

# subset to just one of the variants
expected = expected[:, [1]]

gts = Genotypes(DATADIR / "simple.vcf.gz")
samples = ["HG00097", "HG00100"]
variants = {"1:10117:C:A"}
gts.read(region="1:10115-10117", samples=samples, variants=variants)
gts.read(region="1:10115-10117", samples=samples_set, variants=variants)
np.testing.assert_allclose(gts.data, expected)
assert gts.samples == tuple(samples)

@@ -501,7 +500,8 @@ def test_load_genotypes_subset(self):

gts = GenotypesPLINK(DATADIR / "simple.pgen")
samples = [expected.samples[1], expected.samples[3]]
gts.read(region="1:10115-10117", samples=samples)
samples_set = set(samples)
gts.read(region="1:10115-10117", samples=samples_set)
gts.check_phase()
np.testing.assert_allclose(gts.data, expected_data)
assert gts.samples == tuple(samples)
@@ -511,7 +511,7 @@ def test_load_genotypes_subset(self):

gts = GenotypesPLINK(DATADIR / "simple.pgen")
variants = {"1:10117:C:A"}
gts.read(region="1:10115-10117", samples=samples, variants=variants)
gts.read(region="1:10115-10117", samples=samples_set, variants=variants)
gts.check_phase()
np.testing.assert_allclose(gts.data, expected_data)
assert gts.samples == tuple(samples)