Skip to content

Commit

Permalink
Merge branch 'main' into tmorse-nt-compression
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenixAja authored May 6, 2024
2 parents 36395a7 + 4ccee1a commit 8afaba7
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 63 deletions.
22 changes: 12 additions & 10 deletions lib/idseq-dag/idseq_dag/steps/blast_contigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,15 @@ def run(self):
command.write_text_to_file('[]', contig_summary_json)
return # return in the middle of the function

lineage_db = s3.fetch_reference(
self.additional_files["lineage_db"],
self.ref_dir_local,
allow_s3mi=False) # Too small to waste s3mi

accession2taxid_dict = s3.fetch_reference(self.additional_files["accession2taxid"], self.ref_dir_local)

(read_dict, accession_dict, _selected_genera) = m8.summarize_hits(hit_summary)
PipelineStepBlastContigs.run_blast(db_type, blast_m8, assembled_contig, reference_fasta, blast_top_m8)
PipelineStepBlastContigs.run_blast(db_type, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict)
read2contig = {}
generate_info_from_sam(bowtie_sam, read2contig, duplicate_cluster_sizes_path=duplicate_cluster_sizes_path)

Expand All @@ -106,11 +113,6 @@ def run(self):
refined_hit_summary, refined_m8)

# Generating taxon counts based on updated results
lineage_db = s3.fetch_reference(
self.additional_files["lineage_db"],
self.ref_dir_local,
allow_s3mi=False) # Too small to waste s3mi

deuterostome_db = None
if self.additional_files.get("deuterostome_db"):
deuterostome_db = s3.fetch_reference(self.additional_files["deuterostome_db"],
Expand Down Expand Up @@ -264,7 +266,7 @@ def update_read_dict(read2contig, blast_top_blastn_6_path, read_dict, accession_
return (consolidated_dict, read2blastm8, contig2lineage, added_reads)

@staticmethod
def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8):
def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict):
blast_type = 'nucl'
blast_command = 'blastn'
command.execute(
Expand Down Expand Up @@ -308,10 +310,10 @@ def run_blast_nt(blast_index_path, blast_m8, assembled_contig, reference_fasta,
)
)
# further processing of getting the top m8 entry for each contig.
get_top_m8_nt(blast_m8, blast_top_m8)
get_top_m8_nt(blast_m8, lineage_db, accession2taxid_dict, blast_top_m8)

@staticmethod
def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8):
def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta, blast_top_m8, lineage_db, accession2taxid_dict):
blast_type = 'prot'
blast_command = 'blastx'
command.execute(
Expand Down Expand Up @@ -349,4 +351,4 @@ def run_blast_nr(blast_index_path, blast_m8, assembled_contig, reference_fasta,
)
)
# further processing of getting the top m8 entry for each contig.
get_top_m8_nr(blast_m8, blast_top_m8)
get_top_m8_nr(blast_m8, lineage_db, accession2taxid_dict, blast_top_m8)
74 changes: 56 additions & 18 deletions lib/idseq-dag/idseq_dag/util/top_hits.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections import defaultdict

import idseq_dag.util.lineage as lineage

from idseq_dag.util.dict import open_file_db_by_extension
from idseq_dag.util.m8 import NT_MIN_ALIGNMENT_LEN
from idseq_dag.util.parsing import BlastnOutput6Reader, BlastnOutput6Writer
from idseq_dag.util.parsing import BlastnOutput6NTReader
Expand Down Expand Up @@ -56,6 +59,14 @@ def _intersects(needle, haystack):
''' Return True iff needle intersects haystack. Ignore overlap < NT_MIN_OVERLAP_FRACTION. '''
return any(_hsp_overlap(needle, hay) for hay in haystack)

def _get_lineage(accession_id, lineage_map, accession2taxid_dict, lineage_cache={}):
if accession_id in lineage_cache:
return lineage_cache[accession_id]
accession_taxid = accession2taxid_dict.get(
accession_id.split(".")[0], "NA")
result = lineage_map.get(accession_taxid, lineage.NULL_LINEAGE)
lineage_cache[accession_id] = result
return result

class BlastCandidate:
'''Documented in function get_top_m8_nt() below.'''
Expand All @@ -64,6 +75,7 @@ def __init__(self, hsps):
self.hsps = hsps
self.optimal_cover = None
self.total_score = None

if hsps:
# All HSPs here must be for the same query and subject sequence.
h_0 = hsps[0]
Expand Down Expand Up @@ -105,23 +117,34 @@ def summary(self):
return r


def _optimal_hit_for_each_query_nr(blast_output_path, max_evalue):
def _optimal_hit_for_each_query_nr(blast_output_path, lineage_map, accession2taxid_dict, max_evalue):
contigs_to_best_alignments = defaultdict(list)
accession_counts = defaultdict(lambda: 0)

with open(blast_output_path) as blastn_6_f:
# For each contig, get the alignments that have the best total score (may be multiple if there are ties).
# Prioritize the specificity of the hit.
specificity_to_best_alignments = defaultdict(dict)
for alignment in BlastnOutput6Reader(blastn_6_f):
if alignment["evalue"] > max_evalue:
continue
query = alignment["qseqid"]
best_alignments = contigs_to_best_alignments[query]

if len(best_alignments) == 0 or best_alignments[0]["bitscore"] < alignment["bitscore"]:
contigs_to_best_alignments[query] = [alignment]
lineage_taxids = _get_lineage(alignment["sseqid"], lineage_map, accession2taxid_dict)
specificity = next((level for level, taxid_at_level in enumerate(lineage_taxids) if int(taxid_at_level) > 0), float("inf"))

best_alignments = specificity_to_best_alignments[query]

if (specificity not in best_alignments) or best_alignments[specificity][0]["bitscore"] < alignment["bitscore"]:
specificity_to_best_alignments[query][specificity] = [alignment]
# Add all ties to best_hits.
elif len(best_alignments) > 0 and best_alignments[0]["bitscore"] == alignment["bitscore"]:
contigs_to_best_alignments[query].append(alignment)
elif len(best_alignments[specificity]) > 0 and best_alignments[specificity][0]["bitscore"] == alignment["bitscore"]:
specificity_to_best_alignments[query][specificity].append(alignment)

# Choose the best alignments with the most specific taxid information.
for contig_id, specificity_alignment_dict in specificity_to_best_alignments.items():
specific_best_alignments = next(specificity_alignment_dict[specificity] for specificity in sorted(specificity_alignment_dict.keys()))
contigs_to_best_alignments[contig_id] = specific_best_alignments

# Create a map of accession to best alignment count.
for _contig_id, alignments in contigs_to_best_alignments.items():
Expand Down Expand Up @@ -170,26 +193,33 @@ def _filter_and_group_hits_by_query(blast_output_path, min_alignment_length, min


# An iterator that, for contig, yields to optimal hit for the contig in the nt blast_output.
def _optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_pident, max_evalue, summary=True):
def _optimal_hit_for_each_query_nt(blast_output, lineage_map, accession2taxid_dict,
min_alignment_length, min_pident, max_evalue, summary=True):
contigs_to_blast_candidates = {}
accession_counts = defaultdict(lambda: 0)

# For each contig, get the collection of blast candidates that have the best total score (may be multiple if there are ties).
for query_hits in _filter_and_group_hits_by_query(blast_output, min_alignment_length, min_pident, max_evalue):
best_hits = []
best_hits = {}
for _subject, hit_fragments in query_hits.items():
# For NT, we take a special approach where we try to find a subset of disjoint HSPs
# with maximum sum of bitscores.
hit = BlastCandidate(hit_fragments)
hit.optimize()

if len(best_hits) == 0 or best_hits[0].total_score < hit.total_score:
best_hits = [hit]
# Add all ties to best_hits.
elif len(best_hits) > 0 and best_hits[0].total_score == hit.total_score:
best_hits.append(hit)
# We prioritize the specificity of the hit; hits with species taxids are taken before hits without
# Specificity is just the index of the tuple returned by _get_lineage(); 0 for species, 1 for genus, etc.
lineage_taxids = _get_lineage(hit.sseqid, lineage_map, accession2taxid_dict)
specificity = next((level for level, taxid_at_level in enumerate(lineage_taxids) if int(taxid_at_level) > 0), float("inf"))

if (specificity not in best_hits) or best_hits[specificity][0].total_score < hit.total_score:
best_hits[specificity] = [hit]
# Add all ties to best_hits[specificity].
elif len(best_hits[specificity]) > 0 and best_hits[specificity][0].total_score == hit.total_score:
best_hits[specificity].append(hit)

contigs_to_blast_candidates[best_hits[0].qseqid] = best_hits
specific_best_hits = next(best_hits[specificity] for specificity in sorted(best_hits.keys()))
contigs_to_blast_candidates[specific_best_hits[0].qseqid] = specific_best_hits

# Create a map of accession to blast candidate count.
for _contig_id, blast_candidates in contigs_to_blast_candidates.items():
Expand All @@ -213,18 +243,24 @@ def _optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_piden

def get_top_m8_nr(
blast_output,
lineage_map_path,
accession2taxid_dict_path,
blast_top_blastn_6_path,
max_evalue=MAX_EVALUE_THRESHOLD,
):
''' Get top m8 file entry for each contig from blast_output and output to blast_top_m8 '''
with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f:
with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f, \
open_file_db_by_extension(lineage_map_path, "lll") as lineage_map, \
open_file_db_by_extension(accession2taxid_dict_path, "L") as accession2taxid_dict: # noqa
BlastnOutput6Writer(blast_top_blastn_6_f).writerows(
_optimal_hit_for_each_query_nr(blast_output, max_evalue)
_optimal_hit_for_each_query_nr(blast_output, lineage_map, accession2taxid_dict, max_evalue)
)


def get_top_m8_nt(
blast_output,
lineage_map_path,
accession2taxid_dict_path,
blast_top_blastn_6_path,
min_alignment_length=NT_MIN_ALIGNMENT_LEN,
min_pident=_NT_MIN_PIDENT,
Expand All @@ -251,7 +287,9 @@ def get_top_m8_nt(

# Output the optimal hit for each query.

with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f:
with open(blast_top_blastn_6_path, "w") as blast_top_blastn_6_f, \
open_file_db_by_extension(lineage_map_path, "lll") as lineage_map, \
open_file_db_by_extension(accession2taxid_dict_path, "L") as accession2taxid_dict: # noqa
BlastnOutput6NTRerankedWriter(blast_top_blastn_6_f).writerows(
_optimal_hit_for_each_query_nt(blast_output, min_alignment_length, min_pident, max_evalue, False)
_optimal_hit_for_each_query_nt(blast_output, lineage_map, accession2taxid_dict, min_alignment_length, min_pident, max_evalue, False)
)
1 change: 1 addition & 0 deletions workflows/bulk-download/manifest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ input_loaders:
bulk_download_type: ~
outputs:
files: ~
concatenated_output_name: ~
- name: passthrough
version: ">=0.0.1"
inputs:
Expand Down
16 changes: 13 additions & 3 deletions workflows/bulk-download/run.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ workflow bulk_download {
input {
String action
Array[BulkDownloads] files
String concatenated_output_name = "concatenated.txt"
String docker_image_id = "czid-bulk-download"
}

Expand All @@ -25,6 +26,7 @@ workflow bulk_download {
call concatenate {
input:
files = rename.file,
concatenated_output_name = concatenated_output_name,
docker_image_id = docker_image_id
}
}
Expand Down Expand Up @@ -63,14 +65,15 @@ task rename {
task concatenate {
input {
String docker_image_id
String concatenated_output_name = "concatenated.txt"
Array[File] files
}
command <<<
set -euxo pipefail
cat ~{sep=" " files} > concatenated.txt
cat ~{sep=" " files} > "~{concatenated_output_name}"
>>>
output {
File file = "concatenated.txt"
File file = "~{concatenated_output_name}"
}
runtime {
docker: docker_image_id
Expand All @@ -85,7 +88,14 @@ task zip {
command <<<
set -euxo pipefail
# Don't store full path of original files in the .zip file
zip --junk-paths result.zip ~{sep=" " files}
if [[ "~{select_first(files)}" == *.zip ]]; then
mkdir zip_folders
for f in ~{sep=" " files}; do unzip "$f" -d zip_folders/$(basename "${f%.zip}"); done
cd zip_folders
zip -r ../result.zip *
else
zip --junk-paths result.zip ~{sep=" " files}
fi
>>>
output {
File file = "result.zip"
Expand Down
20 changes: 20 additions & 0 deletions workflows/consensus-genome/manifest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ raw_inputs:
description: CZID's snapshot version of NCBI's index i.e. '2021-01-22'
type: str
required: True
accession_name:
name: Accession Name
description: accession name
type: str
required: False
accession_id:
name: Accession ID
description: accession id
type: str
required: False
taxon_name:
name: Taxon
description: taxon name
type: str
required: False
taxon_level:
name: Taxon Level
description: taxon level
type: str
required: False
input_loaders:
- name: sample
version: ">=0.0.1"
Expand Down
Loading

0 comments on commit 8afaba7

Please sign in to comment.