-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #51 from clulab/kwalcock/mathWithDualSum
Factor out the math
- Loading branch information
Showing
41 changed files
with
2,158 additions
and
465 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,10 @@ | ||
+ **0.6.2** - Shade additional package internal to EJML | ||
+ **0.6.1** - Shade EJML | ||
+ **0.6.0** - Calculate with EJML rather than Breeze | ||
+ **0.6.0** - Use sum instead of concat | ||
+ **0.5.0** - Support Linux on aarch64 | ||
+ **0.5.0** - Isolate dependencies on models to the apps subproject | ||
+ **0.4.0** - Account for maxTokens | ||
+ **0.3.0** - Include some tokenizers as resources and only fall back to network if necessary | ||
+ **0.2.0** - Faster version using JNI | ||
+ **0.1.0** - Initial version using the J4rs library | ||
+ | ||
+ **0.1.0** - Initial version using the J4rs library |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
name := "scala-transformers-apps" | ||
description := "Houses apps, particularly those having extra library dependencies" | ||
|
||
resolvers ++= Seq( | ||
"Artifactory" at "https://artifactory.clulab.org/artifactory/sbt-release" | ||
) | ||
|
||
libraryDependencies ++= { | ||
Seq( | ||
"org.clulab" % "deberta-onnx-model" % "0.2.0", | ||
"org.clulab" % "electra-onnx-model" % "0.2.0", | ||
"org.clulab" % "roberta-onnx-model" % "0.2.0", | ||
"org.scalatest" %% "scalatest" % "3.2.15" % "test" | ||
) | ||
} | ||
|
||
fork := true | ||
|
||
// assembly / mainClass := Some("com.keithalcock.tokenizer.scalapy.apps.ExampleApp") |
2 changes: 1 addition & 1 deletion
2
...ormers/encoder/apps/BlasInstanceApp.scala → ...ansformers/apps/BlasInstanceApp.scala.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 7 additions & 2 deletions
9
...encoder/apps/LoadExampleFromFileApp.scala → ...formers/apps/LoadExampleFromFileApp.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 6 additions & 2 deletions
8
...der/apps/LoadExampleFromResourceApp.scala → ...ers/apps/LoadExampleFromResourceApp.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 8 additions & 11 deletions
19
...oder/apps/TokenClassifierExampleApp.scala → ...mers/apps/TokenClassifierExampleApp.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
130 changes: 130 additions & 0 deletions
130
apps/src/main/scala/org/clulab/scala_transformers/apps/TokenClassifierTimerApp.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
package org.clulab.scala_transformers.apps | ||
|
||
import org.clulab.scala_transformers.common.Timers | ||
import org.clulab.scala_transformers.encoder.{EncoderMaxTokensRuntimeException, TokenClassifier} | ||
import org.clulab.scala_transformers.tokenizer.LongTokenization | ||
|
||
import scala.io.Source | ||
|
||
object TokenClassifierTimerApp extends App { | ||
|
||
class TimedTokenClassifier(tokenClassifier: TokenClassifier) extends TokenClassifier( | ||
tokenClassifier.encoder, tokenClassifier.maxTokens, tokenClassifier.tasks, tokenClassifier.tokenizer | ||
) { | ||
val tokenizeTimer = Timers.getOrNew("Tokenizer") | ||
val forwardTimer = Timers.getOrNew("Encoder.forward") | ||
val predictTimers = tokenClassifier.tasks.indices.map { index => | ||
val name = tasks(index).name | ||
|
||
Timers.getOrNew(s"Encoder.predict $index\t$name") | ||
} | ||
|
||
// NOTE: This should be copied from the base class and then instrumented with timers. | ||
override def predictWithScores(words: Seq[String], headTaskName: String = "Deps Head"): Array[Array[Array[(String, Float)]]] = { | ||
// This condition must be met in order for allLabels to be filled properly without nulls. | ||
// The condition is not checked at runtime! | ||
// if (tasks.exists(_.dual)) | ||
// require(tasks.count(task => !task.dual && task.name == headTaskName) == 1) | ||
|
||
// tokenize to subword tokens | ||
val tokenization = tokenizeTimer.time { | ||
LongTokenization(tokenizer.tokenize(words.toArray)) | ||
} | ||
val inputIds = tokenization.tokenIds | ||
val wordIds = tokenization.wordIds | ||
val tokens = tokenization.tokens | ||
|
||
if (inputIds.length > maxTokens) { | ||
throw new EncoderMaxTokensRuntimeException(s"Encoder error: the following text contains more tokens than the maximum number accepted by this encoder ($maxTokens): ${tokens.mkString(", ")}") | ||
} | ||
|
||
// run the sentence through the transformer encoder | ||
val encOutput = forwardTimer.time { | ||
encoder.forward(inputIds) | ||
} | ||
|
||
// outputs for all tasks stored here: task x tokens in sentence x scores per token | ||
val allLabels = new Array[Array[Array[(String, Float)]]](tasks.length) | ||
// all heads predicted for every token | ||
// dimensions: token x heads | ||
var heads: Option[Array[Array[Int]]] = None | ||
|
||
// now generate token label predictions for all primary tasks (not dual!) | ||
for (i <- tasks.indices) { | ||
if (!tasks(i).dual) { | ||
val tokenLabels = predictTimers(i).time { | ||
tasks(i).predictWithScores(encOutput, None, None) | ||
} | ||
val wordLabels = TokenClassifier.mapTokenLabelsAndScoresToWords(tokenLabels, tokenization.wordIds) | ||
allLabels(i) = wordLabels | ||
|
||
// if this is the task that predicts head positions, then save them for the dual tasks | ||
// we save all the heads predicted for each token | ||
if (tasks(i).name == headTaskName) { | ||
heads = Some(tokenLabels.map(_.map(_._1.toInt))) | ||
} | ||
} | ||
} | ||
|
||
// generate outputs for the dual tasks, if heads were predicted by one of the primary tasks | ||
// the dual task(s) must be aligned with the heads. | ||
// that is, we predict the top label for each of the head candidates | ||
if (heads.isDefined) { | ||
//println("Tokens: " + tokens.mkString(", ")) | ||
//println("Heads:\n\t" + heads.get.map(_.slice(0, 3).mkString(", ")).mkString("\n\t")) | ||
//println("Masks: " + TokenClassifier.mkTokenMask(wordIds).mkString(", ")) | ||
val masks = Some(TokenClassifier.mkTokenMask(wordIds)) | ||
|
||
for (i <- tasks.indices) { | ||
if (tasks(i).dual) { | ||
val tokenLabels = predictTimers(i).time { | ||
tasks(i).predictWithScores(encOutput, heads, masks) | ||
} | ||
val wordLabels = TokenClassifier.mapTokenLabelsAndScoresToWords(tokenLabels, tokenization.wordIds) | ||
allLabels(i) = wordLabels | ||
} | ||
} | ||
} | ||
|
||
allLabels | ||
} | ||
} | ||
|
||
val verbose = false | ||
val fileName = args.lift(0).getOrElse("../corpora/sentences/sentences.txt") | ||
// Choose one of these. | ||
val untimedTokenClassifier = TokenClassifier.fromFiles("../models/microsoft_deberta_v3_base_mtl/avg_export") | ||
// val untimedTokenClassifier = TokenClassifier.fromFiles("../models/google_electra_small_discriminator_mtl/avg_export") | ||
// val untimedTokenClassifier = TokenClassifier.fromFiles("../models/roberta_base_mtl/avg_export") | ||
|
||
val tokenClassifier = new TimedTokenClassifier(untimedTokenClassifier) | ||
val lines = { | ||
val source = Source.fromFile(fileName) | ||
val lines = source.getLines().take(100).toArray | ||
|
||
source.close | ||
lines | ||
} | ||
val elapsedTimer = Timers.getOrNew("Elapsed") | ||
|
||
elapsedTimer.time { | ||
lines.zipWithIndex/*.par*/.foreach { case (line, index) => | ||
println(s"$index $line") | ||
if (index != 1382) { | ||
val words = line.split(" ").toSeq | ||
val allLabelSeqs = tokenClassifier.predictWithScores(words) | ||
|
||
if (verbose) { | ||
println(s"Words: ${words.mkString(", ")}") | ||
for (layer <- allLabelSeqs) { | ||
val words = layer.map(_.head) // Collapse the next layer by just taking the head. | ||
val wordLabels = words.map(_._1) | ||
|
||
println(s"Labels: ${wordLabels.mkString(", ")}") | ||
} | ||
} | ||
} | ||
} | ||
} | ||
Timers.summarize() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...la_transformers/encoder/timer/Timer.scala → ...lab/scala_transformers/common/Timer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.