diff --git a/src/main/scala/com/fulcrumgenomics/bam/EstimatePoolingFractions.scala b/src/main/scala/com/fulcrumgenomics/bam/EstimatePoolingFractions.scala index 15aeab690..aa50027c0 100644 --- a/src/main/scala/com/fulcrumgenomics/bam/EstimatePoolingFractions.scala +++ b/src/main/scala/com/fulcrumgenomics/bam/EstimatePoolingFractions.scala @@ -25,6 +25,7 @@ package com.fulcrumgenomics.bam import java.lang.Math.{max, min} +import java.util import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.bam.api.SamSource @@ -33,11 +34,9 @@ import com.fulcrumgenomics.commons.util.LazyLogging import com.fulcrumgenomics.sopt.{arg, clp} import com.fulcrumgenomics.util.Metric.{Count, Proportion} import com.fulcrumgenomics.util.{Io, Metric, Sequences} -import com.fulcrumgenomics.vcf.ByIntervalListVariantContextIterator +import com.fulcrumgenomics.vcf.api.{Variant, VcfSource} import htsjdk.samtools.util.SamLocusIterator.LocusInfo import htsjdk.samtools.util._ -import htsjdk.variant.variantcontext.VariantContext -import htsjdk.variant.vcf.VCFFileReader import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression @clp(group=ClpGroups.SamOrBam, description= @@ -48,6 +47,11 @@ import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression |for the alternative allele fractions at each SNP locus, using as inputs the individual sample's genotypes. |Only SNPs that are bi-allelic within the pooled samples are used. | + |Each sample's contribution of REF vs. ALT alleles at each site is derived in one of two ways. If + |the sample's genotype in the VCF has the `AF` attribute then the value from that field will be used. If the + |genotype has no AF attribute then the contribution is estimated based on the genotype (e.g. 0/0 will be 100% + |ref, 0/1 will be 50% ref and 50% alt, etc.). + | |Various filtering parameters can be used to control which loci are used: | |- _--intervals_ will restrict analysis to variants within the described intervals @@ -66,84 +70,111 @@ class EstimatePoolingFractions @arg(flag='g', doc="Minimum genotype quality. Use -1 to disable.") val minGenotypeQuality: Int = 30, @arg(flag='c', doc="Minimum (sequencing coverage @ SNP site / n_samples).") val minMeanSampleCoverage: Int = 6, @arg(flag='m', doc="Minimum mapping quality.") val minMappingQuality: Int = 20, - @arg(flag='q', doc="Minimum base quality.") val minBaseQuality:Int = 5 + @arg(flag='q', doc="Minimum base quality.") val minBaseQuality:Int = 5, + @arg(doc="Examine input reads by sample given in each read's read group.") bySample: Boolean = false ) extends FgBioTool with LazyLogging { Io.assertReadable(vcf :: bam :: intervals.toList) private val Ci99Width = 2.58 // Width of a 99% confidence interval in units of std err + private val AllReadGroupsName: String = "all" + /* Class to hold information about a single locus. */ - case class Locus(chrom: String, pos: Int, ref: Char, alt: Char, expectedSampleFractions: Array[Double], var observedFraction: Option[Double] = None) + case class Locus(chrom: String, + pos: Int, + ref: Char, + alt: Char, + expectedSampleFractions: Array[Double], + var observedFraction: Map[String, Double] = Map.empty) override def execute(): Unit = { - val vcfReader = new VCFFileReader(vcf.toFile) + val vcfReader = VcfSource(vcf) val sampleNames = pickSamplesToUse(vcfReader) val intervals = loadIntervals // Get the expected fractions from the VCF val vcfIterator = constructVcfIterator(vcfReader, intervals, sampleNames) - val loci = vcfIterator.filterNot(v => this.nonAutosomes.contains(v.getContig)).map { v => Locus( - chrom = v.getContig, - pos = v.getStart, - ref = v.getReference.getBaseString.charAt(0), - alt = v.getAlternateAllele(0).getBaseString.charAt(0), - expectedSampleFractions = sampleNames.map { s => val gt = v.getGenotype(s); if (gt.isHomRef) 0 else if (gt.isHet) 0.5 else 1.0 } + val loci = vcfIterator.filterNot(v => this.nonAutosomes.contains(v.chrom)).map { v => Locus( + chrom = v.chrom, + pos = v.pos, + ref = v.alleles.ref.bases.charAt(0), + alt = v.alleles.alts.head.value.charAt(0), + expectedSampleFractions = sampleNames.map { s => + val gt = v.gt(s) + gt.get[IndexedSeq[Float]]("AF") match { + case None => if (gt.isHomRef) 0 else if (gt.isHet) 0.5 else 1.0 + case Some(afs) => afs(0) + } + } )}.toArray logger.info(s"Loaded ${loci.length} bi-allelic SNPs from VCF.") - val coveredLoci = fillObserveredFractionAndFilter(loci, this.minMeanSampleCoverage * sampleNames.length) + fillObserveredFractionAndFilter(loci, this.minMeanSampleCoverage * sampleNames.length) + + val observedSamples = loci.flatMap { locus => locus.observedFraction.keySet }.distinct.sorted + logger.info(f"Regressing on ${observedSamples.length}%,d input samples.") - logger.info(s"Regressing on ${coveredLoci.length} of ${loci.length} that met coverage requirements.") val regression = new OLSMultipleLinearRegression regression.setNoIntercept(true) // Intercept should be at 0! - regression.newSampleData( - coveredLoci.map(_.observedFraction.getOrElse(unreachable("observed fraction must be defined"))), - coveredLoci.map(_.expectedSampleFractions) - ) - - val regressionParams = regression.estimateRegressionParameters() - val total = regressionParams.sum - val fractions = regressionParams.map(_ / total) - val stderrs = regression.estimateRegressionParametersStandardErrors().map(_ / total) - logger.info(s"R^2 = ${regression.calculateRSquared()}") - logger.info(s"Sum of regression parameters = ${total}") - - val metrics = sampleNames.toSeq.zipWithIndex.map { case (sample, index) => - val sites = coveredLoci.count(l => l.expectedSampleFractions(index) > 0) - val singletons = coveredLoci.count(l => l.expectedSampleFractions(index) > 0 && l.expectedSampleFractions.sum == l.expectedSampleFractions(index)) - PoolingFractionMetric( - sample = sample, - variant_sites = sites, - singletons = singletons, - estimated_fraction = fractions(index), - standard_error = stderrs(index), - ci99_low = max(0, fractions(index) - stderrs(index)*Ci99Width), - ci99_high = min(1, fractions(index) + stderrs(index)*Ci99Width)) - } - Metric.write(output, metrics) + val metrics = observedSamples.flatMap { observedSample => + logger.info(f"Examining $observedSample") + val (observedFractions, lociExpectedSampleFractions) = loci.flatMap { locus => + locus.observedFraction.get(observedSample).map { observedFraction => + (observedFraction, locus.expectedSampleFractions) + } + }.unzip + logger.info(f"Regressing on ${observedFractions.length}%,d of ${loci.length}%,d loci that met coverage requirements.") + regression.newSampleData( + observedFractions, + lociExpectedSampleFractions + ) + + val regressionParams = regression.estimateRegressionParameters() + val total = regressionParams.sum + val fractions = regressionParams.map(_ / total) + val stderrs = regression.estimateRegressionParametersStandardErrors().map(_ / total) + logger.info(s"R^2 = ${regression.calculateRSquared()}") + logger.info(s"Sum of regression parameters = ${total}") + + if (regression.estimateRegressionParameters().exists(_ < 0)) { + logger.error("#################################################################################") + logger.error("# One or more samples is estimated to have fraction < 0. This is likely due to #") + logger.error("# incorrect samples being used, insufficient coverage and/or too few SNPs. #") + logger.error("#################################################################################") + fail(1) + } - if (regression.estimateRegressionParameters().exists(_ < 0)) { - logger.error("#################################################################################") - logger.error("# One or more samples is estimated to have fraction < 0. This is likely due to #") - logger.error("# incorrect samples being used, insufficient coverage and/or too few SNPs. #") - logger.error("#################################################################################") - fail(1) + sampleNames.toSeq.zipWithIndex.map { case (pool_sample, index) => + val sites = lociExpectedSampleFractions.count(expectedSampleFractions => expectedSampleFractions(index) > 0) + val singletons = lociExpectedSampleFractions.count { expectedSampleFractions => + expectedSampleFractions(index) > 0 && expectedSampleFractions.sum == expectedSampleFractions(index) + } + PoolingFractionMetric( + observed_sample = observedSample, + pool_sample = pool_sample, + variant_sites = sites, + singletons = singletons, + estimated_fraction = fractions(index), + standard_error = stderrs(index), + ci99_low = max(0, fractions(index) - stderrs(index)*Ci99Width), + ci99_high = min(1, fractions(index) + stderrs(index)*Ci99Width)) + } } + + logger.info("Writing metrics") + Metric.write(output, metrics) } /** Verify a provided sample list, or if none provided retrieve the set of samples from the VCF. */ - private def pickSamplesToUse(vcfReader: VCFFileReader): Array[String] = { - if (samples.nonEmpty) { - val samplesInVcf = vcfReader.getFileHeader.getSampleNamesInOrder.iterator.toSet - val missingSamples = samples.filterNot(samplesInVcf.contains) + private def pickSamplesToUse(vcfIn: VcfSource): Array[String] = { + if (this.samples.isEmpty) vcfIn.header.samples.toArray else { + val samplesInVcf = vcfIn.header.samples + val missingSamples = samples.toSet.diff(samplesInVcf.toSet) if (missingSamples.nonEmpty) fail(s"Samples not present in VCF: ${missingSamples.mkString(", ")}") else samples.toArray.sorted } - else { - vcfReader.getFileHeader.getSampleNamesInOrder.iterator.toSeq.toArray.sorted // toSeq.toArray is necessary cos util.ArrayList.toArray() exists - } } /** Loads up and merges all the interval lists provided. Returns None if no intervals were specified. */ @@ -163,20 +194,18 @@ class EstimatePoolingFractions } /** Generates an iterator over non-filtered bi-allelic SNPs where all the required samples are genotyped. */ - def constructVcfIterator(in: VCFFileReader, intervals: Option[IntervalList], samples: Array[String]): Iterator[VariantContext] = { - val vcfIterator: Iterator[VariantContext] = intervals match { + def constructVcfIterator(in: VcfSource, intervals: Option[IntervalList], samples: Seq[String]): Iterator[Variant] = { + val iterator: Iterator[Variant] = intervals match { case None => in.iterator - case Some(is) => ByIntervalListVariantContextIterator(in, is) + case Some(is) => is.flatMap(i => in.query(i.getContig, i.getStart, i.getEnd)) } - val samplesAsUtilSet = CollectionUtil.makeSet(samples:_*) - - vcfIterator - .filterNot(_.isFiltered) - .map(_.subContextFromSamples(samplesAsUtilSet, true)) - .filter(v => v.isSNP && v.isBiallelic && !v.isMonomorphicInSamples) - .filter(_.getNoCallCount == 0) - .filter(v => v.getGenotypesOrderedByName.iterator.forall(gt => gt.getGQ >= this.minGenotypeQuality)) + iterator + .filter(v => v.filters.isEmpty || v.filters == Variant.PassingFilters) + .filter(v => v.alleles.size == 2 && v.alleles.forall(a => a.value.length == 1)) // Just biallelic SNPs + .filter(v => samples.map(v.gt).forall(gt => gt.isFullyCalled && (this.minGenotypeQuality <= 0 || gt.get[Int]("GQ").exists(_ >= this.minGenotypeQuality)))) + .map (v => v.copy(genotypes=v.genotypes.filter { case (s, _) => samples.contains(s)} )) + .filter(v => v.gts.flatMap(_.calls).toSet.size > 1) // Not monomorphic } /** Constructs a SamLocusIterator that will visit every locus in the input. */ @@ -192,29 +221,54 @@ class EstimatePoolingFractions javaIteratorAsScalaIterator(iterator) } + /** Computes the observed fraction of the alternate allele at the given locus*/ + private def getObservedFraction(recordAndOffsets: Seq[SamLocusIterator.RecordAndOffset], + locus: Locus, + minCoverage: Int): Option[Double] = { + if (recordAndOffsets.length < minCoverage) None else { + val counts = BaseCounts(recordAndOffsets) + val (ref, alt) = (counts(locus.ref), counts(locus.alt)) + + // Somewhat redundant with check above, but this protects against a large fraction + // of Ns or other alleles, and also against a large proportion of overlapping reads + if (ref + alt < minCoverage) None else { + Some(alt / (ref + alt).toDouble) + } + } + } + /** * Fills in the observedFraction field for each locus that meets coverage and then returns * the subset of loci that met coverage. */ - def fillObserveredFractionAndFilter(loci: Array[Locus], minCoverage: Int): Array[Locus] = { + def fillObserveredFractionAndFilter(loci: Array[Locus], minCoverage: Int): Unit = { val locusIterator = constructBamIterator(loci) locusIterator.zip(loci.iterator).foreach { case (locusInfo, locus) => if (locusInfo.getSequenceName != locus.chrom || locusInfo.getPosition != locus.pos) fail("VCF and BAM iterators out of sync.") - // A gross coverage check here to avoid a lot of work; better check below - if (locusInfo.getRecordAndOffsets.size() > minCoverage) { - val counts = BaseCounts(locusInfo) - val (ref, alt) = (counts(locus.ref), counts(locus.alt)) - - // Somewhat redundant with check above, but this protects against a large fraction - // of Ns or other alleles, and also against a large proportion of overlapping reads - if (ref + alt >= minCoverage) { - locus.observedFraction = Some(alt / (ref + alt).toDouble) + if (bySample) { + locus.observedFraction = locusInfo.getRecordAndOffsets.toSeq + .groupBy(_.getRecord.getReadGroup.getSample) + .flatMap { case (sample, recordAndOffsets) => + val observedFraction = getObservedFraction( + recordAndOffsets = recordAndOffsets, + locus = locus, + minCoverage = minCoverage + ) + observedFraction.map(frac => sample -> frac) + } + } + else { + val observedFraction = getObservedFraction( + recordAndOffsets = locusInfo.getRecordAndOffsets.toSeq, + locus = locus, + minCoverage = minCoverage + ) + observedFraction.foreach { frac => + locus.observedFraction = Map(AllReadGroupsName -> frac) } } } - - loci.filter(_.observedFraction.isDefined) } } @@ -222,7 +276,9 @@ class EstimatePoolingFractions * Metrics produced by `EstimatePoolingFractions` to quantify the estimated proportion of a sample * mixture that is attributable to a specific sample with a known set of genotypes. * - * @param sample The name of the sample within the pool being reported on. + * @param observed_sample The name of the input sample as reported in the read group, or "all" if all read groups are + * being treated as a single input sample. + * @param pool_sample The name of the sample within the pool being reported on. * @param variant_sites How many sites were examined at which the reported sample is known to be variant. * @param singletons How many of the variant sites were sites at which only this sample was variant. * @param estimated_fraction The estimated fraction of the pool that comes from this sample. @@ -230,7 +286,8 @@ class EstimatePoolingFractions * @param ci99_low The lower bound of the 99% confidence interval for the estimated fraction. * @param ci99_high The upper bound of the 99% confidence interval for the estimated fraction. */ -case class PoolingFractionMetric(sample: String, +case class PoolingFractionMetric(observed_sample: String, + pool_sample: String, variant_sites: Count, singletons: Count, estimated_fraction: Proportion, diff --git a/src/test/scala/com/fulcrumgenomics/bam/EstimatePoolingFractionsTest.scala b/src/test/scala/com/fulcrumgenomics/bam/EstimatePoolingFractionsTest.scala index 85c375af7..13f7e58d2 100644 --- a/src/test/scala/com/fulcrumgenomics/bam/EstimatePoolingFractionsTest.scala +++ b/src/test/scala/com/fulcrumgenomics/bam/EstimatePoolingFractionsTest.scala @@ -25,17 +25,16 @@ package com.fulcrumgenomics.bam import java.nio.file.Paths - import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.bam.api.{SamRecord, SamSource, SamWriter} import com.fulcrumgenomics.testing.UnitSpec import com.fulcrumgenomics.util.Metric +import com.fulcrumgenomics.vcf.api +import com.fulcrumgenomics.vcf.api.{Genotype, VcfCount, VcfFieldType, VcfFormatHeader, VcfSource, VcfWriter} import htsjdk.samtools.SAMFileHeader.SortOrder import htsjdk.samtools.{MergingSamRecordIterator, SamFileHeaderMerger} import org.scalatest.ParallelTestExecution -import scala.collection.JavaConverters._ - class EstimatePoolingFractionsTest extends UnitSpec with ParallelTestExecution { private val Samples = Seq("HG01879", "HG01112", "HG01583", "HG01500", "HG03742", "HG03052") private val DataDir = Paths.get("src/test/resources/com/fulcrumgenomics/bam/estimate_pooling_fractions") @@ -44,12 +43,15 @@ class EstimatePoolingFractionsTest extends UnitSpec with ParallelTestExecution { private val Regions = DataDir.resolve("regions.interval_list") /** Merges one or more BAMs and returns the path to the merged BAM. */ - def merge(bams: Seq[PathToBam]): PathToBam = { + def merge(bams: Seq[PathToBam], sample: Option[String] = None): PathToBam = { val readers = bams.map(bam => SamSource(bam)) // Mangle the library names in the header so that the merger sees duplicate RGs as different RGs. readers.zipWithIndex.foreach { case (reader, index) => - reader.header.getReadGroups.foreach(rg => rg.setLibrary(rg.getLibrary + ":" + index)) + reader.header.getReadGroups.foreach { rg => + rg.setLibrary(rg.getLibrary + ":" + index) + sample.foreach(s => rg.setSample(s)) + } } val headerMerger = new SamFileHeaderMerger(SortOrder.coordinate, readers.iterator.map(_.header).toJavaList, false) val iterator = new MergingSamRecordIterator(headerMerger, readers.iterator.map(_.toSamReader).toJavaList, true) @@ -104,7 +106,92 @@ class EstimatePoolingFractionsTest extends UnitSpec with ParallelTestExecution { val metrics = Metric.read[PoolingFractionMetric](out) metrics should have size 2 metrics.foreach {m => - val expected = if (m.sample == samples.head) 0.75 else 0.25 + val expected = if (m.pool_sample == samples.head) 0.75 else 0.25 + expected should (be >= m.ci99_low and be <= m.ci99_high) + } + } + + it should "accurately estimate a three sample mixture using the AF genotype field" in { + val samples = Samples.take(3) + val Seq(s1, s2, s3) = samples + val bams = Bams.take(3) + val bam = merge(bams) + + val vcf = { + val vcf = makeTempFile("mixture.", ".vcf.gz") + val in = api.VcfSource(Vcf) + val hd = in.header.copy( + samples = IndexedSeq(s1, "two_sample_mixture"), + formats = VcfFormatHeader("AF", VcfCount.OnePerAltAllele, kind=VcfFieldType.Float, description="Allele Frequency") +: in.header.formats + ) + val out = VcfWriter(vcf, hd) + + in.filter(_.alleles.size == 2).foreach { v => + val gts = samples.map(v.gt) + + // Only bother with sites where all samples have called genotypes and there is variation + if (gts.forall((_.isFullyCalled)) && gts.flatMap(_.calls).toSet.size > 1) { + // Make a mixture of the 2nd and 3rd samples + val (mixCalls, mixAf) = { + val input = gts.drop(1) + if (input.forall(_.isHomRef)) (IndexedSeq(v.alleles.ref, v.alleles.ref), 0.0) + else if (input.forall(_.isHomVar)) (IndexedSeq(v.alleles.alts.head, v.alleles.alts.head), 1.0) + else { + val calls = input.flatMap(_.calls) + (IndexedSeq(v.alleles.ref, v.alleles.alts.head), calls.count(_ != v.alleles.ref) / calls.size.toDouble) + } + } + + val mixtureGt = Genotype( + alleles = v.alleles, + sample = "two_sample_mixture", + calls = mixCalls, + attrs = Map("AF" -> IndexedSeq[Float](mixAf.toFloat)) + ) + + out += v.copy(genotypes=Map(s1 -> gts.head, mixtureGt.sample -> mixtureGt)) + } + } + + in.safelyClose() + out.close() + vcf + } + + // Run the estimator and test the outputs + val out = makeTempFile("pooling_metrics.", ".txt") + new EstimatePoolingFractions(vcf=vcf, bam=bam, output=out, minGenotypeQuality = -1).execute() + val metrics = Metric.read[PoolingFractionMetric](out) + + metrics should have size 2 + metrics.foreach {m => + val expected = if (m.pool_sample == samples.head) 1/3.0 else 2/3.0 + expected should (be >= m.ci99_low and be <= m.ci99_high) + } + } + + it should "accurately estimate mixes of two samples across multiple input read groups" in { + val samples = Samples.take(2) + val Seq(bam1, bam2) = Bams.take(2) + val inBam1 = merge(Seq(bam1, bam1, bam1, bam2), sample=Some("Sample1")) // 75% bam1, 25% bam2 + val inBam2 = merge(Seq(bam1, bam2, bam2, bam2), sample=Some("Sample2")) // 25% bam1, 75% bam2 + val bam = merge(Seq(inBam1, inBam2)) + val out = makeTempFile("pooling_metrics.", ".txt") + new EstimatePoolingFractions(vcf=Vcf, bam=bam, output=out, samples=samples, bySample=true).execute() + val metrics = Metric.read[PoolingFractionMetric](out) + metrics should have size 4 + + val metrics1 = metrics.filter(_.observed_sample == "Sample1") + metrics1 should have size 2 + metrics1.foreach {m => + val expected = if (m.pool_sample == samples.head) 0.75 else 0.25 + expected should (be >= m.ci99_low and be <= m.ci99_high) + } + + val metrics2 = metrics.filter(_.observed_sample == "Sample2") + metrics2 should have size 2 + metrics2.foreach {m => + val expected = if (m.pool_sample == samples.head) 0.25 else 0.75 expected should (be >= m.ci99_low and be <= m.ci99_high) } }