diff --git a/pylearn2/train_extensions/live_monitoring.py b/pylearn2/train_extensions/live_monitoring.py index ea4a6e15fd..2959edefce 100644 --- a/pylearn2/train_extensions/live_monitoring.py +++ b/pylearn2/train_extensions/live_monitoring.py @@ -2,25 +2,44 @@ Training extension for allowing querying of monitoring values while an experiment executes. """ -__authors__ = "Dustin Webb" +__authors__ = "Dustin Webb, Adam Stone" __copyright__ = "Copyright 2010-2012, Universite de Montreal" -__credits__ = ["Dustin Webb"] +__credits__ = ["Dustin Webb, Adam Stone"] __license__ = "3-clause BSD" __maintainer__ = "LISA Lab" __email__ = "pylearn-dev@googlegroups" import copy +import logging +log = logging.getLogger('LiveMonitor') try: import zmq zmq_available = True -except: +except Exception: zmq_available = False +try: + from PySide import QtCore, QtGui + + import sys + import matplotlib + import numpy as np + matplotlib.use('Qt4Agg') + matplotlib.rcParams['backend.qt4'] = 'PySide' + + from matplotlib.backends.backend_qt4agg import ( + FigureCanvasQTAgg as FigureCanvas) + from matplotlib.figure import Figure + + qt_available = True +except Exception: + qt_available = False + try: import matplotlib.pyplot as plt pyplot_available = True -except: +except Exception: pyplot_available = False from functools import wraps @@ -345,7 +364,7 @@ def update_channels(self, channel_list, start=-1, end=-1, step=1): chan.time_record += rsp_chan.time_record chan.val_record += rsp_chan.val_record - def follow_channels(self, channel_list): + def follow_channels(self, channel_list, use_qt=False): """ Tracks and plots a specified set of channels in real time. @@ -353,21 +372,112 @@ def follow_channels(self, channel_list): ---------- channel_list : list A list of the channels for which data has been requested. + use_qt : bool + Use a PySide GUI for plotting, if available. """ - if not pyplot_available: + if use_qt: + if not qt_available: + log.warning( + 'follow_channels called with use_qt=True, but PySide ' + 'is not available. Falling back on matplotlib ion().') + else: + # only create new qt app if running the first time in session + if not hasattr(self, 'gui'): + self.gui = LiveMonitorGUI(self, channel_list) + + self.gui.channel_list = channel_list + self.gui.start() + + elif not pyplot_available: raise ImportError('pyplot needs to be installed for ' 'this functionality.') - plt.clf() - plt.ion() - while True: - self.update_channel(channel_list) + else: plt.clf() - for channel_name in self.channels: - plt.plot( - self.channels[channel_name].epoch_record, - self.channels[channel_name].val_record, - label=channel_name - ) - plt.legend() plt.ion() - plt.draw() + while True: + self.update_channels(channel_list) + plt.clf() + for channel_name in self.channels: + plt.plot( + self.channels[channel_name].epoch_record, + self.channels[channel_name].val_record, + label=channel_name + ) + plt.legend() + plt.ion() + plt.draw() + +if qt_available: + class LiveMonitorGUI(QtGui.QMainWindow): + def __init__(self, lm, channel_list): + """ + PySide GUI implementation for live monitoring channels. + + Parameters + ---------- + lm : LiveMonitor instance + The LiveMonitor instance to which the GUI belongs. + + channel_list : list + A list of the channels to display. + """ + self.app = QtGui.QApplication(sys.argv) + + super(LiveMonitorGUI, self).__init__() + self.lm = lm + self.channel_list = channel_list + self.updaterThread = UpdaterThread(lm, channel_list) + self.updaterThread.updated.connect(self.refresh) + self.initUI() + + def initUI(self): + self.resize(300, 200) + self.fig = Figure(figsize=(300, 200), dpi=72, + facecolor=(1, 1, 1), edgecolor=(0, 0, 0)) + self.ax = self.fig.add_subplot(111) + self.canvas = FigureCanvas(self.fig) + self.setCentralWidget(self.canvas) + + def refresh(self): + self.ax.cla() # clear previous plot + for channel_name in self.channel_list: + + X = epoch_record = self.lm.channels[channel_name].epoch_record + Y = val_record = self.lm.channels[channel_name].val_record + + indices = np.nonzero(np.diff(epoch_record))[0] + 1 + epoch_record_split = np.split(epoch_record, indices) + val_record_split = np.split(val_record, indices) + + X = np.zeros(len(epoch_record)) + Y = np.zeros(len(epoch_record)) + + for i, epoch in enumerate(epoch_record_split): + + j = i*len(epoch_record_split[0]) + X[j: j + len(epoch)] = ( + 1.*np.arange(len(epoch)) / len(epoch) + epoch[0]) + Y[j: j + len(epoch)] = val_record_split[i] + + self.ax.plot(X, Y, label=channel_name) + + self.ax.legend() + self.canvas.draw() + self.updaterThread.start() + + def start(self): + self.show() + self.updaterThread.start() + self.app.exec_() + + class UpdaterThread(QtCore.QThread): + updated = QtCore.Signal() + + def __init__(self, lm, channel_list): + super(UpdaterThread, self).__init__() + self.lm = lm + self.channel_list = channel_list + + def run(self): + self.lm.update_channels(self.channel_list) # blocking + self.updated.emit()