diff --git a/prymer/offtarget/offtarget_detector.py b/prymer/offtarget/offtarget_detector.py index 4c6e053..fb23075 100644 --- a/prymer/offtarget/offtarget_detector.py +++ b/prymer/offtarget/offtarget_detector.py @@ -75,6 +75,7 @@ """ # noqa: E501 import itertools +from collections import defaultdict from contextlib import AbstractContextManager from dataclasses import dataclass from dataclasses import field @@ -131,24 +132,6 @@ class OffTargetResult: right_primer_spans: list[Span] = field(default_factory=list) -@dataclass(init=True) -class PrimerPairBwaHitsBySideAndStrand: - """A helper class for storing BWA hits for the left and right primers of a primer pair organized - by left/right and hit strand. - - Attributes: - left_positive: A list of BwaHit objects on the positive strand for the left primer. - left_negative: A list of BwaHit objects on the negative strand for the left primer. - right_positive: A list of BwaHit objects on the positive strand for the right primer. - right_negative: A list of BwaHit objects on the negative strand for the right primer. - """ - - left_positive: list[BwaHit] = field(default_factory=list) - left_negative: list[BwaHit] = field(default_factory=list) - right_positive: list[BwaHit] = field(default_factory=list) - right_negative: list[BwaHit] = field(default_factory=list) - - class OffTargetDetector(AbstractContextManager): """A class for detecting off-target mappings of primers and primer pairs that uses a custom version of "bwa aln" named "bwa-aln-interactive". @@ -370,40 +353,44 @@ def _build_off_target_result( self._primer_pair_cache[primer_pair] = replace(result, cached=True) return result - # Get the set of reference names with hits - hits_by_refname: dict[str, PrimerPairBwaHitsBySideAndStrand] = { - hit.refname: PrimerPairBwaHitsBySideAndStrand() - for hit in itertools.chain(left_bwa_result.hits, right_bwa_result.hits) - } + # Map the hits by reference name + left_positive_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list) + left_negative_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list) + right_positive_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list) + right_negative_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list) # Split the hits for left and right by reference name and strand for hit in left_bwa_result.hits: if hit.negative: - hits_by_refname[hit.refname].left_negative.append(hit) + left_negative_hits[hit.refname].append(hit) else: - hits_by_refname[hit.refname].left_positive.append(hit) + left_positive_hits[hit.refname].append(hit) for hit in right_bwa_result.hits: if hit.negative: - hits_by_refname[hit.refname].right_negative.append(hit) + right_negative_hits[hit.refname].append(hit) else: - hits_by_refname[hit.refname].right_positive.append(hit) + right_positive_hits[hit.refname].append(hit) + + refnames: set[ReferenceName] = { + h.refname for h in itertools.chain(left_bwa_result.hits, right_bwa_result.hits) + } # Build amplicons from hits on the same reference with valid relative orientation amplicons: list[Span] = [] - for hits in hits_by_refname.values(): + for refname in refnames: amplicons.extend( self._to_amplicons( - positive_hits=hits.left_positive, - negative_hits=hits.right_negative, + positive_hits=left_positive_hits[refname], + negative_hits=right_negative_hits[refname], max_len=self._max_amplicon_size, strand=Strand.POSITIVE, ) ) amplicons.extend( self._to_amplicons( - positive_hits=hits.right_positive, - negative_hits=hits.left_negative, + positive_hits=right_positive_hits[refname], + negative_hits=left_negative_hits[refname], max_len=self._max_amplicon_size, strand=Strand.NEGATIVE, )