diff --git a/src/edu/msu/cme/rdp/alignment/pairwise/PairwiseKNN.java b/src/edu/msu/cme/rdp/alignment/pairwise/PairwiseKNN.java index 82e5905..c1aa9f4 100644 --- a/src/edu/msu/cme/rdp/alignment/pairwise/PairwiseKNN.java +++ b/src/edu/msu/cme/rdp/alignment/pairwise/PairwiseKNN.java @@ -19,13 +19,20 @@ import edu.msu.cme.rdp.alignment.AlignmentMode; import edu.msu.cme.rdp.alignment.pairwise.rna.DistanceModel; import edu.msu.cme.rdp.alignment.pairwise.rna.IdentityDistanceModel; +import edu.msu.cme.rdp.alignment.pairwise.rna.OverlapCheckFailedException; import edu.msu.cme.rdp.readseq.SequenceType; import edu.msu.cme.rdp.readseq.readers.SeqReader; import edu.msu.cme.rdp.readseq.readers.Sequence; import edu.msu.cme.rdp.readseq.readers.SequenceReader; import edu.msu.cme.rdp.readseq.utils.IUBUtilities; import edu.msu.cme.rdp.readseq.utils.SeqUtils; +import edu.msu.cme.rdp.readseq.utils.kmermatch.KmerMatchCore; +import edu.msu.cme.rdp.readseq.utils.kmermatch.NuclSeqMatch; +import edu.msu.cme.rdp.readseq.utils.kmermatch.ProteinSeqMatch; +import edu.msu.cme.rdp.readseq.utils.orientation.GoodWordIterator; +import edu.msu.cme.rdp.readseq.utils.orientation.ProteinWordGenerator; import java.io.File; +import java.io.IOException; import java.io.PrintStream; import java.util.ArrayList; import java.util.Collections; @@ -42,6 +49,16 @@ */ public class PairwiseKNN { + private File queryFile; + private File refFile; + private int k; + private int prefilter = 0; + private int wordSize; + private AlignmentMode mode; + private List dbSeqs; + private PrintStream out; + private static final String dformat = "%1$.3f"; + public static class Neighbor { PairwiseAlignment alignment; @@ -63,7 +80,7 @@ private static void insert(T n, List list, Comparator comp, int k) { } } - public static List getKNN(Sequence query, List dbSeqs, AlignmentMode mode, int k) { + public static List getKNN(Sequence query, List dbSeqs, AlignmentMode mode, int k, int wordSize, int prefilter) throws IOException { List ret = new ArrayList(); Neighbor n; Comparator c = new Comparator() { @@ -74,19 +91,33 @@ public int compare(Neighbor t, Neighbor t1) { SequenceType seqType = SeqUtils.guessSequenceType(query); ScoringMatrix matrix; + KmerMatchCore kerMatchCore; if (seqType == SequenceType.Nucleotide) { matrix = ScoringMatrix.getDefaultNuclMatrix(); + kerMatchCore = new NuclSeqMatch(dbSeqs, wordSize); } else { matrix = ScoringMatrix.getDefaultProteinMatrix(); + kerMatchCore = new ProteinSeqMatch(dbSeqs, wordSize); } - matrix = ScoringMatrix.getDefaultProteinMatrix(); - for (Sequence dbSeq : dbSeqs) { + List refList; + + if ( prefilter == 0) { // do not pre-filter the reference seqs + refList = dbSeqs; + }else { + refList = new ArrayList(); + ArrayList topKMatches= kerMatchCore.findTopKMatch(query, prefilter); + for (KmerMatchCore.BestMatch bestTarget : topKMatches) { + refList.add(bestTarget.getBestMatch()); + } + } + + for (Sequence dbSeq : refList) { n = new Neighbor(); n.dbSeq = dbSeq; - PairwiseAlignment fwd = PairwiseAligner.align(dbSeq.getSeqString(), query.getSeqString(), matrix, mode); + PairwiseAlignment fwd = PairwiseAligner.align(n.dbSeq.getSeqString(), query.getSeqString(), matrix, mode); if (seqType == SequenceType.Nucleotide) { - PairwiseAlignment rc = PairwiseAligner.align(dbSeq.getSeqString(), IUBUtilities.reverseComplement(query.getSeqString()), matrix, mode); + PairwiseAlignment rc = PairwiseAligner.align(n.dbSeq.getSeqString(), IUBUtilities.reverseComplement(query.getSeqString()), matrix, mode); if (rc.getScore() > fwd.getScore()) { n.alignment = rc; @@ -102,44 +133,91 @@ public int compare(Neighbor t, Neighbor t1) { insert(n, ret, c, k); } - + return ret; } - /*public static void main(String[] args) { - List list = new ArrayList(); - - list.add(10); - list.add(8); - list.add(7); - list.add(6); + public PairwiseKNN(File queryFile, File refFile, PrintStream out, AlignmentMode mode, int k, int wordSize, int prefilter) throws IOException{ + this.queryFile = queryFile; + this.refFile = refFile; + this.out = out; + this.mode = mode; + this.k = k; + this.prefilter = prefilter; + this.wordSize = wordSize; + SequenceType querySeqType = SeqUtils.guessSequenceType(queryFile); + SequenceType refSeqType = SeqUtils.guessSequenceType(refFile); + + if ( querySeqType != refSeqType) { + throw new RuntimeException("reference seqs and query seqs must be the same type, either protein or nucleotide. " ); + } + if ( wordSize == 0 ){ + if ( refSeqType == SequenceType.Protein){ + this.wordSize = ProteinWordGenerator.WORDSIZE; + } else { + this.wordSize = GoodWordIterator.DEFAULT_WORDSIZE ; + } + } + + dbSeqs = SequenceReader.readFully(refFile); + } + + public void match() throws IOException, OverlapCheckFailedException { + DistanceModel dist = new IdentityDistanceModel(); - Comparator comp = new Comparator() { + out.println("#query file: " + queryFile.getName() + " db file: " + refFile.getName() + " k: " + k + " mode: " + mode + " usePrefilter: " + prefilter); + out.println("#seqname\tk\tref seqid\tref desc\torientation\tscore\tident\tquery start\tquery end\tquery length\tref start\tref end"); + Sequence seq; + List alignments; + Neighbor n; + PairwiseAlignment alignment; + SequenceReader queryReader = new SequenceReader(queryFile); + while ((seq = queryReader.readNextSequence()) != null) { + alignments = getKNN(seq, dbSeqs, mode, k, wordSize, prefilter); - public int compare(Integer t, Integer t1) { - return t - t1; - } + for (int index = 0; index < alignments.size(); index++) { + n = alignments.get(index); + alignment = n.alignment; + double ident = 1 - dist.getDistance(alignment.getAlignedSeqi().getBytes(), alignment.getAlignedSeqj().getBytes(), 0); - }; + out.println("@" + seq.getSeqName() + + "\t" + (index + 1) + + "\t" + n.dbSeq.getSeqName() + + "\t" + n.dbSeq.getDesc() + + "\t" + (n.reverse ? "-" : "+") + + "\t" + alignment.getScore() + + "\t" + String.format(dformat,ident) + + "\t" + alignment.getStartj() + + "\t" + alignment.getEndj() + + "\t" + seq.getSeqString().length() + + "\t" + alignment.getStarti() + + "\t" + alignment.getEndi()); - System.out.println(list); - insert(11, list, comp, 5); - System.out.println(list); - insert(9, list, comp, 5); - System.out.println(list); - }*/ - public static void main(String[] args) throws Exception { - SeqReader queryReader; - List dbSeqs; + out.println(">" + alignment.getAlignedSeqj()); + out.println(">" + alignment.getAlignedSeqi()); + } + } + queryReader.close(); + out.close(); + } + + public static void main(String[] args) throws Exception { + File queryFile; + File refFile; AlignmentMode mode = AlignmentMode.glocal; int k = 1; + int wordSize = 0 ; + int prefilter = 10 ; // The top p closest protein targets PrintStream out = new PrintStream(System.out); Options options = new Options(); options.addOption("m", "mode", true, "Alignment mode {global, glocal, local, overlap, overlap_trimmed} (default= glocal)"); - options.addOption("k", true, "K-nearest neighbors to return"); + options.addOption("k", true, "K-nearest neighbors to return. (default = 1)"); options.addOption("o", "out", true, "Redirect output to file instead of stdout"); - + options.addOption("p", "prefilter", true, "The top p closest targets from kmer prefilter step. Set p=0 to disable the prefilter step. (default = 10) "); + options.addOption("w", "word-size", true, "The word size used to find closest targets during prefilter. (default " + ProteinWordGenerator.WORDSIZE + + " for protein, " + GoodWordIterator.DEFAULT_WORDSIZE + " for nucleotide)"); + try { CommandLine line = new PosixParser().parse(options, args); @@ -150,8 +228,25 @@ public static void main(String[] args) throws Exception { if (line.hasOption('k')) { k = Integer.valueOf(line.getOptionValue('k')); + if ( k < 1 ){ + throw new Exception("k must be at least 1"); + } } - + + if (line.hasOption("word-size")) { + wordSize = Integer.parseInt(line.getOptionValue("word-size")); + if ( wordSize < 3 ){ + throw new Exception("Word size must be at least 3"); + } + } + if (line.hasOption("prefilter")) { + prefilter = Integer.parseInt(line.getOptionValue("prefilter")); + // prefilter == 0 means no prefilter + if ( prefilter > 0 && prefilter < k ){ + throw new Exception("prefilter must be at least as big as k " + k); + } + } + if (line.hasOption("out")) { out = new PrintStream(line.getOptionValue("out")); } @@ -162,48 +257,18 @@ public static void main(String[] args) throws Exception { throw new Exception("Unexpected number of command line arguments"); } - queryReader = new SequenceReader(new File(args[0])); - dbSeqs = SequenceReader.readFully(new File(args[1])); + queryFile = new File(args[0]); + refFile = new File(args[1]); } catch (Exception e) { new HelpFormatter().printHelp("PairwiseKNN ", options); System.err.println("ERROR: " + e.getMessage()); return; } - - DistanceModel dist = new IdentityDistanceModel(); - - out.println("#query file: " + args[0] + " db file: " + args[1] + " k: " + k + " mode: " + mode); - out.println("#seqname\tk\tref seqid\tref desc\torientation\tscore\tident\tquery start\tquery end\tquery length\tref start\tref end"); - Sequence seq; - List alignments; - Neighbor n; - PairwiseAlignment alignment; - while ((seq = queryReader.readNextSequence()) != null) { - alignments = getKNN(seq, dbSeqs, mode, k); - - for (int index = 0; index < alignments.size(); index++) { - n = alignments.get(index); - alignment = n.alignment; - double ident = 1 - dist.getDistance(alignment.getAlignedSeqi().getBytes(), alignment.getAlignedSeqj().getBytes(), 0); - - out.println("@" + seq.getSeqName() - + "\t" + (index + 1) - + "\t" + n.dbSeq.getSeqName() - + "\t" + n.dbSeq.getDesc() - + "\t" + (n.reverse ? "-" : "+") - + "\t" + alignment.getScore() - + "\t" + ident - + "\t" + alignment.getStartj() - + "\t" + alignment.getEndj() - + "\t" + seq.getSeqString().length() - + "\t" + alignment.getStarti() - + "\t" + alignment.getEndi()); - - out.println(">" + alignment.getAlignedSeqj()); - out.println(">" + alignment.getAlignedSeqi()); - } - } + + PairwiseKNN theObj = new PairwiseKNN(queryFile, refFile, out, mode, k, wordSize, prefilter); + theObj.match(); + } }