diff --git a/mir_eval/display.py b/mir_eval/display.py index edd92ff9..8a276845 100644 --- a/mir_eval/display.py +++ b/mir_eval/display.py @@ -2,6 +2,7 @@ """Display functions""" from collections import defaultdict +from weakref import WeakKeyDictionary import numpy as np from scipy.signal import spectrogram @@ -16,6 +17,11 @@ from .util import midi_to_hz, hz_to_midi +# This dictionary is used to track mir_eval-specific attributes +# attached to matplotlib axes +__AXMAP = WeakKeyDictionary() + + def __expand_limits(ax, limits, which="x"): """Expand axis limits""" if which == "x": @@ -27,7 +33,6 @@ def __expand_limits(ax, limits, which="x"): old_lims = getter() new_lims = list(limits) - # infinite limits occur on new axis objects with no data if np.isfinite(old_lims[0]): new_lims[0] = min(old_lims[0], limits[0]) @@ -62,18 +67,21 @@ def __get_axes(ax=None, fig=None): """ new_axes = False - if ax is not None: - return ax, new_axes + if ax is None: + if fig is None: + import matplotlib.pyplot as plt - if fig is None: - import matplotlib.pyplot as plt + fig = plt.gcf() - fig = plt.gcf() + if not fig.get_axes(): + new_axes = True + ax = fig.gca() - if not fig.get_axes(): - new_axes = True + # Create a storage bucket for this axes in case we need it + if ax not in __AXMAP: + __AXMAP[ax] = dict() - return fig.gca(), new_axes + return ax, new_axes def segments( @@ -145,7 +153,7 @@ def segments( if height is None: height = ax.get_ylim()[1] - #cycler = ax._get_patches_for_fill.prop_cycler + # cycler = ax._get_patches_for_fill.prop_cycler seg_map = dict() @@ -153,9 +161,13 @@ def segments( if lab in seg_map: continue - #style = next(cycler) + # style = next(cycler) _bar = ax.bar([0], [0], visible=False) - style = {k: v for k, v in _bar[0].properties().items() if k in ["facecolor", "edgecolor", "linewidth"]} + style = { + k: v + for k, v in _bar[0].properties().items() + if k in ["facecolor", "edgecolor", "linewidth"] + } _bar.remove() seg_map[lab] = seg_def_style.copy() seg_map[lab].update(style) @@ -262,30 +274,34 @@ def labeled_intervals( # Make sure we have a numpy array intervals = np.atleast_2d(intervals) - if not new_axes: - if label_set is None: - # If we have non-empty pre-existing tick labels, use them - label_set = [_.get_text() for _ in ax.get_yticklabels()] - # If none of the label strings have content, treat it as empty - if not any(label_set): - label_set = [] - else: - label_set = list(label_set) + if label_set is None: + # If we have non-empty pre-existing tick labels, use them + # If none of the label strings have content, treat it as empty + label_set = __AXMAP[ax].get("labels", []) else: - label_set = [] + label_set = list(label_set) # Put additional labels at the end, in order - if extend_labels and not new_axes: + extended = False + if extend_labels: ticks = label_set + sorted(set(labels) - set(label_set)) + if ticks != label_set and len(label_set) > 0: + extended = True elif label_set: ticks = label_set else: ticks = sorted(set(labels)) + # Push the ticks up into the axmap + __AXMAP[ax]["labels"] = ticks + style = dict(linewidth=1) + # TODO: now that we have axmap, could use an alternative cycler handle # Swap color -> facecolor here so we preserve edgecolor on rects - _bar = ax.barh([0], [0], visible=False) + # XXX: phony bar plot here is located at 0.5, 0.5 to avoid triggering limit changes + # this is a kludge. + _bar = ax.barh([0.5], [0.5], visible=False) style.update(facecolor=_bar.patches[0].get_facecolor()) _bar.remove() style.update(kwargs) @@ -316,7 +332,7 @@ def labeled_intervals( style.pop("label", None) # Draw a line separating the new labels from pre-existing labels - if label_set != ticks: + if extended: ax.axhline(len(label_set), color="k", alpha=0.5) if tick: @@ -399,7 +415,7 @@ def hierarchy(intervals_hier, labels_hier, levels=None, ax=None, **kwargs): # Reverse the patch ordering for anything we've added. # This way, intervals are listed in the legend from top to bottom # FIXME: this no longer works - #ax.patches[n_patches:] = ax.patches[n_patches:][::-1] + # ax.patches[n_patches:] = ax.patches[n_patches:][::-1] return ax @@ -462,7 +478,11 @@ def events(times, labels=None, base=None, height=None, ax=None, text_kw=None, ** raise ValueError("When specifying base or height, both must be provided.") _plot = ax.plot([], [], visible=False)[0] - style = {k: v for k, v in _plot.properties().items() if k in ["color", "linestyle", "linewidth"]} + style = { + k: v + for k, v in _plot.properties().items() + if k in ["color", "linestyle", "linewidth"] + } _plot.remove() style.update(kwargs) @@ -718,8 +738,17 @@ def piano_roll(intervals, pitches=None, midi=None, ax=None, **kwargs): return ax -def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, rasterized=True, - edgecolors="None", shading="gouraud", **kwargs): +def separation( + sources, + fs=22050, + labels=None, + alpha=0.75, + ax=None, + rasterized=True, + edgecolors="None", + shading="gouraud", + **kwargs +): """Source-separation visualization Parameters @@ -785,7 +814,7 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, rasterized=T # For each source, grab a new color from the cycler # Then construct a colormap that interpolates from # [transparent white -> new color] - # + # # To access the cycler, we'll create a temporary bar plot, # pull its facecolor, and then remove it from the axes. _bar = ax.bar([times.min()], [freqs.min()], visible=False) @@ -809,7 +838,9 @@ def separation(sources, fs=22050, labels=None, alpha=0.75, ax=None, rasterized=T # Attach a 0x0 rect to the axis with the corresponding label # This way, it will show up in the legend - ax.add_patch(Rectangle((times.min(), freqs.min()), 0, 0, color=color, label=labels[i])) + ax.add_patch( + Rectangle((times.min(), freqs.min()), 0, 0, color=color, label=labels[i]) + ) if new_axes: # Set the axis limits to match the spectrogram parameters