-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
02ca8fc
commit b3c3f00
Showing
4 changed files
with
122 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,101 @@ | ||
#!/usr/bin/env python | ||
# -*- mode: python; -*- | ||
# -*- mode: python; coding: utf-8; -*- | ||
|
||
################################################################## | ||
# Imports | ||
import crfsuite | ||
from crfsuite import Attribute, Item, ItemSequence, Tagger | ||
import sys | ||
|
||
|
||
################################################################## | ||
# Constants | ||
LINCHAIN = "1d" | ||
SEMIM = "semim" | ||
TREE = "tree" | ||
MTYPE2INT = {LINCHAIN: 1, TREE: 3, SEMIM: 4} | ||
|
||
|
||
################################################################## | ||
# Methods | ||
def instances(fi): | ||
xseq = crfsuite.ItemSequence() | ||
"""Iterate over instances in the provided file. | ||
Args: | ||
fi (FileInput): input file stream | ||
Yields: | ||
crfsuite instances | ||
""" | ||
item_seen = False | ||
xseq = ItemSequence() | ||
for line in fi: | ||
line = line.strip('\n') | ||
if not line: | ||
# An empty line presents an end of a sequence. | ||
yield xseq | ||
xseq = crfsuite.ItemSequence() | ||
# An empty line presents an end of a sequence. | ||
if item_seen: | ||
yield xseq | ||
xseq = ItemSequence() | ||
item_seen = False | ||
continue | ||
|
||
# Split the line on TAB characters. | ||
# Split the line on TAB characters. | ||
fields = line.split('\t') | ||
item = crfsuite.Item() | ||
item_seen = True | ||
item = Item() | ||
for field in fields[1:]: | ||
p = field.rfind(':') | ||
if p == -1: | ||
# Unweighted (weight=1) attribute. | ||
item.append(crfsuite.Attribute(field)) | ||
# Unweighted (weight=1) attribute. | ||
item.append(Attribute(field)) | ||
else: | ||
# Weighted attribute | ||
item.append(crfsuite.Attribute(field[:p], float(field[p+1:]))) | ||
|
||
# Weighted attribute | ||
item.append(Attribute(field[:p], float(field[p+1:]))) | ||
# Append the item to the item sequence. | ||
xseq.append(item) | ||
if item_seen: | ||
yield xseq | ||
|
||
################################################################## | ||
# Main | ||
if __name__ == '__main__': | ||
fi = sys.stdin | ||
fo = sys.stdout | ||
|
||
if len(sys.argv) < 2: | ||
raise Exception("Provide model path as the 1-st argument.") | ||
import argparse | ||
parser = argparse.ArgumentParser(description= | ||
"Script for testing CRF models.") | ||
parser.add_argument("-m", "--model", | ||
help="model in which to store the file", type=str, | ||
default="") | ||
parser.add_argument("-t", "--type", | ||
help="type of graphical model to use", | ||
type=str, default=LINCHAIN, | ||
choices=(LINCHAIN, TREE, SEMIM)) | ||
parser.add_argument("files", help="input files", nargs='*', | ||
type=argparse.FileType('r'), | ||
default=[sys.stdin]) | ||
args = parser.parse_args() | ||
|
||
# Create a tagger object. | ||
tagger = crfsuite.Tagger() | ||
tagger = Tagger() | ||
|
||
# Load the model to the tagger. | ||
# the second argumend specifies the model type (1 - 1d, 2 - tree, 4 - semim) | ||
tagger.open(sys.argv[1], 2) | ||
|
||
for xseq in instances(fi): | ||
# Tag the sequence. | ||
tagger.set(xseq) | ||
# Obtain the label sequence predicted by the tagger. | ||
yseq = tagger.viterbi() | ||
# Output the probability of the predicted label sequence. | ||
print tagger.probability(yseq) | ||
for t, y in enumerate(yseq): | ||
# Output the predicted labels with their marginal probabilities. | ||
print '%s:%f' % (y, tagger.marginal(y, t)) | ||
# the second argumend specifies the model type (1 - 1d, 3 - tree, 4 - | ||
# semim) | ||
if not tagger.open(args.model, MTYPE2INT[args.type]): | ||
raise RuntimeError("Could not load model file.") | ||
|
||
for ifile in args.files: | ||
for xseq in instances(ifile): | ||
# Tag the sequence. | ||
tagger.set(xseq) | ||
# Obtain the label sequence predicted by the tagger. | ||
yseq = tagger.viterbi() | ||
# Output the probability of the predicted label sequence. | ||
# print tagger.probability(yseq) | ||
for t, y in enumerate(yseq): | ||
# Output the predicted labels with their marginal | ||
# probabilities. | ||
if args.type == SEMIM: | ||
print '%s' % (y) | ||
else: | ||
print '%s:%f' % (y, tagger.marginal(y, t)) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters