diff --git a/sleap/gui/widgets/training_monitor.py b/sleap/gui/widgets/training_monitor.py index 80c231f39..92eb342e8 100644 --- a/sleap/gui/widgets/training_monitor.py +++ b/sleap/gui/widgets/training_monitor.py @@ -9,10 +9,7 @@ from typing import Optional from qtpy import QtCore, QtWidgets, QtGui import attr -import matplotlib.pyplot as plt - -plt.use("QtAgg") -from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from sleap.gui.widgets.mpl import MplCanvas logger = logging.getLogger(__name__) @@ -82,16 +79,22 @@ def reset( what: String identifier indicating which job type the current run corresponds to. """ - self.fig, self.ax = plt.subplots() - self.canvas = FigureCanvas(self.fig) + self.canvas = MplCanvas(width=5, height=4, dpi=100) self.setCentralWidget(self.canvas) - self.series = dict() - COLOR_TRAIN = (18, 158, 220) COLOR_VAL = (248, 167, 52) COLOR_BEST_VAL = (151, 204, 89) + # Add scatter series for batch training loss + self.canvas.add_scatter( + key="batch", + x=x_batch, + y=y_batch, + label="Batch Training Loss", + color=COLOR_TRAIN + "80", # Color with alpha + size=64, + ) self.series["batch"] = QtCharts.QScatterSeries() self.series["batch"].setName("Batch Training Loss") self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48))