Skip to content

Commit

Permalink
[CZID-9262] Only call taxa with species level lineages in long-read-m…
Browse files Browse the repository at this point in the history
…ngs (#349)
  • Loading branch information
lvreynoso authored Apr 26, 2024
1 parent 7e6d8a4 commit dfb8475
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 33 deletions.
22 changes: 12 additions & 10 deletions lib/idseq-dag/idseq_dag/steps/blast_contigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,15 @@ def run(self):
command.write_text_to_file('[]', contig_summary_json)
return # return in the middle of the function

lineage_db = s3.fetch_reference(
self.additional_files["lineage_db"],
self.ref_dir_local,
allow_s3mi=False) # Too small to waste s3mi

accession2taxid_dict = s3.fetch_reference(self.additional_files["accession2taxid"], self.ref_dir_local)

(read_dict, accession_dict, _selected_genera) = m8.summarize_hits(hit_summary)
PipelineStepBlastContigs.run_blast(db_type, blast_m8, assembled_contig, reference_fasta, blast_top_m8)
PipelineStepBlastContigs.run_blast(db_type, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict)
read2contig = {}
generate_info_from_sam(bowtie_sam, read2contig, duplicate_cluster_sizes_path=duplicate_cluster_sizes_path)

Expand All @@ -106,11 +113,6 @@ def run(self):
refined_hit_summary, refined_m8)

# Generating taxon counts based on updated results
lineage_db = s3.fetch_reference(
self.additional_files["lineage_db"],
self.ref_dir_local,
allow_s3mi=False) # Too small to waste s3mi

deuterostome_db = None
if self.additional_files.get("deuterostome_db"):
deuterostome_db = s3.fetch_reference(self.additional_files["deuterostome_db"],
Expand Down Expand Up @@ -264,7 +266,7 @@ def update_read_dict(read2contig, blast_top_blastn_6_path, read_dict, accession_
return (consolidated_dict, read2blastm8, contig2lineage, added_reads)

@staticmethod
def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8):
def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict):
blast_type = 'nucl'
blast_command = 'blastn'
command.execute(
Expand Down Expand Up @@ -308,10 +310,10 @@ def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta,
)
)
# further processing of getting the top m8 entry for each contig.
get_top_m8_nt(blast_m8, blast_top_m8)
get_top_m8_nt(blast_m8, lineage_db, accession2taxid_dict, blast_top_m8)

@staticmethod
def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8):
def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict):
blast_type = 'prot'
blast_command = 'blastx'
command.execute(
Expand Down Expand Up @@ -349,4 +351,4 @@ def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta,
)
)
# further processing of getting the top m8 entry for each contig.
get_top_m8_nr(blast_m8, blast_top_m8)
get_top_m8_nr(blast_m8, lineage_db, accession2taxid_dict, blast_top_m8)
74 changes: 56 additions & 18 deletions lib/idseq-dag/idseq_dag/util/top_hits.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections import defaultdict

import idseq_dag.util.lineage as lineage

from idseq_dag.util.dict import open_file_db_by_extension
from idseq_dag.util.m8 import NT_MIN_ALIGNMENT_LEN
from idseq_dag.util.parsing import BlastnOutput6Reader, BlastnOutput6Writer
from idseq_dag.util.parsing import BlastnOutput6NTReader
Expand Down Expand Up @@ -56,6 +59,14 @@ def _intersects(needle, haystack):
''' Return True iff needle intersects haystack. Ignore overlap < NT_MIN_OVERLAP_FRACTION. '''
return any(_hsp_overlap(needle, hay) for hay in haystack)

def _get_lineage(accession_id, lineage_map, accession2taxid_dict, lineage_cache={}):
if accession_id in lineage_cache:
return lineage_cache[accession_id]
accession_taxid = accession2taxid_dict.get(
accession_id.split(".")[0], "NA")
result = lineage_map.get(accession_taxid, lineage.NULL_LINEAGE)
lineage_cache[accession_id] = result
return result

class BlastCandidate:
'''Documented in function get_top_m8_nt() below.'''
Expand All @@ -64,6 +75,7 @@ def __init__(self, hsps):
self.hsps = hsps
self.optimal_cover = None
self.total_score = None

if hsps:
# All HSPs here must be for the same query and subject sequence.
h_0 = hsps[0]
Expand Down Expand Up @@ -105,23 +117,34 @@ def summary(self):
return r


def _optimal_hit_for_each_query_nr(blast_output_path, max_evalue):
def _optimal_hit_for_each_query_nr(blast_output_path, lineage_map, accession2taxid_dict, max_evalue):
contigs_to_best_alignments = defaultdict(list)
accession_counts = defaultdict(lambda: 0)

with open(blast_output_path) as blastn_6_f:
# For each contig, get the alignments that have the best total score (may be multiple if there are ties).
# Prioritize the specificity of the hit.
specificity_to_best_alignments = defaultdict(dict)
for alignment in BlastnOutput6Reader(blastn_6_f):
if alignment["evalue"] > max_evalue:
continue
query = alignment["qseqid"]
best_alignments = contigs_to_best_alignments[query]

if len(best_alignments) == 0 or best_alignments[0]["bitscore"] < alignment["bitscore"]:
contigs_to_best_alignments[query] = [alignment]
lineage_taxids = _get_lineage(alignment["sseqid"], lineage_map, accession2taxid_dict)
specificity = next((level for level, taxid_at_level in enumerate(lineage_taxids) if int(taxid_at_level) > 0), float("inf"))

best_alignments = specificity_to_best_alignments[query]

if (specificity not in best_alignments) or best_alignments[specificity][0]["bitscore"] < alignment["bitscore"]:
specificity_to_best_alignments[query][specificity] = [alignment]
# Add all ties to best_hits.
elif len(best_alignments) > 0 and best_alignments[0]["bitscore"] == alignment["bitscore"]:
contigs_to_best_alignments[query].append(alignment)
elif len(best_alignments[specificity]) > 0 and best_alignments[specificity][0]["bitscore"] == alignment["bitscore"]:
specificity_to_best_alignments[query][specificity].append(alignment)

# Choose the best alignments with the most specific taxid information.
for contig_id, specificity_alignment_dict in specificity_to_best_alignments.items():
specific_best_alignments = next(specificity_alignment_dict[specificity] for specificity in sorted(specificity_alignment_dict.keys()))
contigs_to_best_alignments[contig_id] = specific_best_alignments

# Create a map of accession to best alignment count.
for _contig_id, alignments in contigs_to_best_alignments.items():
Expand Down Expand Up @@ -170,26 +193,33 @@ def _filter_and_group_hits_by_query(blast_output_path, min_alignment_length, min


# An iterator that, for contig, yields to optimal hit for the contig in the nt blast_output.
def _optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_pident, max_evalue, summary=True):
def _optimal_hit_for_each_query_nt(blast_output, lineage_map, accession2taxid_dict,
min_alignment_length, min_pident, max_evalue, summary=True):
contigs_to_blast_candidates = {}
accession_counts = defaultdict(lambda: 0)

# For each contig, get the collection of blast candidates that have the best total score (may be multiple if there are ties).
for query_hits in _filter_and_group_hits_by_query(blast_output, min_alignment_length, min_pident, max_evalue):
best_hits = []
best_hits = {}
for _subject, hit_fragments in query_hits.items():
# For NT, we take a special approach where we try to find a subset of disjoint HSPs
# with maximum sum of bitscores.
hit = BlastCandidate(hit_fragments)
hit.optimize()

if len(best_hits) == 0 or best_hits[0].total_score < hit.total_score:
best_hits = [hit]
# Add all ties to best_hits.
elif len(best_hits) > 0 and best_hits[0].total_score == hit.total_score:
best_hits.append(hit)
# We prioritize the specificity of the hit; hits with species taxids are taken before hits without
# Specificity is just the index of the tuple returned by _get_lineage(); 0 for species, 1 for genus, etc.
lineage_taxids = _get_lineage(hit.sseqid, lineage_map, accession2taxid_dict)
specificity = next((level for level, taxid_at_level in enumerate(lineage_taxids) if int(taxid_at_level) > 0), float("inf"))

if (specificity not in best_hits) or best_hits[specificity][0].total_score < hit.total_score:
best_hits[specificity] = [hit]
# Add all ties to best_hits[specificity].
elif len(best_hits[specificity]) > 0 and best_hits[specificity][0].total_score == hit.total_score:
best_hits[specificity].append(hit)

contigs_to_blast_candidates[best_hits[0].qseqid] = best_hits
specific_best_hits = next(best_hits[specificity] for specificity in sorted(best_hits.keys()))
contigs_to_blast_candidates[specific_best_hits[0].qseqid] = specific_best_hits

# Create a map of accession to blast candidate count.
for _contig_id, blast_candidates in contigs_to_blast_candidates.items():
Expand All @@ -213,18 +243,24 @@ def _optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_piden

def get_top_m8_nr(
blast_output,
lineage_map_path,
accession2taxid_dict_path,
blast_top_blastn_6_path,
max_evalue=MAX_EVALUE_THRESHOLD,
):
''' Get top m8 file entry for each contig from blast_output and output to blast_top_m8 '''
with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f:
with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f, \
open_file_db_by_extension(lineage_map_path, "lll") as lineage_map, \
open_file_db_by_extension(accession2taxid_dict_path, "L") as accession2taxid_dict: # noqa
BlastnOutput6Writer(blast_top_blastn_6_f).writerows(
_optimal_hit_for_each_query_nr(blast_output, max_evalue)
_optimal_hit_for_each_query_nr(blast_output, lineage_map, accession2taxid_dict, max_evalue)
)


def get_top_m8_nt(
blast_output,
lineage_map_path,
accession2taxid_dict_path,
blast_top_blastn_6_path,
min_alignment_length=NT_MIN_ALIGNMENT_LEN,
min_pident=_NT_MIN_PIDENT,
Expand All @@ -251,7 +287,9 @@ def get_top_m8_nt(

# Output the optimal hit for each query.

with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f:
with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f, \
open_file_db_by_extension(lineage_map_path, "lll") as lineage_map, \
open_file_db_by_extension(accession2taxid_dict_path, "L") as accession2taxid_dict: # noqa
BlastnOutput6NTRerankedWriter(blast_top_blastn_6_f).writerows(
_optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_pident, max_evalue, False)
_optimal_hit_for_each_query_nt(blast_output, lineage_map, accession2taxid_dict, min_alignment_length, min_pident, max_evalue, False)
)
22 changes: 20 additions & 2 deletions workflows/long-read-mngs/run.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,8 @@ task RunNRAlignment {
task FindTopHitsNT {
input {
File nt_m8
File lineage_db
File accession2taxid_db
String docker_image_id
}
Expand All @@ -652,7 +654,12 @@ task FindTopHitsNT {
python3 <<CODE
from idseq_dag.steps.blast_contigs import get_top_m8_nt
get_top_m8_nt("~{nt_m8}", "gsnap.blast.top.m8")
get_top_m8_nt(
"~{nt_m8}",
"~{lineage_db}",
"~{accession2taxid_db}",
"gsnap.blast.top.m8",
)
CODE
python3 - << 'EOF'
Expand Down Expand Up @@ -680,6 +687,8 @@ task FindTopHitsNT {
task FindTopHitsNR {
input {
File nr_m8
File lineage_db
File accession2taxid_db
String docker_image_id
}
Expand All @@ -688,7 +697,12 @@ task FindTopHitsNR {
python3 <<CODE
from idseq_dag.steps.blast_contigs import get_top_m8_nr
get_top_m8_nr("~{nr_m8}", "rapsearch2.blast.top.m8")
get_top_m8_nr(
"~{nr_m8}",
"~{lineage_db}",
"~{accession2taxid_db}",
"rapsearch2.blast.top.m8"
)
CODE
python3 - << 'EOF'
Expand Down Expand Up @@ -1399,12 +1413,16 @@ workflow czid_long_read_mngs {
call FindTopHitsNT {
input:
nt_m8 = RunNTAlignment.nt_m8,
lineage_db = lineage_db,
accession2taxid_db = accession2taxid_db,
docker_image_id = docker_image_id,
}
call FindTopHitsNR {
input:
nr_m8 = RunNRAlignment.nr_m8,
lineage_db = lineage_db,
accession2taxid_db = accession2taxid_db,
docker_image_id = docker_image_id,
}
Expand Down
9 changes: 7 additions & 2 deletions workflows/short-read-mngs/postprocess.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ task BlastContigs_refined_gsnap_out {
File assembly_nt_refseq_fasta
File duplicate_cluster_sizes_tsv
File lineage_db
File accession2taxid
File taxon_blacklist
File deuterostome_db
Boolean use_deuterostome_filter
Expand All @@ -171,7 +172,7 @@ task BlastContigs_refined_gsnap_out {
--input-files '[["~{gsnap_out_gsnap_m8}", "~{gsnap_out_gsnap_deduped_m8}", "~{gsnap_out_gsnap_hitsummary_tab}", "~{gsnap_out_gsnap_counts_with_dcr_json}"], ["~{assembly_contigs_fasta}", "~{assembly_scaffolds_fasta}", "~{assembly_read_contig_sam}", "~{assembly_contig_stats_json}"], ["~{assembly_nt_refseq_fasta}"], ["~{duplicate_cluster_sizes_tsv}"]]' \
--output-files '["assembly/gsnap.blast.m8", "assembly/gsnap.reassigned.m8", "assembly/gsnap.hitsummary2.tab", "assembly/refined_gsnap_counts_with_dcr.json", "assembly/gsnap_contig_summary.json", "assembly/gsnap.blast.top.m8"]' \
--output-dir-s3 '~{s3_wd_uri}' \
--additional-files '{"lineage_db": "~{lineage_db}", "taxon_blacklist": "~{taxon_blacklist}", "deuterostome_db": "~{if use_deuterostome_filter then '~{deuterostome_db}' else ''}"}' \
--additional-files '{"lineage_db": "~{lineage_db}", "accession2taxid": "~{accession2taxid}", "taxon_blacklist": "~{taxon_blacklist}", "deuterostome_db": "~{if use_deuterostome_filter then '~{deuterostome_db}' else ''}"}' \
--additional-attributes '{"db_type": "nt", "use_taxon_whitelist": ~{use_taxon_whitelist}}'
>>>
output {
Expand Down Expand Up @@ -204,6 +205,7 @@ task BlastContigs_refined_rapsearch2_out {
File assembly_nr_refseq_fasta
File duplicate_cluster_sizes_tsv
File lineage_db
File accession2taxid
File taxon_blacklist
Boolean use_taxon_whitelist
}
Expand All @@ -216,7 +218,7 @@ task BlastContigs_refined_rapsearch2_out {
--input-files '[["~{rapsearch2_out_rapsearch2_m8}", "~{rapsearch2_out_rapsearch2_deduped_m8}", "~{rapsearch2_out_rapsearch2_hitsummary_tab}", "~{rapsearch2_out_rapsearch2_counts_with_dcr_json}"], ["~{assembly_contigs_fasta}", "~{assembly_scaffolds_fasta}", "~{assembly_read_contig_sam}", "~{assembly_contig_stats_json}"], ["~{assembly_nr_refseq_fasta}"], ["~{duplicate_cluster_sizes_tsv}"]]' \
--output-files '["assembly/rapsearch2.blast.m8", "assembly/rapsearch2.reassigned.m8", "assembly/rapsearch2.hitsummary2.tab", "assembly/refined_rapsearch2_counts_with_dcr.json", "assembly/rapsearch2_contig_summary.json", "assembly/rapsearch2.blast.top.m8"]' \
--output-dir-s3 '~{s3_wd_uri}' \
--additional-files '{"lineage_db": "~{lineage_db}", "taxon_blacklist": "~{taxon_blacklist}"}' \
--additional-files '{"lineage_db": "~{lineage_db}", "accession2taxid": "~{accession2taxid}", "taxon_blacklist": "~{taxon_blacklist}"}' \
--additional-attributes '{"db_type": "nr", "use_taxon_whitelist": ~{use_taxon_whitelist}}'
>>>
output {
Expand Down Expand Up @@ -489,6 +491,7 @@ workflow czid_postprocess {
String nr_db = "s3://czid-public-references/ncbi-sources/2021-01-22/nr"
File nr_loc_db = "s3://czid-public-references/alignment_data/2021-01-22/nr_loc.db"
File lineage_db = "s3://czid-public-references/taxonomy/2021-01-22/taxid-lineages.db"
File accession2taxid_db = "s3://czid-public-references/ncbi-indexes-prod/2021-01-22/index-generation-2/accession2taxid.marisa"
File taxon_blacklist = "s3://czid-public-references/taxonomy/2021-01-22/taxon_blacklist.txt"
File deuterostome_db = "s3://czid-public-references/taxonomy/2021-01-22/deuterostome_taxids.txt"
Boolean use_deuterostome_filter = true
Expand Down Expand Up @@ -556,6 +559,7 @@ workflow czid_postprocess {
assembly_nt_refseq_fasta = DownloadAccessions_gsnap_accessions_out.assembly_nt_refseq_fasta,
duplicate_cluster_sizes_tsv = duplicate_cluster_sizes_tsv,
lineage_db = lineage_db,
accession2taxid = accession2taxid_db,
taxon_blacklist = taxon_blacklist,
deuterostome_db = deuterostome_db,
use_deuterostome_filter = use_deuterostome_filter,
Expand All @@ -577,6 +581,7 @@ workflow czid_postprocess {
assembly_nr_refseq_fasta = DownloadAccessions_rapsearch2_accessions_out.assembly_nr_refseq_fasta,
duplicate_cluster_sizes_tsv = duplicate_cluster_sizes_tsv,
lineage_db = lineage_db,
accession2taxid = accession2taxid_db,
taxon_blacklist = taxon_blacklist,
use_taxon_whitelist = use_taxon_whitelist
}
Expand Down
3 changes: 2 additions & 1 deletion workflows/short-read-mngs/test/local_test_viral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ postprocess.nt_db: s3://czid-public-references/test/viral-alignment-indexes/vira
postprocess.nt_loc_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt_loc.marisa
postprocess.nr_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nr
postprocess.nr_loc_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nr_loc.marisa
postprocess.accession2taxid_db: s3://czid-public-references/mini-database/alignment_indexes/2020-08-20-viral/viral_accessions2taxid.marisa
experimental.nt_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt
experimental.nt_loc_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt_loc.marisa
experimental.nt_info_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt_info.marisa
experimental.nt_info_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt_info.marisa

0 comments on commit dfb8475

Please sign in to comment.