From dfb84756632857ce71b6ac6b9f0e1e14f0a99312 Mon Sep 17 00:00:00 2001 From: Lucia Reynoso Date: Fri, 26 Apr 2024 12:07:31 -0700 Subject: [PATCH] [CZID-9262] Only call taxa with species level lineages in long-read-mngs (#349) --- .../idseq_dag/steps/blast_contigs.py | 22 +++--- lib/idseq-dag/idseq_dag/util/top_hits.py | 74 ++++++++++++++----- workflows/long-read-mngs/run.wdl | 22 +++++- workflows/short-read-mngs/postprocess.wdl | 9 ++- .../short-read-mngs/test/local_test_viral.yml | 3 +- 5 files changed, 97 insertions(+), 33 deletions(-) diff --git a/lib/idseq-dag/idseq_dag/steps/blast_contigs.py b/lib/idseq-dag/idseq_dag/steps/blast_contigs.py index 643e677dd..6da4382a8 100644 --- a/lib/idseq-dag/idseq_dag/steps/blast_contigs.py +++ b/lib/idseq-dag/idseq_dag/steps/blast_contigs.py @@ -94,8 +94,15 @@ def run(self): command.write_text_to_file('[]', contig_summary_json) return # return in the middle of the function + lineage_db = s3.fetch_reference( + self.additional_files["lineage_db"], + self.ref_dir_local, + allow_s3mi=False) # Too small to waste s3mi + + accession2taxid_dict = s3.fetch_reference(self.additional_files["accession2taxid"], self.ref_dir_local) + (read_dict, accession_dict, _selected_genera) = m8.summarize_hits(hit_summary) - PipelineStepBlastContigs.run_blast(db_type, blast_m8, assembled_contig, reference_fasta, blast_top_m8) + PipelineStepBlastContigs.run_blast(db_type, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict) read2contig = {} generate_info_from_sam(bowtie_sam, read2contig, duplicate_cluster_sizes_path=duplicate_cluster_sizes_path) @@ -106,11 +113,6 @@ def run(self): refined_hit_summary, refined_m8) # Generating taxon counts based on updated results - lineage_db = s3.fetch_reference( - self.additional_files["lineage_db"], - self.ref_dir_local, - allow_s3mi=False) # Too small to waste s3mi - deuterostome_db = None if self.additional_files.get("deuterostome_db"): deuterostome_db = s3.fetch_reference(self.additional_files["deuterostome_db"], @@ -264,7 +266,7 @@ def update_read_dict(read2contig, blast_top_blastn_6_path, read_dict, accession_ return (consolidated_dict, read2blastm8, contig2lineage, added_reads) @staticmethod - def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8): + def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict): blast_type = 'nucl' blast_command = 'blastn' command.execute( @@ -308,10 +310,10 @@ def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta, ) ) # further processing of getting the top m8 entry for each contig. - get_top_m8_nt(blast_m8, blast_top_m8) + get_top_m8_nt(blast_m8, lineage_db, accession2taxid_dict, blast_top_m8) @staticmethod - def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8): + def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict): blast_type = 'prot' blast_command = 'blastx' command.execute( @@ -349,4 +351,4 @@ def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta, ) ) # further processing of getting the top m8 entry for each contig. - get_top_m8_nr(blast_m8, blast_top_m8) + get_top_m8_nr(blast_m8, lineage_db, accession2taxid_dict, blast_top_m8) diff --git a/lib/idseq-dag/idseq_dag/util/top_hits.py b/lib/idseq-dag/idseq_dag/util/top_hits.py index 06bd83790..fd367e76a 100644 --- a/lib/idseq-dag/idseq_dag/util/top_hits.py +++ b/lib/idseq-dag/idseq_dag/util/top_hits.py @@ -1,5 +1,8 @@ from collections import defaultdict +import idseq_dag.util.lineage as lineage + +from idseq_dag.util.dict import open_file_db_by_extension from idseq_dag.util.m8 import NT_MIN_ALIGNMENT_LEN from idseq_dag.util.parsing import BlastnOutput6Reader, BlastnOutput6Writer from idseq_dag.util.parsing import BlastnOutput6NTReader @@ -56,6 +59,14 @@ def _intersects(needle, haystack): ''' Return True iff needle intersects haystack. Ignore overlap < NT_MIN_OVERLAP_FRACTION. ''' return any(_hsp_overlap(needle, hay) for hay in haystack) +def _get_lineage(accession_id, lineage_map, accession2taxid_dict, lineage_cache={}): + if accession_id in lineage_cache: + return lineage_cache[accession_id] + accession_taxid = accession2taxid_dict.get( + accession_id.split(".")[0], "NA") + result = lineage_map.get(accession_taxid, lineage.NULL_LINEAGE) + lineage_cache[accession_id] = result + return result class BlastCandidate: '''Documented in function get_top_m8_nt() below.''' @@ -64,6 +75,7 @@ def __init__(self, hsps): self.hsps = hsps self.optimal_cover = None self.total_score = None + if hsps: # All HSPs here must be for the same query and subject sequence. h_0 = hsps[0] @@ -105,23 +117,34 @@ def summary(self): return r -def _optimal_hit_for_each_query_nr(blast_output_path, max_evalue): +def _optimal_hit_for_each_query_nr(blast_output_path, lineage_map, accession2taxid_dict, max_evalue): contigs_to_best_alignments = defaultdict(list) accession_counts = defaultdict(lambda: 0) with open(blast_output_path) as blastn_6_f: # For each contig, get the alignments that have the best total score (may be multiple if there are ties). + # Prioritize the specificity of the hit. + specificity_to_best_alignments = defaultdict(dict) for alignment in BlastnOutput6Reader(blastn_6_f): if alignment["evalue"] > max_evalue: continue query = alignment["qseqid"] - best_alignments = contigs_to_best_alignments[query] - if len(best_alignments) == 0 or best_alignments[0]["bitscore"] < alignment["bitscore"]: - contigs_to_best_alignments[query] = [alignment] + lineage_taxids = _get_lineage(alignment["sseqid"], lineage_map, accession2taxid_dict) + specificity = next((level for level, taxid_at_level in enumerate(lineage_taxids) if int(taxid_at_level) > 0), float("inf")) + + best_alignments = specificity_to_best_alignments[query] + + if (specificity not in best_alignments) or best_alignments[specificity][0]["bitscore"] < alignment["bitscore"]: + specificity_to_best_alignments[query][specificity] = [alignment] # Add all ties to best_hits. - elif len(best_alignments) > 0 and best_alignments[0]["bitscore"] == alignment["bitscore"]: - contigs_to_best_alignments[query].append(alignment) + elif len(best_alignments[specificity]) > 0 and best_alignments[specificity][0]["bitscore"] == alignment["bitscore"]: + specificity_to_best_alignments[query][specificity].append(alignment) + + # Choose the best alignments with the most specific taxid information. + for contig_id, specificity_alignment_dict in specificity_to_best_alignments.items(): + specific_best_alignments = next(specificity_alignment_dict[specificity] for specificity in sorted(specificity_alignment_dict.keys())) + contigs_to_best_alignments[contig_id] = specific_best_alignments # Create a map of accession to best alignment count. for _contig_id, alignments in contigs_to_best_alignments.items(): @@ -170,26 +193,33 @@ def _filter_and_group_hits_by_query(blast_output_path, min_alignment_length, min # An iterator that, for contig, yields to optimal hit for the contig in the nt blast_output. -def _optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_pident, max_evalue, summary=True): +def _optimal_hit_for_each_query_nt(blast_output, lineage_map, accession2taxid_dict, + min_alignment_length, min_pident, max_evalue, summary=True): contigs_to_blast_candidates = {} accession_counts = defaultdict(lambda: 0) # For each contig, get the collection of blast candidates that have the best total score (may be multiple if there are ties). for query_hits in _filter_and_group_hits_by_query(blast_output, min_alignment_length, min_pident, max_evalue): - best_hits = [] + best_hits = {} for _subject, hit_fragments in query_hits.items(): # For NT, we take a special approach where we try to find a subset of disjoint HSPs # with maximum sum of bitscores. hit = BlastCandidate(hit_fragments) hit.optimize() - if len(best_hits) == 0 or best_hits[0].total_score < hit.total_score: - best_hits = [hit] - # Add all ties to best_hits. - elif len(best_hits) > 0 and best_hits[0].total_score == hit.total_score: - best_hits.append(hit) + # We prioritize the specificity of the hit; hits with species taxids are taken before hits without + # Specificity is just the index of the tuple returned by _get_lineage(); 0 for species, 1 for genus, etc. + lineage_taxids = _get_lineage(hit.sseqid, lineage_map, accession2taxid_dict) + specificity = next((level for level, taxid_at_level in enumerate(lineage_taxids) if int(taxid_at_level) > 0), float("inf")) + + if (specificity not in best_hits) or best_hits[specificity][0].total_score < hit.total_score: + best_hits[specificity] = [hit] + # Add all ties to best_hits[specificity]. + elif len(best_hits[specificity]) > 0 and best_hits[specificity][0].total_score == hit.total_score: + best_hits[specificity].append(hit) - contigs_to_blast_candidates[best_hits[0].qseqid] = best_hits + specific_best_hits = next(best_hits[specificity] for specificity in sorted(best_hits.keys())) + contigs_to_blast_candidates[specific_best_hits[0].qseqid] = specific_best_hits # Create a map of accession to blast candidate count. for _contig_id, blast_candidates in contigs_to_blast_candidates.items(): @@ -213,18 +243,24 @@ def _optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_piden def get_top_m8_nr( blast_output, + lineage_map_path, + accession2taxid_dict_path, blast_top_blastn_6_path, max_evalue=MAX_EVALUE_THRESHOLD, ): ''' Get top m8 file entry for each contig from blast_output and output to blast_top_m8 ''' - with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f: + with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f, \ + open_file_db_by_extension(lineage_map_path, "lll") as lineage_map, \ + open_file_db_by_extension(accession2taxid_dict_path, "L") as accession2taxid_dict: # noqa BlastnOutput6Writer(blast_top_blastn_6_f).writerows( - _optimal_hit_for_each_query_nr(blast_output, max_evalue) + _optimal_hit_for_each_query_nr(blast_output, lineage_map, accession2taxid_dict, max_evalue) ) def get_top_m8_nt( blast_output, + lineage_map_path, + accession2taxid_dict_path, blast_top_blastn_6_path, min_alignment_length=NT_MIN_ALIGNMENT_LEN, min_pident=_NT_MIN_PIDENT, @@ -251,7 +287,9 @@ def get_top_m8_nt( # Output the optimal hit for each query. - with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f: + with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f, \ + open_file_db_by_extension(lineage_map_path, "lll") as lineage_map, \ + open_file_db_by_extension(accession2taxid_dict_path, "L") as accession2taxid_dict: # noqa BlastnOutput6NTRerankedWriter(blast_top_blastn_6_f).writerows( - _optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_pident, max_evalue, False) + _optimal_hit_for_each_query_nt(blast_output, lineage_map, accession2taxid_dict, min_alignment_length, min_pident, max_evalue, False) ) diff --git a/workflows/long-read-mngs/run.wdl b/workflows/long-read-mngs/run.wdl index 507a021f1..820a5c85a 100644 --- a/workflows/long-read-mngs/run.wdl +++ b/workflows/long-read-mngs/run.wdl @@ -644,6 +644,8 @@ task RunNRAlignment { task FindTopHitsNT { input { File nt_m8 + File lineage_db + File accession2taxid_db String docker_image_id } @@ -652,7 +654,12 @@ task FindTopHitsNT { python3 <>> output { @@ -204,6 +205,7 @@ task BlastContigs_refined_rapsearch2_out { File assembly_nr_refseq_fasta File duplicate_cluster_sizes_tsv File lineage_db + File accession2taxid File taxon_blacklist Boolean use_taxon_whitelist } @@ -216,7 +218,7 @@ task BlastContigs_refined_rapsearch2_out { --input-files '[["~{rapsearch2_out_rapsearch2_m8}", "~{rapsearch2_out_rapsearch2_deduped_m8}", "~{rapsearch2_out_rapsearch2_hitsummary_tab}", "~{rapsearch2_out_rapsearch2_counts_with_dcr_json}"], ["~{assembly_contigs_fasta}", "~{assembly_scaffolds_fasta}", "~{assembly_read_contig_sam}", "~{assembly_contig_stats_json}"], ["~{assembly_nr_refseq_fasta}"], ["~{duplicate_cluster_sizes_tsv}"]]' \ --output-files '["assembly/rapsearch2.blast.m8", "assembly/rapsearch2.reassigned.m8", "assembly/rapsearch2.hitsummary2.tab", "assembly/refined_rapsearch2_counts_with_dcr.json", "assembly/rapsearch2_contig_summary.json", "assembly/rapsearch2.blast.top.m8"]' \ --output-dir-s3 '~{s3_wd_uri}' \ - --additional-files '{"lineage_db": "~{lineage_db}", "taxon_blacklist": "~{taxon_blacklist}"}' \ + --additional-files '{"lineage_db": "~{lineage_db}", "accession2taxid": "~{accession2taxid}", "taxon_blacklist": "~{taxon_blacklist}"}' \ --additional-attributes '{"db_type": "nr", "use_taxon_whitelist": ~{use_taxon_whitelist}}' >>> output { @@ -489,6 +491,7 @@ workflow czid_postprocess { String nr_db = "s3://czid-public-references/ncbi-sources/2021-01-22/nr" File nr_loc_db = "s3://czid-public-references/alignment_data/2021-01-22/nr_loc.db" File lineage_db = "s3://czid-public-references/taxonomy/2021-01-22/taxid-lineages.db" + File accession2taxid_db = "s3://czid-public-references/ncbi-indexes-prod/2021-01-22/index-generation-2/accession2taxid.marisa" File taxon_blacklist = "s3://czid-public-references/taxonomy/2021-01-22/taxon_blacklist.txt" File deuterostome_db = "s3://czid-public-references/taxonomy/2021-01-22/deuterostome_taxids.txt" Boolean use_deuterostome_filter = true @@ -556,6 +559,7 @@ workflow czid_postprocess { assembly_nt_refseq_fasta = DownloadAccessions_gsnap_accessions_out.assembly_nt_refseq_fasta, duplicate_cluster_sizes_tsv = duplicate_cluster_sizes_tsv, lineage_db = lineage_db, + accession2taxid = accession2taxid_db, taxon_blacklist = taxon_blacklist, deuterostome_db = deuterostome_db, use_deuterostome_filter = use_deuterostome_filter, @@ -577,6 +581,7 @@ workflow czid_postprocess { assembly_nr_refseq_fasta = DownloadAccessions_rapsearch2_accessions_out.assembly_nr_refseq_fasta, duplicate_cluster_sizes_tsv = duplicate_cluster_sizes_tsv, lineage_db = lineage_db, + accession2taxid = accession2taxid_db, taxon_blacklist = taxon_blacklist, use_taxon_whitelist = use_taxon_whitelist } diff --git a/workflows/short-read-mngs/test/local_test_viral.yml b/workflows/short-read-mngs/test/local_test_viral.yml index dad3be785..d10061be4 100644 --- a/workflows/short-read-mngs/test/local_test_viral.yml +++ b/workflows/short-read-mngs/test/local_test_viral.yml @@ -26,6 +26,7 @@ postprocess.nt_db: s3://czid-public-references/test/viral-alignment-indexes/vira postprocess.nt_loc_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt_loc.marisa postprocess.nr_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nr postprocess.nr_loc_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nr_loc.marisa +postprocess.accession2taxid_db: s3://czid-public-references/mini-database/alignment_indexes/2020-08-20-viral/viral_accessions2taxid.marisa experimental.nt_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt experimental.nt_loc_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt_loc.marisa -experimental.nt_info_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt_info.marisa \ No newline at end of file +experimental.nt_info_db: s3://czid-public-references/test/viral-alignment-indexes/viral_nt_info.marisa