-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'feature-mms' into develop
- Loading branch information
Showing
19 changed files
with
673 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,3 +90,4 @@ ENV/ | |
|
||
# Corpora and Fixtures | ||
corpus/fixtures/debates | ||
corpus/fixtures/*.pickle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# arbiter | ||
# A Django app that implements a MMS for the red/blue models. | ||
# | ||
# Author: Benjamin Bengfort <[email protected]> | ||
# Created: Tue Aug 02 09:13:41 2016 -0400 | ||
# | ||
# Copyright (C) 2016 District Data Labs | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: __init__.py [] [email protected] $ | ||
|
||
""" | ||
A Django app that implements a MMS for the red/blue models. | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## | ||
|
||
|
||
########################################################################## | ||
## Configuration | ||
########################################################################## | ||
|
||
default_app_config = 'arbiter.apps.ArbiterConfig' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# arbiter.admin | ||
# Django admin CMS definitions and registrations for the arbiter app. | ||
# | ||
# Author: Benjamin Bengfort <[email protected]> | ||
# Created: Tue Aug 02 09:18:18 2016 -0400 | ||
# | ||
# Copyright (C) 2016 District Data Labs | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: admin.py [] [email protected] $ | ||
|
||
""" | ||
Django admin CMS definitions and registrations for the arbiter app. | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## | ||
|
||
from django.contrib import admin | ||
from arbiter.models import Estimator, Score | ||
|
||
########################################################################## | ||
## Register Admin | ||
########################################################################## | ||
|
||
admin.site.register(Estimator) | ||
admin.site.register(Score) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# arbiter.apps | ||
# Application definition for the arbiter app. | ||
# | ||
# Author: Benjamin Bengfort <[email protected]> | ||
# Created: Tue Aug 02 09:14:47 2016 -0400 | ||
# | ||
# Copyright (C) 2016 District Data Labs | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: apps.py [] [email protected] $ | ||
|
||
""" | ||
Application definition for the arbiter app. | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## | ||
|
||
from django.apps import AppConfig | ||
|
||
|
||
########################################################################## | ||
## Corpus Config | ||
########################################################################## | ||
|
||
class ArbiterConfig(AppConfig): | ||
|
||
name = 'arbiter' | ||
verbose_name = 'Arbiter' | ||
|
||
def ready(self): | ||
pass | ||
# import arbiter.signals |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# arbiter.management | ||
# A module that specifies Django management commands for the arbiter app. | ||
# | ||
# Author: Benjamin Bengfort <[email protected]> | ||
# Created: Tue Aug 02 10:36:54 2016 -0400 | ||
# | ||
# Copyright (C) 2016 District Data Labs | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: __init__.py [] [email protected] $ | ||
|
||
""" | ||
A module that specifies Django management commands for the arbiter app. | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# arbiter.management.commands | ||
# Module that contains each individual management command for Django. | ||
# | ||
# Author: Benjamin Bengfort <[email protected]> | ||
# Created: Tue Aug 02 10:37:24 2016 -0400 | ||
# | ||
# Copyright (C) 2016 District Data Labs | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: __init__.py [] [email protected] $ | ||
|
||
""" | ||
Module that contains each individual management command for Django. | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# arbiter.management.commands.train | ||
# Command to train red/blue classifiers from the command line. | ||
# | ||
# Author: Benjamin Bengfort <[email protected]> | ||
# Created: Tue Aug 02 10:38:54 2016 -0400 | ||
# | ||
# Copyright (C) 2016 District Data Labs | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: train.py [] [email protected] $ | ||
|
||
""" | ||
Command to train red/blue classifiers from the command line. | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## | ||
|
||
import numpy as np | ||
|
||
from arbiter.models import Estimator, Score | ||
from django.contrib.auth.models import User | ||
from corpus.reader import TranscriptCorpusReader | ||
from corpus.learn import CorpusLoader, build_model | ||
from django.core.management.base import BaseCommand, CommandError | ||
|
||
from sklearn.naive_bayes import MultinomialNB | ||
from sklearn.linear_model import SGDClassifier | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
|
||
########################################################################## | ||
## Training Command | ||
########################################################################## | ||
|
||
class Command(BaseCommand): | ||
|
||
help = "Trains red/blue classifiers and stores them in the database." | ||
|
||
estimators = { | ||
'maxent': (LogisticRegression, {}), | ||
'svm': (SGDClassifier, {'loss':'hinge', 'penalty':'l2', 'alpha':1e-3}), | ||
'nbayes': (MultinomialNB, {}), | ||
} | ||
|
||
def add_arguments(self, parser): | ||
""" | ||
Add command line argparse arguments. | ||
""" | ||
# Model selection argument | ||
parser.add_argument( | ||
'-m', '--model', choices=self.estimators, default='maxent', | ||
help='specify the model form to fit on the given corpus', | ||
) | ||
|
||
# Number of folds for cross-validation | ||
parser.add_argument( | ||
'-f', '--folds', type=int, default=12, | ||
help='number of folds to use in cross-validation', | ||
) | ||
|
||
# Optional ownership argument | ||
parser.add_argument( | ||
'-u', '--username', default=None, | ||
help='specify the username to associate with the model', | ||
) | ||
|
||
# TODO: Change this to allow for a query or a path on disk | ||
parser.add_argument('corpus', nargs=1, help='path to the corpus on disk') | ||
|
||
def handle(self, *args, **options): | ||
""" | ||
Handles the model training process | ||
""" | ||
|
||
# Get the details from the command line arguments | ||
model, kwargs = self.estimators[options['model']] | ||
owner = self.get_user(options['username']) | ||
|
||
# Construct the corpus and loader in preparation for training. | ||
# TODO: Make the corpus loader construction a method to handle querysets | ||
corpus = TranscriptCorpusReader(options['corpus'][0]) | ||
loader = CorpusLoader(corpus, options['folds']) | ||
|
||
# Inform the user that the training process is beginning | ||
self.stdout.write(( | ||
"Starting training of {} {} models on the corpus at {}\n" | ||
"This may take quite a bit of time, please be patient!\n" | ||
).format( | ||
loader.n_folds + 1, model.__name__, options['corpus'][0] | ||
)) | ||
|
||
# GO! Build the model forever! Whooo!!! | ||
(clf, scores), total_time = build_model(loader, model, **kwargs) | ||
|
||
# Save the estimator model | ||
estimator = Estimator.objects.create( | ||
model_type = Estimator.TYPES.classifier, | ||
model_class = model.__name__, | ||
model_form = repr(clf), | ||
estimator = clf, | ||
build_time = total_time, | ||
owner = owner, | ||
) | ||
|
||
# Save the scores objects. | ||
for metric, values in scores.items(): | ||
|
||
# Handle the time key in particular. | ||
if metric == 'times': | ||
Score.objects.create( | ||
metric = Score.METRICS.time, | ||
score = values['final'].total_seconds(), | ||
folds = [td.total_seconds() for td in values['folds']], | ||
estimator = estimator, | ||
) | ||
continue | ||
|
||
# Handle generic scores for the model | ||
for label, folds in values.items(): | ||
if metric == 'support' and label == 'average': | ||
# This will be an array of None values, so skip. | ||
continue | ||
|
||
Score.objects.create( | ||
metric = metric, | ||
score = np.asarray(folds).mean(), | ||
label = label, | ||
folds = folds, | ||
estimator = estimator, | ||
) | ||
|
||
|
||
# Report model construction complete | ||
self.stdout.write( | ||
"Training complete in {}! Estimator saved to the database\n".format(total_time) | ||
) | ||
|
||
def get_user(self, username): | ||
""" | ||
Returns a user or None, raising a command error if no user with the | ||
specified username is found in the database. | ||
""" | ||
if username is None: return None | ||
try: | ||
return User.objects.get(username=username) | ||
except User.DoesNotExist: | ||
raise CommandError( | ||
"No user with username '{}' in the database".format(username) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# -*- coding: utf-8 -*- | ||
# Generated by Django 1.9.7 on 2016-08-02 17:06 | ||
from __future__ import unicode_literals | ||
|
||
from django.conf import settings | ||
import django.contrib.postgres.fields | ||
from django.db import migrations, models | ||
import django.db.models.deletion | ||
import django.utils.timezone | ||
import model_utils.fields | ||
import picklefield.fields | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
initial = True | ||
|
||
dependencies = [ | ||
migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||
] | ||
|
||
operations = [ | ||
migrations.CreateModel( | ||
name='Estimator', | ||
fields=[ | ||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
('created', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='created')), | ||
('modified', model_utils.fields.AutoLastModifiedField(default=django.utils.timezone.now, editable=False, verbose_name='modified')), | ||
('model_type', models.CharField(choices=[('classifier', 'classifier'), ('regression', 'regression'), ('clusters', 'clusters'), ('decomposition', 'decomposition')], max_length=32)), | ||
('model_class', models.CharField(blank=True, default=None, max_length=255, null=True)), | ||
('model_form', models.CharField(blank=True, default=None, max_length=512, null=True)), | ||
('estimator', picklefield.fields.PickledObjectField(blank=True, default=None, editable=False, null=True)), | ||
('build_time', models.DurationField(blank=True, default=None, null=True)), | ||
('owner', models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), | ||
], | ||
options={ | ||
'get_latest_by': 'created', | ||
'db_table': 'estimators', | ||
}, | ||
), | ||
migrations.CreateModel( | ||
name='Score', | ||
fields=[ | ||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
('created', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='created')), | ||
('modified', model_utils.fields.AutoLastModifiedField(default=django.utils.timezone.now, editable=False, verbose_name='modified')), | ||
('metric', models.CharField(choices=[('accuracy', 'accuracy'), ('auc', 'auc'), ('brier', 'brier'), ('f1', 'f1'), ('fbeta', 'fbeta'), ('hamming', 'hamming'), ('hinge', 'hinge'), ('jaccard', 'jaccard'), ('logloss', 'logloss'), ('mcc', 'mcc'), ('precision', 'precision'), ('recall', 'recall'), ('roc', 'roc'), ('support', 'support'), ('mae', 'mae'), ('mse', 'mse'), ('mdae', 'mdae'), ('r2', 'r2'), ('rand', 'rand'), ('completeness', 'completeness'), ('homogeneity', 'homogeneity'), ('mutual', 'mutual'), ('silhouette', 'silhouette'), ('v', 'v'), ('time', 'time')], max_length=32)), | ||
('score', models.FloatField(blank=True, default=None, null=True)), | ||
('label', models.CharField(blank=True, default=None, max_length=32, null=True)), | ||
('folds', django.contrib.postgres.fields.ArrayField(base_field=models.FloatField(), blank=True, default=None, null=True, size=None)), | ||
('estimator', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='scores', to='arbiter.Estimator')), | ||
], | ||
options={ | ||
'get_latest_by': 'created', | ||
'db_table': 'evaluations', | ||
}, | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# arbiter.migrations | ||
# Database migrations for arbiter models. | ||
# | ||
# Author: Benjamin Bengfort <[email protected]> | ||
# Created: Tue Aug 02 09:13:04 2016 -0400 | ||
# | ||
# Copyright (C) 2016 District Data Labs | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: __init__.py [] [email protected] $ | ||
|
||
""" | ||
Database migrations for arbiter models. | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## |
Oops, something went wrong.