From 51b13bae1b32b2a2866052839a26bce6b557d25a Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan <berri104@gmail.com> Date: Thu, 1 Aug 2024 15:20:53 -0700 Subject: [PATCH] start replacing qtcharts with mpl --- sleap/gui/widgets/training_monitor.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sleap/gui/widgets/training_monitor.py b/sleap/gui/widgets/training_monitor.py index 80c231f39..92eb342e8 100644 --- a/sleap/gui/widgets/training_monitor.py +++ b/sleap/gui/widgets/training_monitor.py @@ -9,10 +9,7 @@ from typing import Optional from qtpy import QtCore, QtWidgets, QtGui import attr -import matplotlib.pyplot as plt - -plt.use("QtAgg") -from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from sleap.gui.widgets.mpl import MplCanvas logger = logging.getLogger(__name__) @@ -82,16 +79,22 @@ def reset( what: String identifier indicating which job type the current run corresponds to. """ - self.fig, self.ax = plt.subplots() - self.canvas = FigureCanvas(self.fig) + self.canvas = MplCanvas(width=5, height=4, dpi=100) self.setCentralWidget(self.canvas) - self.series = dict() - COLOR_TRAIN = (18, 158, 220) COLOR_VAL = (248, 167, 52) COLOR_BEST_VAL = (151, 204, 89) + # Add scatter series for batch training loss + self.canvas.add_scatter( + key="batch", + x=x_batch, + y=y_batch, + label="Batch Training Loss", + color=COLOR_TRAIN + "80", # Color with alpha + size=64, + ) self.series["batch"] = QtCharts.QScatterSeries() self.series["batch"].setName("Batch Training Loss") self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48))