From 51b13bae1b32b2a2866052839a26bce6b557d25a Mon Sep 17 00:00:00 2001
From: Elizabeth Berrigan <berri104@gmail.com>
Date: Thu, 1 Aug 2024 15:20:53 -0700
Subject: [PATCH] start replacing qtcharts with mpl

---
 sleap/gui/widgets/training_monitor.py | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/sleap/gui/widgets/training_monitor.py b/sleap/gui/widgets/training_monitor.py
index 80c231f39..92eb342e8 100644
--- a/sleap/gui/widgets/training_monitor.py
+++ b/sleap/gui/widgets/training_monitor.py
@@ -9,10 +9,7 @@
 from typing import Optional
 from qtpy import QtCore, QtWidgets, QtGui
 import attr
-import matplotlib.pyplot as plt
-
-plt.use("QtAgg")
-from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
+from sleap.gui.widgets.mpl import MplCanvas
 
 logger = logging.getLogger(__name__)
 
@@ -82,16 +79,22 @@ def reset(
             what: String identifier indicating which job type the current run
                 corresponds to.
         """
-        self.fig, self.ax = plt.subplots()
-        self.canvas = FigureCanvas(self.fig)
+        self.canvas = MplCanvas(width=5, height=4, dpi=100)
         self.setCentralWidget(self.canvas)
 
-        self.series = dict()
-
         COLOR_TRAIN = (18, 158, 220)
         COLOR_VAL = (248, 167, 52)
         COLOR_BEST_VAL = (151, 204, 89)
 
+        # Add scatter series for batch training loss
+        self.canvas.add_scatter(
+            key="batch",
+            x=x_batch,
+            y=y_batch,
+            label="Batch Training Loss",
+            color=COLOR_TRAIN + "80",  # Color with alpha
+            size=64,
+        )
         self.series["batch"] = QtCharts.QScatterSeries()
         self.series["batch"].setName("Batch Training Loss")
         self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48))