Skip to content

Commit

Permalink
added PySide GUI implementation for live updating plot and plotting s…
Browse files Browse the repository at this point in the history
…ub-epoch monitor updates
  • Loading branch information
AdamStone committed Jan 12, 2015
1 parent 105a63f commit b1f8029
Showing 1 changed file with 128 additions and 18 deletions.
146 changes: 128 additions & 18 deletions pylearn2/train_extensions/live_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,44 @@
Training extension for allowing querying of monitoring values while an
experiment executes.
"""
__authors__ = "Dustin Webb"
__authors__ = "Dustin Webb, Adam Stone"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
__credits__ = ["Dustin Webb"]
__credits__ = ["Dustin Webb, Adam Stone"]
__license__ = "3-clause BSD"
__maintainer__ = "LISA Lab"
__email__ = "pylearn-dev@googlegroups"

import copy
import logging
log = logging.getLogger('LiveMonitor')

try:
import zmq
zmq_available = True
except:
except Exception:
zmq_available = False

try:
from PySide import QtCore, QtGui

import sys
import matplotlib
import numpy as np
matplotlib.use('Qt4Agg')
matplotlib.rcParams['backend.qt4'] = 'PySide'

from matplotlib.backends.backend_qt4agg import (
FigureCanvasQTAgg as FigureCanvas)
from matplotlib.figure import Figure

qt_available = True
except Exception:
qt_available = False

try:
import matplotlib.pyplot as plt
pyplot_available = True
except:
except Exception:
pyplot_available = False

from functools import wraps
Expand Down Expand Up @@ -345,29 +364,120 @@ def update_channels(self, channel_list, start=-1, end=-1, step=1):
chan.time_record += rsp_chan.time_record
chan.val_record += rsp_chan.val_record

def follow_channels(self, channel_list):
def follow_channels(self, channel_list, use_qt=False):
"""
Tracks and plots a specified set of channels in real time.
Parameters
----------
channel_list : list
A list of the channels for which data has been requested.
use_qt : bool
Use a PySide GUI for plotting, if available.
"""
if not pyplot_available:
if use_qt:
if not qt_available:
log.warning(
'follow_channels called with use_qt=True, but PySide '
'is not available. Falling back on matplotlib ion().')
else:
# only create new qt app if running the first time in session
if not hasattr(self, 'gui'):
self.gui = LiveMonitorGUI(self, channel_list)

self.gui.channel_list = channel_list
self.gui.start()

elif not pyplot_available:
raise ImportError('pyplot needs to be installed for '
'this functionality.')
plt.clf()
plt.ion()
while True:
self.update_channel(channel_list)
else:
plt.clf()
for channel_name in self.channels:
plt.plot(
self.channels[channel_name].epoch_record,
self.channels[channel_name].val_record,
label=channel_name
)
plt.legend()
plt.ion()
plt.draw()
while True:
self.update_channels(channel_list)
plt.clf()
for channel_name in self.channels:
plt.plot(
self.channels[channel_name].epoch_record,
self.channels[channel_name].val_record,
label=channel_name
)
plt.legend()
plt.ion()
plt.draw()

if qt_available:
class LiveMonitorGUI(QtGui.QMainWindow):
def __init__(self, lm, channel_list):
"""
PySide GUI implementation for live monitoring channels.
Parameters
----------
lm : LiveMonitor instance
The LiveMonitor instance to which the GUI belongs.
channel_list : list
A list of the channels to display.
"""
self.app = QtGui.QApplication(sys.argv)

super(LiveMonitorGUI, self).__init__()
self.lm = lm
self.channel_list = channel_list
self.updaterThread = UpdaterThread(lm, channel_list)
self.updaterThread.updated.connect(self.refresh)
self.initUI()

def initUI(self):
self.resize(300, 200)
self.fig = Figure(figsize=(300, 200), dpi=72,
facecolor=(1, 1, 1), edgecolor=(0, 0, 0))
self.ax = self.fig.add_subplot(111)
self.canvas = FigureCanvas(self.fig)
self.setCentralWidget(self.canvas)

def refresh(self):
self.ax.cla() # clear previous plot
for channel_name in self.channel_list:

X = epoch_record = self.lm.channels[channel_name].epoch_record
Y = val_record = self.lm.channels[channel_name].val_record

indices = np.nonzero(np.diff(epoch_record))[0] + 1
epoch_record_split = np.split(epoch_record, indices)
val_record_split = np.split(val_record, indices)

X = np.zeros(len(epoch_record))
Y = np.zeros(len(epoch_record))

for i, epoch in enumerate(epoch_record_split):

j = i*len(epoch_record_split[0])
X[j: j + len(epoch)] = (
1.*np.arange(len(epoch)) / len(epoch) + epoch[0])
Y[j: j + len(epoch)] = val_record_split[i]

self.ax.plot(X, Y, label=channel_name)

self.ax.legend()
self.canvas.draw()
self.updaterThread.start()

def start(self):
self.show()
self.updaterThread.start()
self.app.exec_()

class UpdaterThread(QtCore.QThread):
updated = QtCore.Signal()

def __init__(self, lm, channel_list):
super(UpdaterThread, self).__init__()
self.lm = lm
self.channel_list = channel_list

def run(self):
self.lm.update_channels(self.channel_list) # blocking
self.updated.emit()

0 comments on commit b1f8029

Please sign in to comment.