-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #50 from clulab/dualsum
Dualsum
- Loading branch information
Showing
32 changed files
with
1,878 additions
and
234 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...ormers/encoder/apps/BlasInstanceApp.scala → ...ansformers/apps/BlasInstanceApp.scala.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
130 changes: 130 additions & 0 deletions
130
apps/src/main/scala/org/clulab/scala_transformers/apps/TokenClassifierTimerApp.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
package org.clulab.scala_transformers.apps | ||
|
||
import org.clulab.scala_transformers.common.Timers | ||
import org.clulab.scala_transformers.encoder.{EncoderMaxTokensRuntimeException, TokenClassifier} | ||
import org.clulab.scala_transformers.tokenizer.LongTokenization | ||
|
||
import scala.io.Source | ||
|
||
object TokenClassifierTimerApp extends App { | ||
|
||
class TimedTokenClassifier(tokenClassifier: TokenClassifier) extends TokenClassifier( | ||
tokenClassifier.encoder, tokenClassifier.maxTokens, tokenClassifier.tasks, tokenClassifier.tokenizer | ||
) { | ||
val tokenizeTimer = Timers.getOrNew("Tokenizer") | ||
val forwardTimer = Timers.getOrNew("Encoder.forward") | ||
val predictTimers = tokenClassifier.tasks.indices.map { index => | ||
val name = tasks(index).name | ||
|
||
Timers.getOrNew(s"Encoder.predict $index\t$name") | ||
} | ||
|
||
// NOTE: This should be copied from the base class and then instrumented with timers. | ||
override def predictWithScores(words: Seq[String], headTaskName: String = "Deps Head"): Array[Array[Array[(String, Float)]]] = { | ||
// This condition must be met in order for allLabels to be filled properly without nulls. | ||
// The condition is not checked at runtime! | ||
// if (tasks.exists(_.dual)) | ||
// require(tasks.count(task => !task.dual && task.name == headTaskName) == 1) | ||
|
||
// tokenize to subword tokens | ||
val tokenization = tokenizeTimer.time { | ||
LongTokenization(tokenizer.tokenize(words.toArray)) | ||
} | ||
val inputIds = tokenization.tokenIds | ||
val wordIds = tokenization.wordIds | ||
val tokens = tokenization.tokens | ||
|
||
if (inputIds.length > maxTokens) { | ||
throw new EncoderMaxTokensRuntimeException(s"Encoder error: the following text contains more tokens than the maximum number accepted by this encoder ($maxTokens): ${tokens.mkString(", ")}") | ||
} | ||
|
||
// run the sentence through the transformer encoder | ||
val encOutput = forwardTimer.time { | ||
encoder.forward(inputIds) | ||
} | ||
|
||
// outputs for all tasks stored here: task x tokens in sentence x scores per token | ||
val allLabels = new Array[Array[Array[(String, Float)]]](tasks.length) | ||
// all heads predicted for every token | ||
// dimensions: token x heads | ||
var heads: Option[Array[Array[Int]]] = None | ||
|
||
// now generate token label predictions for all primary tasks (not dual!) | ||
for (i <- tasks.indices) { | ||
if (!tasks(i).dual) { | ||
val tokenLabels = predictTimers(i).time { | ||
tasks(i).predictWithScores(encOutput, None, None) | ||
} | ||
val wordLabels = TokenClassifier.mapTokenLabelsAndScoresToWords(tokenLabels, tokenization.wordIds) | ||
allLabels(i) = wordLabels | ||
|
||
// if this is the task that predicts head positions, then save them for the dual tasks | ||
// we save all the heads predicted for each token | ||
if (tasks(i).name == headTaskName) { | ||
heads = Some(tokenLabels.map(_.map(_._1.toInt))) | ||
} | ||
} | ||
} | ||
|
||
// generate outputs for the dual tasks, if heads were predicted by one of the primary tasks | ||
// the dual task(s) must be aligned with the heads. | ||
// that is, we predict the top label for each of the head candidates | ||
if (heads.isDefined) { | ||
//println("Tokens: " + tokens.mkString(", ")) | ||
//println("Heads:\n\t" + heads.get.map(_.slice(0, 3).mkString(", ")).mkString("\n\t")) | ||
//println("Masks: " + TokenClassifier.mkTokenMask(wordIds).mkString(", ")) | ||
val masks = Some(TokenClassifier.mkTokenMask(wordIds)) | ||
|
||
for (i <- tasks.indices) { | ||
if (tasks(i).dual) { | ||
val tokenLabels = predictTimers(i).time { | ||
tasks(i).predictWithScores(encOutput, heads, masks) | ||
} | ||
val wordLabels = TokenClassifier.mapTokenLabelsAndScoresToWords(tokenLabels, tokenization.wordIds) | ||
allLabels(i) = wordLabels | ||
} | ||
} | ||
} | ||
|
||
allLabels | ||
} | ||
} | ||
|
||
val verbose = false | ||
val fileName = args.lift(0).getOrElse("../corpora/sentences/sentences.txt") | ||
// Choose one of these. | ||
val untimedTokenClassifier = TokenClassifier.fromFiles("../models/microsoft_deberta_v3_base_mtl/avg_export") | ||
// val untimedTokenClassifier = TokenClassifier.fromFiles("../models/google_electra_small_discriminator_mtl/avg_export") | ||
// val untimedTokenClassifier = TokenClassifier.fromFiles("../models/roberta_base_mtl/avg_export") | ||
|
||
val tokenClassifier = new TimedTokenClassifier(untimedTokenClassifier) | ||
val lines = { | ||
val source = Source.fromFile(fileName) | ||
val lines = source.getLines().take(100).toArray | ||
|
||
source.close | ||
lines | ||
} | ||
val elapsedTimer = Timers.getOrNew("Elapsed") | ||
|
||
elapsedTimer.time { | ||
lines.zipWithIndex/*.par*/.foreach { case (line, index) => | ||
println(s"$index $line") | ||
if (index != 1382) { | ||
val words = line.split(" ").toSeq | ||
val allLabelSeqs = tokenClassifier.predictWithScores(words) | ||
|
||
if (verbose) { | ||
println(s"Words: ${words.mkString(", ")}") | ||
for (layer <- allLabelSeqs) { | ||
val words = layer.map(_.head) // Collapse the next layer by just taking the head. | ||
val wordLabels = words.map(_._1) | ||
|
||
println(s"Labels: ${wordLabels.mkString(", ")}") | ||
} | ||
} | ||
} | ||
} | ||
} | ||
Timers.summarize() | ||
} |
2 changes: 1 addition & 1 deletion
2
...la_transformers/encoder/timer/Timer.scala → ...lab/scala_transformers/common/Timer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 3 additions & 3 deletions
6
...transformers/encoder/BreezeExamples.scala → ...sformers/encoder/BreezeExamples.scala.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 0 additions & 33 deletions
33
encoder/src/main/scala/org/clulab/scala_transformers/encoder/BreezeUtils.scala
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.