Skip to content

Commit

Permalink
feat: speed improvements for primer pair hit building from single pri…
Browse files Browse the repository at this point in the history
…mer hits (#99)

Closes #94 

1. `OffTargetDetector._to_amplicons` signature changes from 

```
def _to_amplicons(lefts: list[BwaHit], rights: list[BwaHit], max_len: int) -> list[Span]:
```

to

```
def _to_amplicons(
    positive_hits: list[BwaHit], negative_hits: list[BwaHit], max_len: int, strand: Strand
) -> list[Span]:
```

- Amplicons are built from a list of single primer hits on the positive
strand and a list of single primer hits on the negative strand.
- New parameter `strand` is used to set the `strand` attribute on each
of the returned `Span`s for the amplicon hits.
- Validates that all positive hits are on the positive strand and all
negative hits are on the negative strand.
- Validates that all hits are on the same reference (e.g. contig/chr).

2. The existing unit test for `OffTargetDetector.to_amplicons` is split
into two, one where no error should be raised, and one where a
`ValueError` should be raised, with expected error messages.

3. `OffTargetDetector._build_off_target_result` does not change
signature.

- exits early setting `pass = False` to the returned `OffTargetResult`
if there are too many hits for either the left or the right primer.
- hits are separated by reference (e.g. contig/chr), left/right, and
strand (positive/negative) using defaultdict collections
- this enables passing hits that are on the same strand and in the
correct relative orientation to each other to `_to_amplicons`, along
with a known amplicon strand - positive when the positive hits are from
the left primer of a pair and the negative hits are from the right
primer, and negative in the inverse.
- the full set of left and right primer hits are still returned if
`OffTargetDetector._keep_primer_spans` is `True`.

4. Unit test for `OffTargetDetector._build_off_target_result` is added.
  • Loading branch information
ameynert authored Jan 10, 2025
1 parent 2656adf commit baf2e65
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 53 deletions.
183 changes: 155 additions & 28 deletions prymer/offtarget/offtarget_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
""" # noqa: E501

import itertools
from collections import defaultdict
from contextlib import AbstractContextManager
from dataclasses import dataclass
from dataclasses import field
Expand All @@ -83,13 +84,15 @@
from types import TracebackType
from typing import Optional
from typing import Self
from typing import TypeAlias
from typing import TypeVar

from ordered_set import OrderedSet

from prymer.api.oligo import Oligo
from prymer.api.primer_pair import PrimerPair
from prymer.api.span import Span
from prymer.api.span import Strand
from prymer.offtarget.bwa import BWA_EXECUTABLE_NAME
from prymer.offtarget.bwa import BwaAlnInteractive
from prymer.offtarget.bwa import BwaHit
Expand All @@ -98,6 +101,9 @@

PrimerType = TypeVar("PrimerType", bound=Oligo)

ReferenceName: TypeAlias = str
"""Alias for a reference sequence name."""


@dataclass(init=True, frozen=True)
class OffTargetResult:
Expand Down Expand Up @@ -344,27 +350,78 @@ def _build_off_target_result(
result: OffTargetResult

# Get the mappings for the left primer and right primer respectively
p1: BwaResult = hits_by_primer[primer_pair.left_primer.bases]
p2: BwaResult = hits_by_primer[primer_pair.right_primer.bases]

# Get all possible amplicons from the left_primer_mappings and right_primer_mappings
# primer hits, filtering if there are too many for either
if p1.hit_count > self._max_primer_hits or p2.hit_count > self._max_primer_hits:
left_bwa_result: BwaResult = hits_by_primer[primer_pair.left_primer.bases]
right_bwa_result: BwaResult = hits_by_primer[primer_pair.right_primer.bases]

# If there are too many hits, this primer pair will not pass. Exit early.
if (
left_bwa_result.hit_count > self._max_primer_hits
or right_bwa_result.hit_count > self._max_primer_hits
):
result = OffTargetResult(primer_pair=primer_pair, passes=False)
else:
amplicons = self._to_amplicons(p1.hits, p2.hits, self._max_amplicon_size)
result = OffTargetResult(
primer_pair=primer_pair,
passes=self._min_primer_pair_hits <= len(amplicons) <= self._max_primer_pair_hits,
spans=amplicons if self._keep_spans else [],
left_primer_spans=(
[self._hit_to_span(h) for h in p1.hits] if self._keep_primer_spans else []
),
right_primer_spans=(
[self._hit_to_span(h) for h in p2.hits] if self._keep_primer_spans else []
),
if self._cache_results:
self._primer_pair_cache[primer_pair] = replace(result, cached=True)
return result

# Map the hits by reference name
left_positive_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list)
left_negative_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list)
right_positive_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list)
right_negative_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list)

# Split the hits for left and right by reference name and strand
for hit in left_bwa_result.hits:
if hit.negative:
left_negative_hits[hit.refname].append(hit)
else:
left_positive_hits[hit.refname].append(hit)

for hit in right_bwa_result.hits:
if hit.negative:
right_negative_hits[hit.refname].append(hit)
else:
right_positive_hits[hit.refname].append(hit)

refnames: set[ReferenceName] = {
h.refname for h in itertools.chain(left_bwa_result.hits, right_bwa_result.hits)
}

# Build amplicons from hits on the same reference with valid relative orientation
amplicons: list[Span] = []
for refname in refnames:
amplicons.extend(
self._to_amplicons(
positive_hits=left_positive_hits[refname],
negative_hits=right_negative_hits[refname],
max_len=self._max_amplicon_size,
strand=Strand.POSITIVE,
)
)
amplicons.extend(
self._to_amplicons(
positive_hits=right_positive_hits[refname],
negative_hits=left_negative_hits[refname],
max_len=self._max_amplicon_size,
strand=Strand.NEGATIVE,
)
)

result = OffTargetResult(
primer_pair=primer_pair,
passes=self._min_primer_pair_hits <= len(amplicons) <= self._max_primer_pair_hits,
spans=amplicons if self._keep_spans else [],
left_primer_spans=(
[self._hit_to_span(h) for h in left_bwa_result.hits]
if self._keep_primer_spans
else []
),
right_primer_spans=(
[self._hit_to_span(h) for h in right_bwa_result.hits]
if self._keep_primer_spans
else []
),
)

if self._cache_results:
self._primer_pair_cache[primer_pair] = replace(result, cached=True)

Expand Down Expand Up @@ -420,19 +477,89 @@ def mappings_of(self, primers: list[PrimerType]) -> dict[str, BwaResult]:
return hits_by_primer

@staticmethod
def _to_amplicons(lefts: list[BwaHit], rights: list[BwaHit], max_len: int) -> list[Span]:
"""Takes a set of hits for one or more left primers and right primers and constructs
amplicon mappings anywhere a left primer hit and a right primer hit align in F/R
orientation up to `maxLen` apart on the same reference. Primers may not overlap.
def _to_amplicons(
positive_hits: list[BwaHit], negative_hits: list[BwaHit], max_len: int, strand: Strand
) -> list[Span]:
"""Takes lists of positive strand hits and negative strand hits and constructs amplicon
mappings anywhere a positive strand hit and a negative strand hit occur where the end of
the negative strand hit is no more than `max_len` from the start of the positive strand
hit.
Primers may not overlap.
Args:
positive_hits: List of hits on the positive strand for one of the primers in the pair.
negative_hits: List of hits on the negative strand for the other primer in the pair.
max_len: Maximum length of amplicons to consider.
strand: The strand of the amplicon to generate. Set to Strand.POSITIVE if
`positive_hits` are for the left primer and `negative_hits` are for the right
primer. Set to Strand.NEGATIVE if `positive_hits` are for the right primer and
`negative_hits` are for the left primer.
Raises:
ValueError: If any of the positive hits are not on the positive strand, or any of the
negative hits are not on the negative strand. If hits are present on more than one
reference.
"""
if any(h.negative for h in positive_hits):
raise ValueError("Positive hits must be on the positive strand.")
if any(not h.negative for h in negative_hits):
raise ValueError("Negative hits must be on the negative strand.")

refnames: set[ReferenceName] = {
h.refname for h in itertools.chain(positive_hits, negative_hits)
}
if len(refnames) > 1:
raise ValueError(f"Hits are present on more than one reference: {refnames}")

# Exit early if one of the hit lists is empty - this will save unnecessary sorting of the
# other list
if len(positive_hits) == 0 or len(negative_hits) == 0:
return []

# Sort the positive strand hits by start position and the negative strand hits by *end*
# position. The `max_len` cutoff is based on negative_hit.end - positive_hit.start + 1.
positive_hits_sorted = sorted(positive_hits, key=lambda h: h.start)
negative_hits_sorted = sorted(negative_hits, key=lambda h: h.end)

amplicons: list[Span] = []
for h1, h2 in itertools.product(lefts, rights):
if h1.negative == h2.negative or h1.refname != h2.refname: # not F/R orientation
continue

plus, minus = (h2, h1) if h1.negative else (h1, h2)
if minus.start > plus.end and (minus.end - plus.start + 1) <= max_len:
amplicons.append(Span(refname=plus.refname, start=plus.start, end=minus.end))
# Track the position of the previously examined negative hit.
prev_negative_hit_index = 0
for positive_hit in positive_hits_sorted:
# Check only negative hits starting with the previously examined one.
for negative_hit_index, negative_hit in enumerate(
negative_hits_sorted[prev_negative_hit_index:],
start=prev_negative_hit_index,
):
# TODO: Consider allowing overlapping positive and negative hits.
if (
negative_hit.start > positive_hit.end
and negative_hit.end - positive_hit.start + 1 <= max_len
):
# If the negative hit starts to the right of the positive hit, and the amplicon
# length is <= max_len, add it to the list of amplicon hits to be returned.
amplicons.append(
Span(
refname=positive_hit.refname,
start=positive_hit.start,
end=negative_hit.end,
strand=strand,
)
)

if negative_hit.end - positive_hit.start + 1 > max_len:
# Stop searching for negative hits to pair with this positive hit.
# All subsequence negative hits will have amplicon length > max_len
break

if negative_hit.end < positive_hit.start:
# This positive hit is genomically right of the current negative hit.
# All subsequent positive hits will also be genomically right of this negative
# hit, so we should start at the one after this. If this index is past the end
# of the list, the slice `negative_hits_sorted[prev_negative_hit_index:]` will
# be empty.
prev_negative_hit_index = negative_hit_index + 1

return amplicons

Expand Down
Loading

0 comments on commit baf2e65

Please sign in to comment.