Skip to content

Commit

Permalink
cleaned up Lda.java
Browse files Browse the repository at this point in the history
removed many commented sections of code no longer being used

added functionality to "cheat" and use true parameters instead of
learned ones
  • Loading branch information
vortic committed Jun 7, 2012
1 parent cb2ec15 commit f35cd02
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 95 deletions.
153 changes: 59 additions & 94 deletions src/plusone/clustering/Lda.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public class Lda extends ClusteringTest {
private int numTopics;
private SimpleMatrix beta;
private SimpleMatrix gammas;
//flag to take true parameters (only for synthesized data)
private static boolean CHEAT;

public Lda(List<TrainingPaper> trainingSet, Indexer<String> wordIndexer,
Terms terms, int numTopics) {
Expand All @@ -38,29 +40,6 @@ public Lda(List<TrainingPaper> trainingSet, Indexer<String> wordIndexer,
train();
}

/**
* deprecated
*/
public void analysis() {
// double trainPercent, double testWordPercent) {

// super.analysis(0,0);

/*
* List<PaperAbstract> trainingSet = this.documents.subList(0,
* ((int)(documents.size() * trainPercent))); List<PaperAbstract>
* testingSet = this.documents.subList((int)(documents.size() *
* trainPercent) + 1, documents.size());
*
* for (PaperAbstract a : testingSet) {
* a.generateTestset(testWordPercent, this.wordIndexer);
* //trainingSet.add(a); }
*/

// this.train(documents, testingSet);
// this.test(testingSet, testWordPercent);
}

/**
* Runs lda-c-dist on the training set to learn the beta matrix and alpha
* parameter (in this case, all alphas to the dirichlet are equal)
Expand All @@ -79,37 +58,19 @@ private void train() {
+ " lib/lda-c-dist/settings.txt " + trainingData
+ " random lda", false);

double[][] betaMatrix = readLdaResultFile("lda/final.beta", 0, true);
CHEAT = false; //CHANGE TO false WHEN TRAINING ON REAL DATA
double[][] betaMatrix;
if (CHEAT) {
System.out.println("We are cheating and using the true beta");
betaMatrix = getRealBeta("src/datageneration/output/" +
"documents_model-out");
implantRealBeta(betaMatrix, "lda/final.beta");
} else {
betaMatrix = readLdaResultFile("lda/final.beta", 0, true);
}
beta = new SimpleMatrix(betaMatrix);
// SimpleMatrix results = gammas.mult(beta);

/*
* Integer[][] predictedWords = this.predictTopKWords(beta, gammas,
* testingAbstracts, k, outputUsedWords);
*
* int predicted = 0, total = 0; double tfidfScore = 0.0, idfScore = 0;
* for (int document = 0; document < predictedWords.length; document ++)
* { //System.out.println("document: " + document +
* " number of predicted words: " + predictedWords[document].length);
* for (int predict = 0; predict < predictedWords[document].length;
* predict ++) { Integer wordID = predictedWords[document][predict]; if
* (testingAbstracts.get(document).predictionWords.isEmpty())
* System.out.println("no prediction words in testing set?");
*
* if (testingAbstracts.get(document).predictionWords .contains(wordID))
* { predicted ++; tfidfScore += this.tfidf.tfidf(abstracts.size() -
* testingAbstracts.size() + document, wordID); idfScore +=
* this.tfidf.idf(wordID); }
*
* total ++; } } System.out.println("Predicted " +
* ((double)predicted/total)*100 + " percent of the words");
* System.out.println("total attempts: " + total);
* System.out.println("TFIDF score: " + tfidfScore);
* System.out.println("IDF score: " + idfScore);
*/
}


/**
* Given a set of test documents, runs lda-c-dist inference to learn the
* final gammas. Then, subtracts alpha from each gamma to find the expected
Expand Down Expand Up @@ -151,46 +112,7 @@ public double[][] predict(List<PredictionPaper> testDocs){
System.out.println("Perplexity is " + getPerplexity(testDocs));
return result;
}

/*public Integer[][] predictTopKWords(int k, boolean outputUsedWords) {
train();
SimpleMatrix matrix = gammas.mult(beta);
Integer[][] results = new Integer[testingSet.size()][];
for (int row = 0; row < matrix.numRows(); row++) {
PriorityQueue<ItemAndScore> queue =
new PriorityQueue<ItemAndScore>(k + 1);
for (int col = 0; col < matrix.numCols(); col++) {
if (!outputUsedWords && testingSet.get(row).tf[col][0] > 0)
continue;
if (queue.size() < k
|| matrix.get(row, col) > queue.peek().score) {
if (queue.size() >= k)
queue.poll();
queue.add(new ItemAndScore(col, matrix.get(row, col), false));
}
}

// if (outputUsedWords) {
results[row] = new Integer[Math.min(k, queue.size())];
for (int i = 0; i < k && !queue.isEmpty(); i++) {
results[row][i] = queue.poll().wordID;
}
* } else { //System.out.println("Predicting results for row: " +
* row); List<WordAndScore> lst = new ArrayList<WordAndScore>(); for
* (int i = 0; i < k && !queue.isEmpty(); i ++) { WordAndScore cur =
* queue.poll(); //System.out.println("predicted word: " +
* wordIndexer.get(cur.wordID) + " score: " + cur.score); if
* (!abstracts.get(row).outputWords.contains(cur)) lst.add(cur);
* else i --; }
*
* results[row] = new Integer[lst.size()]; for (int i = 0; i <
* lst.size(); i ++) { results[row][i] = lst.get(i).wordID; } }
}
return results;
}*/
private void createLdaInput(String filename, List<TrainingPaper> papers){
System.out.print("creating lda input in file: " + filename + " ... ");

Expand All @@ -209,6 +131,7 @@ private void createLdaInput(String filename, List<TrainingPaper> papers){

System.out.println("done.");
}

/**
* Takes a list of PaperAbstract documents and writes them to file according
* to the format specified by lda-c-dist
Expand Down Expand Up @@ -237,6 +160,51 @@ private void createLdaInputTest(String filename, List<PredictionPaper> papers) {
System.out.println("done.");
}

private void implantRealBeta(double[][] betaMatrix, String filename) {
System.out.print("Replacing trained betas with true betas...");
PlusoneFileWriter fileWriter = new PlusoneFileWriter(filename);
for (int row = 0; row < betaMatrix.length; row++) {
for (int col = 0; col < betaMatrix[row].length; col++) {
fileWriter.write(Math.log(betaMatrix[row][col]) + " ");
}
fileWriter.write("\n");
}
fileWriter.close();
System.out.println("done");
}

/**
* Only used for synthesized data. Reads in the distribution matrix that was
* used to generate the data.
* @param filename location of stored matrix
* @return the beta matrix from which the documents were generated
*/
private double[][] getRealBeta(String filename) {
double[][] res = null;
List<String[]> topics = new ArrayList<String[]>();
try {
FileInputStream fstream = new FileInputStream(filename);
DataInputStream in = new DataInputStream(fstream);
BufferedReader br = new BufferedReader(new InputStreamReader(in));
String strLine;

while (!(strLine = br.readLine()).equals("V")) {
topics.add(strLine.trim().split(" "));
}

res = new double[topics.size()][];
for (int i = 0; i < topics.size(); i++) {
res[i] = new double[topics.get(i).length];
for (int j = 0; j < topics.get(i).length; j++) {
res[i][j] = new Double(topics.get(i)[j]);
}
}
} catch (Exception e) {
e.printStackTrace();
}
return res;
}

/**
* Takes a file output by lda-c-dist and stores it in a matrix.
*
Expand All @@ -249,8 +217,7 @@ private double[][] readLdaResultFile(String filename, int start,
boolean exp) {
List<String[]> gammas = new ArrayList<String[]>();
double[][] results = null;
// System.out.println("reading lda results file starting at : " +
// start);

try {
FileInputStream fstream = new FileInputStream(filename);
DataInputStream in = new DataInputStream(fstream);
Expand All @@ -264,7 +231,6 @@ private double[][] readLdaResultFile(String filename, int start,
}
c++;
}
// System.out.println("C got to " + c);

results = new double[gammas.size()][];
for (int i = 0; i < gammas.size(); i++) {
Expand All @@ -273,7 +239,6 @@ private double[][] readLdaResultFile(String filename, int start,
results[i][j] = new Double(gammas.get(i)[j]);
if (exp)
results[i][j] = Math.exp(results[i][j]);

}
}

Expand Down
2 changes: 1 addition & 1 deletion test
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ args=\
-Dplusone.localCO.termEnzs=15000
-Dplusone.localCO.dtNs=500
-Dplusone.localCO.tdNs=1200
-Dplusone.lda.dimensions=5,10,15
-Dplusone.lda.dimensions=15
-Dplusone.enableTest.localCO=false \
-Dplusone.enableTest.lda=true \
-Dplusone.enableTest.knnc=false \
Expand Down

0 comments on commit f35cd02

Please sign in to comment.