Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

N-gram filters using KenLM #32

Merged
merged 6 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ lazy val core = (project in file("core"))
.enablePlugins(sbtassembly.AssemblyPlugin)
.settings(
name := "uzushio",
libraryDependencies ++= coreDependencies ++ sparkDependencies.map(
libraryDependencies ++= sparkDependencies.map(
_ % Provided
)
)
Expand All @@ -84,7 +84,7 @@ lazy val lib = (project in file("lib"))
} else {
Seq.empty
}
)
),
)
.settings(commonSettings)
.settings(lintSettings)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.worksap.nlp.uzushio.lib.cleaning

import com.typesafe.config.{Config, ConfigFactory}
import com.typesafe.config.{Config, ConfigFactory, ConfigResolveOptions}
import com.worksap.nlp.uzushio.lib.filters.base.{DocFilter, ParagraphFilter}
import com.worksap.nlp.uzushio.lib.stats.NgramHashExtractor
import com.worksap.nlp.uzushio.lib.utils.{MathUtil, Paragraphs}
Expand Down Expand Up @@ -197,6 +197,7 @@ object Pipeline {
index: Int
): AnyRef = {
if (!cfg.hasPath(par.getName)) {
// try to use default parameter for constructor, if such exist
val defFnName = "$lessinit$greater$default$" + index
try {
val defMethod = clz.getMethod(defFnName) // should be static
Expand Down Expand Up @@ -262,34 +263,35 @@ object Pipeline {
)
}

def make(cfg: Config): Pipeline = {
val filterCfgs = cfg.getConfigList("filters")
val filters = filterCfgs.asScala.map(cfg => instantiateFilter(cfg)).toArray
def make(cfg: Config, props: Config): Pipeline = {
val resolved = cfg.resolveWith(props, ConfigResolveOptions.noSystem())
val filterCfgs = resolved.getConfigList("filters")
val filters = filterCfgs.asScala.map(c => instantiateFilter(c)).toArray
new Pipeline(filters)
}

def make(path: Path): Pipeline = {
def make(path: Path, props: Config): Pipeline = {
val cfg = ConfigFactory.parseFile(path.toFile)
make(cfg)
make(cfg, props)
}

def make(url: URL): Pipeline = {
def make(url: URL, props: Config): Pipeline = {
val cfg = ConfigFactory.parseURL(url)
make(cfg)
make(cfg, props)
}

def make(name: String): Pipeline = {
def make(name: String, props: Config): Pipeline = {
val p = Paths.get(name)
if (Files.exists(p)) {
return make(p)
return make(p, props)
}
val basicUri = getClass.getClassLoader.getResource(name)
if (basicUri != null) {
return make(basicUri)
return make(basicUri, props)
}
val pipelinesUri = getClass.getClassLoader.getResource(s"pipeline/$name")
if (pipelinesUri != null) {
return make(pipelinesUri)
return make(pipelinesUri, props)
}
throw new IllegalArgumentException(
s"failed to find pipeline description: $name"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package com.worksap.nlp.uzushio.lib.filters

import com.github.jbaiter.kenlm.BufferEvaluator
import com.worksap.nlp.sudachi.{Dictionary, Morpheme}
import com.worksap.nlp.uzushio.lib.cleaning.{Document, Paragraph}
import com.worksap.nlp.uzushio.lib.filters.base.{DocFilter, HighLowDocFilter}
import com.worksap.nlp.uzushio.lib.resources.{KenLM, Sudachi}
import com.worksap.nlp.uzushio.lib.utils.Paragraphs

class KenLMDocAvgPerplexity(
sudachi: String,
kenlm: String,
outliers: Float = 0,
override val high: Float = 1e6f,
override val low: Float = 0f
) extends HighLowDocFilter {

@transient
private lazy val processor = KenLMEvaluator.make(sudachi, kenlm, outliers)

override def checkDocument(doc: Document): Document = {
val perplexity = measureDoc(doc)
maybeFilter(doc, perplexity)
}

def measureDoc(doc: Document): Float = {
var ppxSum = 0.0
var charCnt = 0
val paragraphs = doc.aliveParagraphs
val proc = processor
while (paragraphs.hasNext) {
val p = paragraphs.next()
val logProb = proc.scoreParagraph(p)
val ppx = Math.pow(10, -logProb)
ppxSum += ppx * p.text.length
charCnt += p.text.length
}
(ppxSum / charCnt).toFloat
}

override def describeFilter: String = s"KenLMAvgDoc($outliers)"
}

class KenLMEvaluator(sudachi: String, kenlm: String) {
private val dictionary: Dictionary = Sudachi.get(sudachi)
final protected val tokenizer = dictionary.create()
final protected val evaluator = KenLM.get(kenlm).bufferEvaluator(64 * 1024, 1024)

def processParagraph(p: Paragraph): BufferEvaluator = {
val tokens = tokenizer.tokenize(p.text)
val ev = evaluator
val iter = tokens.iterator()
var continue = true
ev.clear()
while (iter.hasNext && continue) {
val token = iter.next()
if (acceptedToken(token)) {
val remaining = ev.append(token.surface())
continue = remaining > 0
}
}
ev
}

def acceptedToken(x: Morpheme): Boolean = {
if (x.normalizedForm() == " ") {
return false
}

val s = x.surface()
if (s.length == 1) {
s.charAt(0) match {
case Paragraphs.HTML_LINK_START | Paragraphs.HTML_LINK_END | '\n' => return false
case _ => return true
}
}

true
}

def extractScore(ev: BufferEvaluator): Double = ev.evaluate()

def scoreParagraph(p: Paragraph): Double = {
val e = processParagraph(p)
extractScore(e)
}
}

object KenLMEvaluator {
def make(sudachi: String, kenlm: String, ratio: Float): KenLMEvaluator = {
if (ratio < 1e-3) {
new KenLMEvaluator(sudachi, kenlm)
} else {
new KenLMEvaluatorNoOutliers(sudachi, kenlm, ratio)
}
}
}

class KenLMEvaluatorNoOutliers(sudachi: String, kenlm: String, ratio: Float)
extends KenLMEvaluator(sudachi, kenlm) {
override def extractScore(ev: BufferEvaluator): Double = ev.evaluateNoOutliers(ratio)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package com.worksap.nlp.uzushio.lib.filters

import com.worksap.nlp.uzushio.lib.cleaning.{Document, Paragraph}
import com.worksap.nlp.uzushio.lib.filters.base.DocFilter

import scala.collection.mutable

final case class ParagraphWithPerplexity(p: Paragraph, ppx: Float) {
def isAlive: Boolean = p.isAlive

def remove(x: AnyRef): ParagraphWithPerplexity = copy(p = p.copy(remove = x))
}

class KenLMParagraphPerplexity(
sudachi: String,
kenlm: String,
outliers: Float = 0.02f,
count: Int = 3,
threshold: Float = 1e6f
) extends DocFilter {
private val lmScore = -Math.log10(threshold).toFloat

@transient
private lazy val processor = KenLMEvaluator.make(sudachi, kenlm, outliers)

override def checkDocument(doc: Document): Document = {
val proc = processor
val paragraphs = doc.paragraphs
.map(p => ParagraphWithPerplexity(p, proc.scoreParagraph(p).toFloat)).toBuffer

val nchanged = markParagraphs(paragraphs)

if (nchanged > 0) {
doc.copy(paragraphs = paragraphs.map(_.p))
} else {
doc
}
}

def markParagraphs(paragraphs: mutable.Buffer[ParagraphWithPerplexity]): Int = {
var nchanged = 0
var idx = 0
val len = paragraphs.length
while (idx < len) {
val p = paragraphs(idx)
if (p.isAlive && (shouldRemoveBack(paragraphs, idx) || shouldRemoveFwd(paragraphs, idx, len))) {
paragraphs(idx) = p.remove(this)
nchanged += removePrev(paragraphs, idx)
nchanged += 1
}
idx += 1
}
nchanged
}

def removePrev(paragraphs: mutable.Buffer[ParagraphWithPerplexity], offset: Int): Int = {
var result = 0
val end = math.max(offset - count, 0)
var idx = offset - 1
while (idx >= end) {
val p = paragraphs(idx)
if (p.isAlive && p.ppx <= lmScore) {
paragraphs(idx) = p.remove(this)
result += 1
}

idx -= 1
}
result
}

def shouldRemoveBack(
paragraphs: mutable.Buffer[ParagraphWithPerplexity],
offset: Int
): Boolean = {
var idx = offset
val end = math.max(offset - count + 1, 0)
while (idx >= end) {
val p = paragraphs(idx)
if (p.ppx > lmScore) {
return false
}
idx -= 1
}
true
}

def shouldRemoveFwd(
paragraphs: mutable.Buffer[ParagraphWithPerplexity],
offset: Int,
length: Int
): Boolean = {
var idx = offset
val end = math.min(offset + count, length)
while (idx < end) {
val p = paragraphs(idx)
if (p.ppx > lmScore) {
return false
}
idx += 1
}
true
}

override val toString = s"KenLMPar($outliers,$count,$threshold)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ trait HighLowDocFilter extends DocFilter { self =>
} else doc
}

def describeFilter: String = self.getClass.getSimpleName

@transient object Low {
override val toString = s"${self.getClass.getSimpleName}.Low($low)"
override val toString = s"$describeFilter.Low($low)"
}

@transient object High {
override val toString = s"${self.getClass.getSimpleName}.High($high)"
override val toString = s"$describeFilter.High($high)"
}

override def toString = s"${self.getClass.getSimpleName}($low,$high)"
override def toString = s"$describeFilter($low,$high)"
}

trait HighLowDocIntFilter extends DocFilter { self =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.worksap.nlp.uzushio.lib.resources

import com.github.jbaiter.kenlm.Model
import com.worksap.nlp.sudachi.{Config, Dictionary, DictionaryFactory}
import org.apache.spark.SparkFiles

import java.nio.file.{Files, Path, Paths}
import java.util.concurrent.ConcurrentHashMap

trait CachedLocalResource[T] {
final private val cache = new ConcurrentHashMap[Path, T]()

def create(p: Path): T

def get(dict: String): T = {
val p = resolveLocalPath(dict).orElse(resolveSparkPath(dict)).getOrElse(
throw new IllegalArgumentException(s"could not find file: $dict")
)

cache.computeIfAbsent(
p,
p1 => create(p1)
)
}

def resolveLocalPath(str: String): Option[Path] = {
val p = Paths.get(str)
if (Files.exists(p) && Files.isRegularFile(p)) {
Some(p)
} else None
}

def resolveSparkPath(str: String): Option[Path] = {
resolveLocalPath(SparkFiles.get(str))
}
}

object Sudachi extends CachedLocalResource[Dictionary] {
override def create(p: Path): Dictionary = {
val cfg = Config.defaultConfig().systemDictionary(p)
new DictionaryFactory().create(cfg)
}
}

object KenLM extends CachedLocalResource[Model] {
override def create(p: Path): Model = new Model(p)
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.worksap.nlp.uzushio.lib.runners

import com.typesafe.config.ConfigFactory
import com.worksap.nlp.uzushio.lib.cleaning.{Document, Paragraph, Pipeline}
import com.worksap.nlp.uzushio.lib.runners.DuplicateCandidateRow._
import com.worksap.nlp.uzushio.lib.stats.{NgramBitSignatures, NgramHashExtractor, SimHashProcessor}
Expand Down Expand Up @@ -917,6 +918,8 @@ object DeduplicateParagraphs {

// noinspection TypeAnnotation,ScalaWeakerAccess
class ArgParser(args: Seq[String]) extends ScallopConf(args) {
import scala.collection.JavaConverters._

val input = opt[List[String]]()
val output = opt[String]()
val numShifts = opt[Int](default = Some(5))
Expand Down Expand Up @@ -947,6 +950,7 @@ object DeduplicateParagraphs {
descr = "Spark StorageLevel for caching operations"
)
val textOnly = toggle(default = Some(false), descrYes = "output only text")
val replacements = props[String]('P', descr = "Properties to resolve in filter chains")
verify()

def toArgs: Args = Args(
Expand All @@ -965,7 +969,7 @@ object DeduplicateParagraphs {
format = format(),
compression = compression(),
intermediate = intermediate(),
pipeline = Pipeline.make(filters()),
pipeline = Pipeline.make(filters(), ConfigFactory.parseMap(replacements.asJava, "props")),
bufferSizeInBytes = bufferSize(),
cacheLevel = cacheLevel.toOption.map(StorageLevel.fromString)
.getOrElse(StorageLevel.MEMORY_AND_DISK),
Expand Down
Loading