Skip to content

Commit

Permalink
toward a working implementation of axes weakkeydict storage
Browse files Browse the repository at this point in the history
  • Loading branch information
bmcfee committed May 9, 2024
1 parent 81980ce commit ac9c9d4
Showing 1 changed file with 62 additions and 31 deletions.
93 changes: 62 additions & 31 deletions mir_eval/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Display functions"""

from collections import defaultdict
from weakref import WeakKeyDictionary

import numpy as np
from scipy.signal import spectrogram
Expand All @@ -16,6 +17,11 @@
from .util import midi_to_hz, hz_to_midi


# This dictionary is used to track mir_eval-specific attributes
# attached to matplotlib axes
__AXMAP = WeakKeyDictionary()


def __expand_limits(ax, limits, which="x"):
"""Expand axis limits"""
if which == "x":
Expand All @@ -27,7 +33,6 @@ def __expand_limits(ax, limits, which="x"):

old_lims = getter()
new_lims = list(limits)

# infinite limits occur on new axis objects with no data
if np.isfinite(old_lims[0]):
new_lims[0] = min(old_lims[0], limits[0])
Expand Down Expand Up @@ -62,18 +67,21 @@ def __get_axes(ax=None, fig=None):
"""
new_axes = False

if ax is not None:
return ax, new_axes
if ax is None:
if fig is None:
import matplotlib.pyplot as plt

if fig is None:
import matplotlib.pyplot as plt
fig = plt.gcf()

fig = plt.gcf()
if not fig.get_axes():
new_axes = True
ax = fig.gca()

if not fig.get_axes():
new_axes = True
# Create a storage bucket for this axes in case we need it
if ax not in __AXMAP:
__AXMAP[ax] = dict()

return fig.gca(), new_axes
return ax, new_axes


def segments(
Expand Down Expand Up @@ -145,17 +153,21 @@ def segments(
if height is None:
height = ax.get_ylim()[1]

#cycler = ax._get_patches_for_fill.prop_cycler
# cycler = ax._get_patches_for_fill.prop_cycler

seg_map = dict()

for lab in labels:
if lab in seg_map:
continue

#style = next(cycler)
# style = next(cycler)
_bar = ax.bar([0], [0], visible=False)
style = {k: v for k, v in _bar[0].properties().items() if k in ["facecolor", "edgecolor", "linewidth"]}
style = {
k: v
for k, v in _bar[0].properties().items()
if k in ["facecolor", "edgecolor", "linewidth"]
}
_bar.remove()
seg_map[lab] = seg_def_style.copy()
seg_map[lab].update(style)
Expand Down Expand Up @@ -262,30 +274,34 @@ def labeled_intervals(
# Make sure we have a numpy array
intervals = np.atleast_2d(intervals)

if not new_axes:
if label_set is None:
# If we have non-empty pre-existing tick labels, use them
label_set = [_.get_text() for _ in ax.get_yticklabels()]
# If none of the label strings have content, treat it as empty
if not any(label_set):
label_set = []
else:
label_set = list(label_set)
if label_set is None:
# If we have non-empty pre-existing tick labels, use them
# If none of the label strings have content, treat it as empty
label_set = __AXMAP[ax].get("labels", [])
else:
label_set = []
label_set = list(label_set)

# Put additional labels at the end, in order
if extend_labels and not new_axes:
extended = False
if extend_labels:
ticks = label_set + sorted(set(labels) - set(label_set))
if ticks != label_set and len(label_set) > 0:
extended = True
elif label_set:
ticks = label_set
else:
ticks = sorted(set(labels))

# Push the ticks up into the axmap
__AXMAP[ax]["labels"] = ticks

style = dict(linewidth=1)

# TODO: now that we have axmap, could use an alternative cycler handle
# Swap color -> facecolor here so we preserve edgecolor on rects
_bar = ax.barh([0], [0], visible=False)
# XXX: phony bar plot here is located at 0.5, 0.5 to avoid triggering limit changes
# this is a kludge.
_bar = ax.barh([0.5], [0.5], visible=False)
style.update(facecolor=_bar.patches[0].get_facecolor())
_bar.remove()
style.update(kwargs)
Expand Down Expand Up @@ -316,7 +332,7 @@ def labeled_intervals(
style.pop("label", None)

# Draw a line separating the new labels from pre-existing labels
if label_set != ticks:
if extended:
ax.axhline(len(label_set), color="k", alpha=0.5)

if tick:
Expand Down Expand Up @@ -399,7 +415,7 @@ def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs):
# Reverse the patch ordering for anything we've added.
# This way, intervals are listed in the legend from top to bottom
# FIXME: this no longer works
#ax.patches[n_patches:] = ax.patches[n_patches:][::-1]
# ax.patches[n_patches:] = ax.patches[n_patches:][::-1]
return ax


Expand Down Expand Up @@ -462,7 +478,11 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, **
raise ValueError("When specifying base or height, both must be provided.")

_plot = ax.plot([], [], visible=False)[0]
style = {k: v for k, v in _plot.properties().items() if k in ["color", "linestyle", "linewidth"]}
style = {
k: v
for k, v in _plot.properties().items()
if k in ["color", "linestyle", "linewidth"]
}
_plot.remove()
style.update(kwargs)

Expand Down Expand Up @@ -718,8 +738,17 @@ def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs):
return ax


def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, rasterized=True,
edgecolors="None", shading="gouraud", **kwargs):
def separation(
sources,
fs=22050,
labels=None,
alpha=0.75,
ax=None,
rasterized=True,
edgecolors="None",
shading="gouraud",
**kwargs
):
"""Source-separation visualization
Parameters
Expand Down Expand Up @@ -785,7 +814,7 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, rasterized=T
# For each source, grab a new color from the cycler
# Then construct a colormap that interpolates from
# [transparent white -> new color]
#
#
# To access the cycler, we'll create a temporary bar plot,
# pull its facecolor, and then remove it from the axes.
_bar = ax.bar([times.min()], [freqs.min()], visible=False)
Expand All @@ -809,7 +838,9 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, rasterized=T

# Attach a 0x0 rect to the axis with the corresponding label
# This way, it will show up in the legend
ax.add_patch(Rectangle((times.min(), freqs.min()), 0, 0, color=color, label=labels[i]))
ax.add_patch(
Rectangle((times.min(), freqs.min()), 0, 0, color=color, label=labels[i])
)

if new_axes:
# Set the axis limits to match the spectrogram parameters
Expand Down

0 comments on commit ac9c9d4

Please sign in to comment.