Skip to content

Commit

Permalink
plot confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Feb 22, 2018
1 parent 3b891b0 commit c1917cd
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 26 deletions.
12 changes: 7 additions & 5 deletions data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,14 @@ def _get_zero_padded(list_of_arrays):
return np.array(batch)


phns = ['h#', 'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl',
'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi',
'er', 'ey', 'f', 'g', 'gcl', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh',
'k', 'kcl', 'l', 'm', 'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl',
'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w', 'y', 'z', 'zh']


def load_vocab():
phns = ['h#', 'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl',
'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi',
'er', 'ey', 'f', 'g', 'gcl', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh',
'k', 'kcl', 'l', 'm', 'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl',
'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w', 'y', 'z', 'zh']
phn2idx = {phn: idx for idx, phn in enumerate(phns)}
idx2phn = {idx: phn for idx, phn in enumerate(phns)}

Expand Down
30 changes: 20 additions & 10 deletions eval1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import argparse

import tensorflow as tf
from data_load import get_batch
from data_load import get_batch, phns, load_vocab
from hparam import hparam as hp
from models import Model
from utils import plot_confusion_matrix
import numpy as np


def eval(logdir, writer, queue=False):
Expand All @@ -21,8 +23,14 @@ def eval(logdir, writer, queue=False):
# Loss
loss_op = model.loss_net1()

# confusion matrix
y_ppg_1d = tf.reshape(model.y_ppg, shape=(tf.size(model.y_ppg),))
pred_ppg_1d = tf.reshape(model.pred_ppg, shape=(tf.size(model.pred_ppg),))

# Summary
summ_op = summaries(acc_op, loss_op)
tf.summary.scalar('net1/eval/acc', acc_op)
tf.summary.scalar('net1/eval/loss', loss_op)
summ_op = tf.summary.merge_all()

session_conf = tf.ConfigProto(
allow_soft_placement=True,
Expand All @@ -37,12 +45,20 @@ def eval(logdir, writer, queue=False):
model.load(sess, 'train1', logdir=logdir)

if queue:
summ, acc, loss = sess.run([summ_op, acc_op, loss_op])
summ, acc, loss, y_ppg_1d, pred_ppg_1d = sess.run([summ_op, acc_op, loss_op, y_ppg_1d, pred_ppg_1d])
else:
mfcc, ppg = get_batch(model.mode, model.batch_size)
summ, acc, loss = sess.run([summ_op, acc_op, loss_op], feed_dict={model.x_mfcc: mfcc, model.y_ppgs: ppg})
summ, acc, loss, y_ppg_1d, pred_ppg_1d = sess.run([summ_op, acc_op, loss_op, y_ppg_1d, pred_ppg_1d],
feed_dict={model.x_mfcc: mfcc, model.y_ppg: ppg})

# plot confusion matrix
_, idx2phn = load_vocab()
y_ppg_1d = [idx2phn[i] for i in y_ppg_1d]
pred_ppg_1d = [idx2phn[i] for i in pred_ppg_1d]
cm_summ = plot_confusion_matrix(y_ppg_1d, pred_ppg_1d, phns)

writer.add_summary(summ)
writer.add_summary(cm_summ)

print("acc:", acc)
print("loss:", loss)
Expand All @@ -52,12 +68,6 @@ def eval(logdir, writer, queue=False):
coord.join(threads)


def summaries(acc, loss):
tf.summary.scalar('net1/eval/acc', acc)
tf.summary.scalar('net1/eval/loss', loss)
return tf.summary.merge_all()


def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('case', type=str, help='experiment case name')
Expand Down
2 changes: 1 addition & 1 deletion hparams/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ logdir_path: '/data/private/vc/logdir'

train1:
# path
data_path: 'timit/TIMIT/TRAIN/*/*/*.wav'
data_path: '/data/private/vc/datasets/timit/TIMIT/TRAIN/*/*/*.wav'

# model
hidden_units: 256 # alias: E
Expand Down
13 changes: 5 additions & 8 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import tensorflow as tf

from data_load import get_batch_queue, load_vocab
from data_load import get_batch_queue, phns
from modules import prenet, cbhg
import glob, os
from utils import split_path
Expand All @@ -18,7 +18,7 @@ def __init__(self, mode, batch_size, hp, queue=True):
self.is_training = self.get_is_training(mode)

# Input
self.x_mfcc, self.y_ppgs, self.y_spec, self.y_mel, self.num_batch = self.get_input(mode, batch_size, queue)
self.x_mfcc, self.y_ppg, self.y_spec, self.y_mel, self.num_batch = self.get_input(mode, batch_size, queue)

# Networks
self.net_template = tf.make_template('net', self._net2)
Expand Down Expand Up @@ -61,9 +61,6 @@ def get_is_training(self, mode):

def _net1(self):
with tf.variable_scope('net1'):
# Load vocabulary
phn2idx, idx2phn = load_vocab()

# Pre-net
prenet_out = prenet(self.x_mfcc,
num_units=[self.hp.train1.hidden_units, self.hp.train1.hidden_units // 2],
Expand All @@ -74,22 +71,22 @@ def _net1(self):
out = cbhg(prenet_out, self.hp.train1.num_banks, self.hp.train1.hidden_units // 2, self.hp.train1.num_highway_blocks, self.hp.train1.norm_type, self.is_training)

# Final linear projection
logits = tf.layers.dense(out, len(phn2idx)) # (N, T, V)
logits = tf.layers.dense(out, len(phns)) # (N, T, V)
ppgs = tf.nn.softmax(logits / self.hp.train1.t) # (N, T, V)
preds = tf.to_int32(tf.arg_max(logits, dimension=-1)) # (N, T)

return ppgs, preds, logits

def loss_net1(self):
istarget = tf.sign(tf.abs(tf.reduce_sum(self.x_mfcc, -1))) # indicator: (N, T)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_ppg / self.hp.train1.t, labels=self.y_ppgs)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_ppg / self.hp.train1.t, labels=self.y_ppg)
loss *= istarget
loss = tf.reduce_mean(loss)
return loss

def acc_net1(self):
istarget = tf.sign(tf.abs(tf.reduce_sum(self.x_mfcc, -1))) # indicator: (N, T)
num_hits = tf.reduce_sum(tf.to_float(tf.equal(self.pred_ppg, self.y_ppgs)) * istarget)
num_hits = tf.reduce_sum(tf.to_float(tf.equal(self.pred_ppg, self.y_ppg)) * istarget)
num_targets = tf.reduce_sum(istarget)
acc = num_hits / num_targets
return acc
Expand Down
2 changes: 1 addition & 1 deletion train1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def train(logdir, queue=True):
sess.run(train_op)
else:
mfcc, ppg = get_batch(model.mode, model.batch_size)
sess.run(train_op, feed_dict={model.x_mfcc: mfcc, model.y_ppgs: ppg})
sess.run(train_op, feed_dict={model.x_mfcc: mfcc, model.y_ppg: ppg})

# Write checkpoint files at every epoch
summ, gs = sess.run([summ_op, global_step])
Expand Down
63 changes: 62 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# -*- coding: utf-8 -*-
#!/usr/bin/env python

import os
import glob
import itertools
import os
import re
from textwrap import wrap

import matplotlib
import numpy as np
import tfplot
from sklearn.metrics import confusion_matrix


def split_path(path):
Expand Down Expand Up @@ -31,3 +38,57 @@ def normalize_0_1(values, max, min):
def denormalize_0_1(normalized, max, min):
values = np.clip(normalized, 0, 1) * (max - min) + min
return values


def plot_confusion_matrix(correct_labels, predict_labels, labels, tensor_name='confusion_matrix', normalize=False):
'''
Parameters:
correct_labels : These are your true classification categories.
predict_labels : These are you predicted classification categories
labels : This is a list of labels which will be used to display the axix labels
title='Confusion matrix' : Title for your matrix
tensor_name = 'MyFigure/image' : Name for the output summay tensor
Returns:
summary: TensorFlow summary
Other itema to note:
- Depending on the number of category and the data , you may have to modify the figzie, font sizes etc.
- Currently, some of the ticks dont line up due to rotations.
'''
cm = confusion_matrix(correct_labels, predict_labels, labels=labels)
if normalize:
cm = cm.astype('float') * 10 / cm.sum(axis=1)[:, np.newaxis]
cm = np.nan_to_num(cm, copy=True)
cm = cm.astype('int')

np.set_printoptions(precision=2)
###fig, ax = matplotlib.figure.Figure()

fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k')
ax = fig.add_subplot(1, 1, 1)
im = ax.imshow(cm, cmap='Oranges')

classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
classes = ['\n'.join(wrap(l, 40)) for l in classes]

tick_marks = np.arange(len(classes))

ax.set_xlabel('Predicted', fontsize=7)
ax.set_xticks(tick_marks)
c = ax.set_xticklabels(classes, fontsize=4, rotation=-90, ha='center')
ax.xaxis.set_label_position('bottom')
ax.xaxis.tick_bottom()

ax.set_ylabel('True Label', fontsize=7)
ax.set_yticks(tick_marks)
ax.set_yticklabels(classes, fontsize=4, va='center')
ax.yaxis.set_label_position('left')
ax.yaxis.tick_left()

for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], 'd') if cm[i, j] != 0 else '.', horizontalalignment="center", fontsize=6,
verticalalignment='center', color="black")
fig.set_tight_layout(True)
summary = tfplot.figure.to_summary(fig, tag=tensor_name)
return summary

0 comments on commit c1917cd

Please sign in to comment.