From f35cd021d9c3a419d29d25873c307999770d42c3 Mon Sep 17 00:00:00 2001 From: vortic Date: Thu, 7 Jun 2012 13:45:34 -0700 Subject: [PATCH] cleaned up Lda.java removed many commented sections of code no longer being used added functionality to "cheat" and use true parameters instead of learned ones --- src/plusone/clustering/Lda.java | 153 ++++++++++++-------------------- test | 2 +- 2 files changed, 60 insertions(+), 95 deletions(-) diff --git a/src/plusone/clustering/Lda.java b/src/plusone/clustering/Lda.java index 00bae8f..770ac1f 100644 --- a/src/plusone/clustering/Lda.java +++ b/src/plusone/clustering/Lda.java @@ -27,6 +27,8 @@ public class Lda extends ClusteringTest { private int numTopics; private SimpleMatrix beta; private SimpleMatrix gammas; + //flag to take true parameters (only for synthesized data) + private static boolean CHEAT; public Lda(List trainingSet, Indexer wordIndexer, Terms terms, int numTopics) { @@ -38,29 +40,6 @@ public Lda(List trainingSet, Indexer wordIndexer, train(); } - /** - * deprecated - */ - public void analysis() { - // double trainPercent, double testWordPercent) { - - // super.analysis(0,0); - - /* - * List trainingSet = this.documents.subList(0, - * ((int)(documents.size() * trainPercent))); List - * testingSet = this.documents.subList((int)(documents.size() * - * trainPercent) + 1, documents.size()); - * - * for (PaperAbstract a : testingSet) { - * a.generateTestset(testWordPercent, this.wordIndexer); - * //trainingSet.add(a); } - */ - - // this.train(documents, testingSet); - // this.test(testingSet, testWordPercent); - } - /** * Runs lda-c-dist on the training set to learn the beta matrix and alpha * parameter (in this case, all alphas to the dirichlet are equal) @@ -79,37 +58,19 @@ private void train() { + " lib/lda-c-dist/settings.txt " + trainingData + " random lda", false); - double[][] betaMatrix = readLdaResultFile("lda/final.beta", 0, true); + CHEAT = false; //CHANGE TO false WHEN TRAINING ON REAL DATA + double[][] betaMatrix; + if (CHEAT) { + System.out.println("We are cheating and using the true beta"); + betaMatrix = getRealBeta("src/datageneration/output/" + + "documents_model-out"); + implantRealBeta(betaMatrix, "lda/final.beta"); + } else { + betaMatrix = readLdaResultFile("lda/final.beta", 0, true); + } beta = new SimpleMatrix(betaMatrix); - // SimpleMatrix results = gammas.mult(beta); - - /* - * Integer[][] predictedWords = this.predictTopKWords(beta, gammas, - * testingAbstracts, k, outputUsedWords); - * - * int predicted = 0, total = 0; double tfidfScore = 0.0, idfScore = 0; - * for (int document = 0; document < predictedWords.length; document ++) - * { //System.out.println("document: " + document + - * " number of predicted words: " + predictedWords[document].length); - * for (int predict = 0; predict < predictedWords[document].length; - * predict ++) { Integer wordID = predictedWords[document][predict]; if - * (testingAbstracts.get(document).predictionWords.isEmpty()) - * System.out.println("no prediction words in testing set?"); - * - * if (testingAbstracts.get(document).predictionWords .contains(wordID)) - * { predicted ++; tfidfScore += this.tfidf.tfidf(abstracts.size() - - * testingAbstracts.size() + document, wordID); idfScore += - * this.tfidf.idf(wordID); } - * - * total ++; } } System.out.println("Predicted " + - * ((double)predicted/total)*100 + " percent of the words"); - * System.out.println("total attempts: " + total); - * System.out.println("TFIDF score: " + tfidfScore); - * System.out.println("IDF score: " + idfScore); - */ } - /** * Given a set of test documents, runs lda-c-dist inference to learn the * final gammas. Then, subtracts alpha from each gamma to find the expected @@ -151,46 +112,7 @@ public double[][] predict(List testDocs){ System.out.println("Perplexity is " + getPerplexity(testDocs)); return result; } - - /*public Integer[][] predictTopKWords(int k, boolean outputUsedWords) { - train(); - SimpleMatrix matrix = gammas.mult(beta); - Integer[][] results = new Integer[testingSet.size()][]; - for (int row = 0; row < matrix.numRows(); row++) { - PriorityQueue queue = - new PriorityQueue(k + 1); - - for (int col = 0; col < matrix.numCols(); col++) { - if (!outputUsedWords && testingSet.get(row).tf[col][0] > 0) - continue; - if (queue.size() < k - || matrix.get(row, col) > queue.peek().score) { - if (queue.size() >= k) - queue.poll(); - queue.add(new ItemAndScore(col, matrix.get(row, col), false)); - } - } - // if (outputUsedWords) { - results[row] = new Integer[Math.min(k, queue.size())]; - for (int i = 0; i < k && !queue.isEmpty(); i++) { - results[row][i] = queue.poll().wordID; - } - - * } else { //System.out.println("Predicting results for row: " + - * row); List lst = new ArrayList(); for - * (int i = 0; i < k && !queue.isEmpty(); i ++) { WordAndScore cur = - * queue.poll(); //System.out.println("predicted word: " + - * wordIndexer.get(cur.wordID) + " score: " + cur.score); if - * (!abstracts.get(row).outputWords.contains(cur)) lst.add(cur); - * else i --; } - * - * results[row] = new Integer[lst.size()]; for (int i = 0; i < - * lst.size(); i ++) { results[row][i] = lst.get(i).wordID; } } - - } - return results; - }*/ private void createLdaInput(String filename, List papers){ System.out.print("creating lda input in file: " + filename + " ... "); @@ -209,6 +131,7 @@ private void createLdaInput(String filename, List papers){ System.out.println("done."); } + /** * Takes a list of PaperAbstract documents and writes them to file according * to the format specified by lda-c-dist @@ -237,6 +160,51 @@ private void createLdaInputTest(String filename, List papers) { System.out.println("done."); } + private void implantRealBeta(double[][] betaMatrix, String filename) { + System.out.print("Replacing trained betas with true betas..."); + PlusoneFileWriter fileWriter = new PlusoneFileWriter(filename); + for (int row = 0; row < betaMatrix.length; row++) { + for (int col = 0; col < betaMatrix[row].length; col++) { + fileWriter.write(Math.log(betaMatrix[row][col]) + " "); + } + fileWriter.write("\n"); + } + fileWriter.close(); + System.out.println("done"); + } + + /** + * Only used for synthesized data. Reads in the distribution matrix that was + * used to generate the data. + * @param filename location of stored matrix + * @return the beta matrix from which the documents were generated + */ + private double[][] getRealBeta(String filename) { + double[][] res = null; + List topics = new ArrayList(); + try { + FileInputStream fstream = new FileInputStream(filename); + DataInputStream in = new DataInputStream(fstream); + BufferedReader br = new BufferedReader(new InputStreamReader(in)); + String strLine; + + while (!(strLine = br.readLine()).equals("V")) { + topics.add(strLine.trim().split(" ")); + } + + res = new double[topics.size()][]; + for (int i = 0; i < topics.size(); i++) { + res[i] = new double[topics.get(i).length]; + for (int j = 0; j < topics.get(i).length; j++) { + res[i][j] = new Double(topics.get(i)[j]); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + return res; + } + /** * Takes a file output by lda-c-dist and stores it in a matrix. * @@ -249,8 +217,7 @@ private double[][] readLdaResultFile(String filename, int start, boolean exp) { List gammas = new ArrayList(); double[][] results = null; - // System.out.println("reading lda results file starting at : " + - // start); + try { FileInputStream fstream = new FileInputStream(filename); DataInputStream in = new DataInputStream(fstream); @@ -264,7 +231,6 @@ private double[][] readLdaResultFile(String filename, int start, } c++; } - // System.out.println("C got to " + c); results = new double[gammas.size()][]; for (int i = 0; i < gammas.size(); i++) { @@ -273,7 +239,6 @@ private double[][] readLdaResultFile(String filename, int start, results[i][j] = new Double(gammas.get(i)[j]); if (exp) results[i][j] = Math.exp(results[i][j]); - } } diff --git a/test b/test index d31ff37..a5dd94d 100755 --- a/test +++ b/test @@ -60,7 +60,7 @@ args=\ -Dplusone.localCO.termEnzs=15000 -Dplusone.localCO.dtNs=500 -Dplusone.localCO.tdNs=1200 - -Dplusone.lda.dimensions=5,10,15 + -Dplusone.lda.dimensions=15 -Dplusone.enableTest.localCO=false \ -Dplusone.enableTest.lda=true \ -Dplusone.enableTest.knnc=false \