diff --git a/sleap/gui/widgets/mpl.py b/sleap/gui/widgets/mpl.py index a9b7fc838..2c11488dc 100644 --- a/sleap/gui/widgets/mpl.py +++ b/sleap/gui/widgets/mpl.py @@ -6,11 +6,11 @@ from qtpy import QtWidgets from matplotlib.figure import Figure -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as Canvas +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as Canvas import matplotlib # Ensure using PyQt5 backend -matplotlib.use("QT5Agg") +matplotlib.use("QtAgg") class MplCanvas(Canvas): @@ -26,3 +26,23 @@ def __init__(self, width=5, height=4, dpi=100): self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding ) Canvas.updateGeometry(self) + self.series = {} + + def add_scatter(self, key, x, y, label, color, size): + if key in self.series: + self.series[key].remove() + scatter = self.axes.scatter(x, y, label=label, color=color, s=size) + self.series[key] = scatter + self.draw() + + def add_line(self, key, x, y, label, color, width): + if key in self.series: + self.series[key].remove() + (line,) = self.axes.plot(x, y, label=label, color=color, linewidth=width) + self.series[key] = line + self.draw() + + def clear(self): + self.axes.cla() + self.series.clear() + self.draw() diff --git a/sleap/gui/widgets/training_monitor.py b/sleap/gui/widgets/training_monitor.py index 80c231f39..92eb342e8 100644 --- a/sleap/gui/widgets/training_monitor.py +++ b/sleap/gui/widgets/training_monitor.py @@ -9,10 +9,7 @@ from typing import Optional from qtpy import QtCore, QtWidgets, QtGui import attr -import matplotlib.pyplot as plt - -plt.use("QtAgg") -from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from sleap.gui.widgets.mpl import MplCanvas logger = logging.getLogger(__name__) @@ -82,16 +79,22 @@ def reset( what: String identifier indicating which job type the current run corresponds to. """ - self.fig, self.ax = plt.subplots() - self.canvas = FigureCanvas(self.fig) + self.canvas = MplCanvas(width=5, height=4, dpi=100) self.setCentralWidget(self.canvas) - self.series = dict() - COLOR_TRAIN = (18, 158, 220) COLOR_VAL = (248, 167, 52) COLOR_BEST_VAL = (151, 204, 89) + # Add scatter series for batch training loss + self.canvas.add_scatter( + key="batch", + x=x_batch, + y=y_batch, + label="Batch Training Loss", + color=COLOR_TRAIN + "80", # Color with alpha + size=64, + ) self.series["batch"] = QtCharts.QScatterSeries() self.series["batch"].setName("Batch Training Loss") self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48)) diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 6a64e43b6..97fe7573c 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -508,7 +508,7 @@ def setup_visualization( callbacks = [] try: - matplotlib.use("Qt5Agg") + matplotlib.use("QtAgg") except ImportError: print( "Unable to use Qt backend for matplotlib. "