Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/elizabeth/update-python-and-depe…
Browse files Browse the repository at this point in the history
…ndencies' into elizabeth/update-python-and-dependencies
  • Loading branch information
eberrigan committed Aug 2, 2024
2 parents 9af7619 + 51b13ba commit b42ad2b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
24 changes: 22 additions & 2 deletions sleap/gui/widgets/mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from qtpy import QtWidgets
from matplotlib.figure import Figure
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as Canvas
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as Canvas
import matplotlib

# Ensure using PyQt5 backend
matplotlib.use("QT5Agg")
matplotlib.use("QtAgg")


class MplCanvas(Canvas):
Expand All @@ -26,3 +26,23 @@ def __init__(self, width=5, height=4, dpi=100):
self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding
)
Canvas.updateGeometry(self)
self.series = {}

def add_scatter(self, key, x, y, label, color, size):
if key in self.series:
self.series[key].remove()
scatter = self.axes.scatter(x, y, label=label, color=color, s=size)
self.series[key] = scatter
self.draw()

def add_line(self, key, x, y, label, color, width):
if key in self.series:
self.series[key].remove()
(line,) = self.axes.plot(x, y, label=label, color=color, linewidth=width)
self.series[key] = line
self.draw()

def clear(self):
self.axes.cla()
self.series.clear()
self.draw()
19 changes: 11 additions & 8 deletions sleap/gui/widgets/training_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
from typing import Optional
from qtpy import QtCore, QtWidgets, QtGui
import attr
import matplotlib.pyplot as plt

plt.use("QtAgg")
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
from sleap.gui.widgets.mpl import MplCanvas

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,16 +79,22 @@ def reset(
what: String identifier indicating which job type the current run
corresponds to.
"""
self.fig, self.ax = plt.subplots()
self.canvas = FigureCanvas(self.fig)
self.canvas = MplCanvas(width=5, height=4, dpi=100)
self.setCentralWidget(self.canvas)

self.series = dict()

COLOR_TRAIN = (18, 158, 220)
COLOR_VAL = (248, 167, 52)
COLOR_BEST_VAL = (151, 204, 89)

# Add scatter series for batch training loss
self.canvas.add_scatter(
key="batch",
x=x_batch,
y=y_batch,
label="Batch Training Loss",
color=COLOR_TRAIN + "80", # Color with alpha
size=64,
)
self.series["batch"] = QtCharts.QScatterSeries()
self.series["batch"].setName("Batch Training Loss")
self.series["batch"].setColor(QtGui.QColor(*COLOR_TRAIN, 48))
Expand Down
2 changes: 1 addition & 1 deletion sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def setup_visualization(
callbacks = []

try:
matplotlib.use("Qt5Agg")
matplotlib.use("QtAgg")
except ImportError:
print(
"Unable to use Qt backend for matplotlib. "
Expand Down

0 comments on commit b42ad2b

Please sign in to comment.