diff --git a/lib/idseq_utils/idseq_utils/diamond_scatter.py b/lib/idseq_utils/idseq_utils/diamond_scatter.py index e1710171c..c98669ca3 100644 --- a/lib/idseq_utils/idseq_utils/diamond_scatter.py +++ b/lib/idseq_utils/idseq_utils/diamond_scatter.py @@ -233,7 +233,7 @@ def blastx_join(chunk_dir: str, out: str, diamond_args: str, *query: str): diamond_blastx( cwd=tmp_dir, par_tmpdir="par-tmp", - block_size=100 if "long-reads" in diamond_args else 10, + block_size=1, database=db.name, out=out, join_chunks=chunks, diff --git a/short-read-mngs/idseq_utils/idseq_utils/diamond_scatter.py b/short-read-mngs/idseq_utils/idseq_utils/diamond_scatter.py new file mode 100644 index 000000000..5474caa42 --- /dev/null +++ b/short-read-mngs/idseq_utils/idseq_utils/diamond_scatter.py @@ -0,0 +1,256 @@ +import os +import shutil +import sys +import errno + +from argparse import ArgumentParser +from glob import glob +from multiprocessing import Pool +from os.path import abspath, basename, join +from subprocess import run, PIPE +from tempfile import NamedTemporaryFile, TemporaryDirectory +from typing import Iterable + +from Bio import SeqIO +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord + + +class DiamondBlastXException(Exception): + pass + + +class DiamondJoinException(Exception): + pass + +################################################################################################################ +# +# Diamond +# +################################################################################################################ + + +def diamond_makedb(cwd: str, reference_fasta: str, database: str): + cmd = [ + "diamond", + "makedb", + "--in", + reference_fasta, + "--db", + database, + ] + run(cmd, check=True, cwd=cwd, stdout=PIPE, stderr=PIPE) + + +def diamond_blastx( + cwd: str, + par_tmpdir: str, + block_size: float, + database: str, + out: str, + queries: Iterable[str], + chunk=False, + diamond_args="", + join_chunks=0, +): + cmd = [ + "diamond", + "blastx", + "--multiprocessing", + "--dbsize", + "123347875600", + "--parallel-tmpdir", + par_tmpdir, + "--block-size", + str(block_size), + "--db", + database, + "--out", + out, + f"--{diamond_args}", + ] + for query in queries: + cmd += ["--query", query] + if chunk: + cmd += ["--single-chunk"] + if join_chunks > 0: + cmd += ["--join-chunks", str(join_chunks)] + print(cmd) + res = run(cmd, cwd=cwd, stdout=PIPE, stderr=PIPE) + if res.returncode != 0: + for line in res.stderr.decode().split("\n"): + print(line) + raise DiamondBlastXException(f"Command {' '.join(cmd)} failed with error: {res.stderr.decode()}") + + +################################################################################################################ +# +# Main +# +################################################################################################################ + + +def _consume_iter(iterable: Iterable, n: int): + for i, e in enumerate(iterable): + yield e + if i == n - 1: + break + + +def align_chunk(ref_chunk: int, start: int, size: int, query_chunk: int): + return f"{ref_chunk} {start} {size} # query_chunk={query_chunk}\n" + + +def zero_pad(n: int, m: int): + tagged = str(n) + return ("0" * (m - len(tagged))) + tagged + + +def make_par_dir(cwd: str, par_tmpdir: str): + os.mkdir(join(cwd, par_tmpdir)) + p_dir = join(cwd, par_tmpdir, "parallelizer") + os.mkdir(p_dir) + with open(join(p_dir, "command"), "w"): + pass + with open(join(p_dir, "register"), "w"): + pass + with open(join(p_dir, "workers"), "w"): + pass + + +def make_db(reference_fasta: str, output_dir: str, chunks: int): + os.mkdir(output_dir) + chunk_size = (sum(1 for _ in SeqIO.parse(reference_fasta, "fasta")) // chunks) + 1 + seqs = SeqIO.parse(reference_fasta, "fasta") + for i in range(chunks): + print(f"STARTING CHUNK {i}") + fasta_name = f"{i}.fasta" + SeqIO.write(_consume_iter(seqs, chunk_size), fasta_name, "fasta") + n_seqs = n_letters = 0 + for seq in SeqIO.parse(fasta_name, "fasta"): + n_seqs += 1 + n_letters += len(seq.seq) + db_name = f"{i}-{n_seqs}-{n_letters}" + print(f"INDEXING CHUNK {i}") + diamond_makedb(".", fasta_name, join(output_dir, db_name)) + os.remove(fasta_name) + print(f"COMPLETED CHUNK {i}") + + +def blastx_chunk(db_chunk: str, output_dir: str, diamond_args: str, *query: str): + try: + os.mkdir(output_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + chunk, n_seqs, n_letters = basename(db_chunk)[:-5].split("-") + with TemporaryDirectory() as tmp_dir: + make_par_dir(tmp_dir, "par-tmp") + with open(join(tmp_dir, "par-tmp", f"align_todo_{zero_pad(0, 6)}"), "w") as f: + f.writelines([align_chunk(int(chunk), 0, int(n_seqs), 0)]) + diamond_blastx( + cwd=tmp_dir, + par_tmpdir="par-tmp", + block_size=int(n_letters) / 1e9, + database=abspath(db_chunk), + out="out.tsv", + chunk=True, + diamond_args=diamond_args, + queries=(abspath(q) for q in query), + ) + ref_block_name = f"ref_block_{zero_pad(0, 6)}_{zero_pad(int(chunk), 6)}" + ref_dict_name = f"ref_dict_{zero_pad(0, 6)}_{zero_pad(int(chunk), 6)}" + shutil.copy( + join(tmp_dir, "par-tmp", ref_block_name), join(output_dir, ref_block_name) + ) + shutil.copy( + join(tmp_dir, "par-tmp", ref_dict_name), join(output_dir, ref_dict_name) + ) + + +def mock_reference_fasta(chunks: int, chunk_size: int): + letters = chunk = i = 0 + while chunk < chunks: + n = 100 + letters += n + if letters > (1 + chunk) * chunk_size: + chunk += 1 + yield SeqRecord(Seq("M" * n), "A") + i += 1 + + +def blastx_join(chunk_dir: str, out: str, diamond_args: str, *query: str): + with TemporaryDirectory() as tmp_dir: + make_par_dir(tmp_dir, "par-tmp") + with open(join(tmp_dir, "par-tmp", f"join_todo_{zero_pad(0, 6)}"), "w") as f: + f.write("TOKEN\n") + + for f in os.listdir(chunk_dir): + shutil.copy(join(chunk_dir, f), join(tmp_dir, "par-tmp", f)) + + chunks = len(os.listdir(chunk_dir)) // 2 + with NamedTemporaryFile() as ref_fasta, NamedTemporaryFile() as db: + # make fake db to appease diamond + SeqIO.write(SeqRecord(Seq("M"), "ID"), ref_fasta.name, "fasta") + diamond_makedb(tmp_dir, ref_fasta.name, db.name) + diamond_blastx( + cwd=tmp_dir, + par_tmpdir="par-tmp", + block_size=1, + database=db.name, + out=out, + join_chunks=chunks, + diamond_args=diamond_args, + queries=(abspath(q) for q in query), + ) + + with open(out, "w") as out_f: + for out_chunk in glob(join(tmp_dir, f"{out}_*")): + with open(out_chunk) as out_chunk_f: + out_f.writelines(out_chunk_f) + os.remove(out_chunk) + + +if __name__ == "__main__": + parser = ArgumentParser() + subparsers = parser.add_subparsers(title="commands", dest="command") + + make_db_parser = subparsers.add_parser("make-db") + make_db_parser.add_argument("--db", required=True) + make_db_parser.add_argument("--in", required=True) + make_db_parser.add_argument("--chunks", type=int, required=True) + + blastx_chunk_parser = subparsers.add_parser("blastx-chunk") + blastx_chunk_parser.add_argument("--db", required=True) + blastx_chunk_parser.add_argument("--out-dir", required=True) + blastx_chunk_parser.add_argument("--diamond-args", required=False) + blastx_chunk_parser.add_argument("--query", required=True, action="append") + + blastx_chunks_parser = subparsers.add_parser("blastx-chunks") + blastx_chunks_parser.add_argument("--db-dir", required=True) + blastx_chunks_parser.add_argument("--out-dir", required=True) + blastx_chunks_parser.add_argument("--diamond-args", required=False) + blastx_chunks_parser.add_argument("--query", required=True, action="append") + + blastx_join_parser = subparsers.add_parser("blastx-join") + blastx_join_parser.add_argument("--chunk-dir", required=True) + blastx_join_parser.add_argument("--out", required=True) + blastx_join_parser.add_argument("--query", required=True, action="append") + + args = parser.parse_args(sys.argv[1:]) + if args.command == "make-db": + make_db(args.__getattribute__("in"), args.db, args.chunks) + elif args.command == "blastx-chunk": + blastx_chunk(args.db, args.out_dir, args.diamond_args, *args.query) + elif args.command == "blastx-chunks": + + def _blastx_chunk(db): + print(f"STARTING: {db}") + res = blastx_chunk(join(args.db_dir, db), args.out_dir, args.diamond_args, *args.query) + print(f"FINISHING: {db}") + return res + + with Pool(48) as p: + p.map(_blastx_chunk, os.listdir(args.db_dir)) + elif args.command == "blastx-join": + blastx_join(args.chunk_dir, args.out, args.diamond_args, *args.query) diff --git a/workflows/diamond/Dockerfile b/workflows/diamond/Dockerfile index 806ef944a..c24bb2e0f 100644 --- a/workflows/diamond/Dockerfile +++ b/workflows/diamond/Dockerfile @@ -19,7 +19,7 @@ WORKDIR /tmp RUN git clone https://github.com/chanzuckerberg/czid-workflows WORKDIR /tmp/czid-workflows RUN pip3 install -r requirements-dev.txt -RUN git checkout rlim-add-diamond-modification -RUN cp short-read-mngs/idseq_utils/idseq_utils/diamond_scatter.py /usr/local/bin/ -WORKDIR /workdir \ No newline at end of file +COPY --from=lib idseq_utils/idseq_utils/diamond_scatter.py /usr/local/bin/ + +WORKDIR /workdir diff --git a/workflows/diamond/diamond.wdl b/workflows/diamond/diamond.wdl index 35af2d033..7a8f2f814 100644 --- a/workflows/diamond/diamond.wdl +++ b/workflows/diamond/diamond.wdl @@ -34,7 +34,7 @@ task RunDiamond { } command <<< - python3 /usr/local/bin/diamond_scatter.py blastx-chunk --db ~{db_chunk} --query ~{query_0} ~{if defined(query_1) then '--query ~{query_1}' else ''} --out-dir chunks --diamond-args "~{extra_args}" + python3 /usr/local/bin/diamond_scatter.py blastx-chunk --db ~{db_chunk} --query ~{query_0} ~{if defined(query_1) then '--query ~{query_1}' else ''} --out-dir chunks --diamond-args="~{extra_args}" >>> output { @@ -45,4 +45,4 @@ task RunDiamond { runtime { docker: docker_image_id } -} \ No newline at end of file +}