Skip to content

Commit

Permalink
swig bindings updated (issue #3)
Browse files Browse the repository at this point in the history
  • Loading branch information
WladimirSidorenko committed Mar 5, 2016
1 parent 02ca8fc commit b3c3f00
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 78 deletions.
16 changes: 6 additions & 10 deletions include/crfsuite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ namespace CRFSuite

// Set the attributes in the item.
i = 0;
n_items = item.size();
n_items = (int) item.size();

if (tr && tr->ftype == FTYPE_CRF1TREE) {
if (n_items < 1)
Expand Down Expand Up @@ -461,14 +461,11 @@ namespace CRFSuite

StringList Tagger::viterbi()
{
int ret;
StringList yseq;
crfsuite_dictionary_t *labels = NULL;

if (model == NULL || tagger == NULL) {
if (model == NULL || tagger == NULL)
throw std::invalid_argument("The tagger is not opened");
}

int ret;
StringList yseq;
// Make sure that the current instance is not empty.
const size_t T = (size_t)tagger->length(tagger);
if (T <= 0)
Expand All @@ -484,7 +481,7 @@ namespace CRFSuite

// Convert the Viterbi path to a label sequence.
yseq.resize(T);
for (size_t t = 0;t < T;++t) {
for (size_t t = 0; t < T; ++t) {
const char *label = NULL;
if (m_labels->to_string(m_labels, path[t], &label) != 0) {
delete[] path;
Expand All @@ -493,8 +490,7 @@ namespace CRFSuite
yseq[t] = label;
m_labels->free(m_labels, label);
}

labels->release(labels);
delete[] path;
return yseq;
}

Expand Down
2 changes: 1 addition & 1 deletion swig/python/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
ln -fs ../crfsuite.cpp
ln -fs ../export.i

if [ "$1" = "--swig" ]; then
if test "$1" = "--swig"; then
swig -c++ -python -I../../include -o export_wrap.cpp export.i
fi
105 changes: 72 additions & 33 deletions swig/python/sample_tag.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,101 @@
#!/usr/bin/env python
# -*- mode: python; -*-
# -*- mode: python; coding: utf-8; -*-

##################################################################
# Imports
import crfsuite
from crfsuite import Attribute, Item, ItemSequence, Tagger
import sys


##################################################################
# Constants
LINCHAIN = "1d"
SEMIM = "semim"
TREE = "tree"
MTYPE2INT = {LINCHAIN: 1, TREE: 3, SEMIM: 4}


##################################################################
# Methods
def instances(fi):
xseq = crfsuite.ItemSequence()
"""Iterate over instances in the provided file.
Args:
fi (FileInput): input file stream
Yields:
crfsuite instances
"""
item_seen = False
xseq = ItemSequence()
for line in fi:
line = line.strip('\n')
if not line:
# An empty line presents an end of a sequence.
yield xseq
xseq = crfsuite.ItemSequence()
# An empty line presents an end of a sequence.
if item_seen:
yield xseq
xseq = ItemSequence()
item_seen = False
continue

# Split the line on TAB characters.
# Split the line on TAB characters.
fields = line.split('\t')
item = crfsuite.Item()
item_seen = True
item = Item()
for field in fields[1:]:
p = field.rfind(':')
if p == -1:
# Unweighted (weight=1) attribute.
item.append(crfsuite.Attribute(field))
# Unweighted (weight=1) attribute.
item.append(Attribute(field))
else:
# Weighted attribute
item.append(crfsuite.Attribute(field[:p], float(field[p+1:])))

# Weighted attribute
item.append(Attribute(field[:p], float(field[p+1:])))
# Append the item to the item sequence.
xseq.append(item)
if item_seen:
yield xseq

##################################################################
# Main
if __name__ == '__main__':
fi = sys.stdin
fo = sys.stdout

if len(sys.argv) < 2:
raise Exception("Provide model path as the 1-st argument.")
import argparse
parser = argparse.ArgumentParser(description=
"Script for testing CRF models.")
parser.add_argument("-m", "--model",
help="model in which to store the file", type=str,
default="")
parser.add_argument("-t", "--type",
help="type of graphical model to use",
type=str, default=LINCHAIN,
choices=(LINCHAIN, TREE, SEMIM))
parser.add_argument("files", help="input files", nargs='*',
type=argparse.FileType('r'),
default=[sys.stdin])
args = parser.parse_args()

# Create a tagger object.
tagger = crfsuite.Tagger()
tagger = Tagger()

# Load the model to the tagger.
# the second argumend specifies the model type (1 - 1d, 2 - tree, 4 - semim)
tagger.open(sys.argv[1], 2)

for xseq in instances(fi):
# Tag the sequence.
tagger.set(xseq)
# Obtain the label sequence predicted by the tagger.
yseq = tagger.viterbi()
# Output the probability of the predicted label sequence.
print tagger.probability(yseq)
for t, y in enumerate(yseq):
# Output the predicted labels with their marginal probabilities.
print '%s:%f' % (y, tagger.marginal(y, t))
print
# the second argumend specifies the model type (1 - 1d, 3 - tree, 4 -
# semim)
if not tagger.open(args.model, MTYPE2INT[args.type]):
raise RuntimeError("Could not load model file.")

for ifile in args.files:
for xseq in instances(ifile):
# Tag the sequence.
tagger.set(xseq)
# Obtain the label sequence predicted by the tagger.
yseq = tagger.viterbi()
# Output the probability of the predicted label sequence.
# print tagger.probability(yseq)
for t, y in enumerate(yseq):
# Output the predicted labels with their marginal
# probabilities.
if args.type == SEMIM:
print '%s' % (y)
else:
print '%s:%f' % (y, tagger.marginal(y, t))
print
77 changes: 43 additions & 34 deletions swig/python/sample_train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# -*- mode: python; codig: utf-8; -*-
# -*- mode: python; coding: utf-8; -*-

##################################################################
# Imports
Expand All @@ -14,16 +14,17 @@
# Variables and Constants
ENCODING = "UTF-8"


##################################################################
# Class

# Inherit crfsuite.Trainer to implement message() function, which receives
# progress messages from a training process.
class Trainer(crfsuite.Trainer):
def message(self, s):
# Simply output the progress messages to STDOUT.
sys.stdout.write(s)


##################################################################
# Methods
def instances(fi):
Expand All @@ -33,24 +34,23 @@ def instances(fi):
for line in fi:
line = line.strip('\n')
if not line:
# An empty line presents an end of a sequence.
# An empty line presents the end of a sequence.
yield xseq, yseq
xseq = crfsuite.ItemSequence()
yseq = crfsuite.StringList()
continue

# Split the line with TAB characters.
fields = line.split('\t')

# Append attributes to the item.
# Append attributes to the item.
item = crfsuite.Item()
for field in fields[1:]:
p = field.rfind(':')
if p == -1:
# Unweighted (weight=1) attribute.
# Unweighted (weight=1) attribute.
item.append(crfsuite.Attribute(field))
else:
# Weighted attribute
# Weighted attribute
item.append(crfsuite.Attribute(field[:p], float(field[p+1:])))

# Append the item to the item sequence.
Expand All @@ -64,33 +64,47 @@ def instances(fi):
"""Train CRF model on the given dataset.
Args:
-----
argv - command line arguments
argv (list(str)): command line arguments
Returns:
--------
(void)
(void):
"""
import argparse
parser = argparse.ArgumentParser(description = """Script for training CRF models.""")
parser.add_argument("--help-params", help = "output CRFSuite parameters")
parser.add_argument("-a", "--algorithm", help = "type of graphical model to use", nargs = 1, \
type = str, default = "lbfgs", choices = ("lbfgs", "l2sgd", "ap", "pa", \
"arow"))
parser.add_argument("-m", "--model", help = "model in which to store the file", type = str, \
default = "")
parser.add_argument("-t", "--type", help = "type of graphical model to use", \
type = str, default = "1d", choices = ("1d", "tree", "semim"))
parser.add_argument("-v", "--version", help = "output CRFSuite version")
parser.add_argument("files", help="input files", nargs = '*', type = argparse.FileType('r'),
default = [sys.stdin])
parser = argparse.ArgumentParser(description=
"Script for training CRF models.")
parser.add_argument("--help-params", help="output CRFSuite parameters",
action="store_true")
parser.add_argument("-a", "--algorithm",
help="type of graphical model to use",
type=str, default="lbfgs", choices=("lbfgs", "l2sgd",
"ap", "pa",
"arow"))
parser.add_argument("-m", "--model",
help="model in which to store the file", type=str,
default="")
parser.add_argument("-t", "--type",
help="type of graphical model to use",
type=str, default="1d",
choices=("1d", "tree", "semim"))
parser.add_argument("-v", "--version",
help="output CRFSuite version", action="store_true")
parser.add_argument("files", help="input files", nargs='*',
type=argparse.FileType('r'),
default=[sys.stdin])
args = parser.parse_args()

# This demonstrates how to obtain the version string of CRFsuite.
if args.version:
print(crfsuite.version())
elif args.help_params:
sys.exit(0)
# Create a Trainer object.
trainer = Trainer()
# Use L2-regularized SGD and 1st-order dyad features.
if not trainer.select(str(args.algorithm), str(args.type)):
raise Exception("Could not initialize trainer.")

if args.help_params:
for name in trainer.params():
print(' '.join([name, trainer.get(name), trainer.help(name)]))
else:
Expand All @@ -100,15 +114,10 @@ def instances(fi):
pass
elif os.path.exists(mdir):
if not os.path.isdir(mdir) or not os.access(mdir, os.R_OK):
print("Can't write to directory '{:s}'.".format(mdir), file = sys.stderr)
print("Can't write to directory '{:s}'.".format(mdir),
file=sys.stderr)
else:
os.makedirs(mdir)
# Create a Trainer object.
trainer = Trainer()

# Use L2-regularized SGD and 1st-order dyad features.
if not trainer.select(str(args.algorithm), str(args.type)):
raise Exception("Could not initialize trainer.")

# Set the coefficient for L2 regularization to 0.1
# trainer.set('c2', '0.1')
Expand All @@ -117,9 +126,9 @@ def instances(fi):
for ifile in args.files:
for xseq, yseq in instances(ifile):
trainer.append(xseq, yseq, 0)
# print("Dataset read...", file = sys.stderr)
# print("Dataset read...", file=sys.stderr)

# Start training; the training process will invoke trainer.message()
# to report the progress.
trainer.train(args.model or "", -1)
# print("Model trained...", file = sys.stderr)
trainer.train(str(args.model), -1)
# print("Model trained...", file=sys.stderr)

0 comments on commit b3c3f00

Please sign in to comment.