From 724a2347ebe3d2eb06c51cbf58a2fd1a6a424d9a Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 12:42:58 -0400 Subject: [PATCH 01/10] validation warnings set stacklevel=2, fixes #372 --- mir_eval/beat.py | 4 ++-- mir_eval/chord.py | 4 ++-- mir_eval/hierarchy.py | 2 +- mir_eval/melody.py | 12 ++++++------ mir_eval/multipitch.py | 8 ++++---- mir_eval/onset.py | 4 ++-- mir_eval/pattern.py | 4 ++-- mir_eval/segment.py | 8 ++++---- mir_eval/separation.py | 4 ++-- mir_eval/tempo.py | 2 +- mir_eval/transcription.py | 4 ++-- 11 files changed, 28 insertions(+), 28 deletions(-) diff --git a/mir_eval/beat.py b/mir_eval/beat.py index 21211a07..83221e72 100644 --- a/mir_eval/beat.py +++ b/mir_eval/beat.py @@ -88,9 +88,9 @@ def validate(reference_beats, estimated_beats): # If reference or estimated beats are empty, # warn because metric will be 0 if reference_beats.size == 0: - warnings.warn("Reference beats are empty.") + warnings.warn("Reference beats are empty.", stacklevel=2) if estimated_beats.size == 0: - warnings.warn("Estimated beats are empty.") + warnings.warn("Estimated beats are empty.", stacklevel=2) for beats in [reference_beats, estimated_beats]: util.validate_events(beats, MAX_TIME) diff --git a/mir_eval/chord.py b/mir_eval/chord.py index 69939e7a..5a5a9a23 100644 --- a/mir_eval/chord.py +++ b/mir_eval/chord.py @@ -642,9 +642,9 @@ def validate(reference_labels, estimated_labels): validate_chord_label(chord_label) # When either label list is empty, warn the user if len(reference_labels) == 0: - warnings.warn("Reference labels are empty") + warnings.warn("Reference labels are empty", stacklevel=2) if len(estimated_labels) == 0: - warnings.warn("Estimated labels are empty") + warnings.warn("Estimated labels are empty", stacklevel=2) def weighted_accuracy(comparisons, weights): diff --git a/mir_eval/hierarchy.py b/mir_eval/hierarchy.py index c5ad5be5..9df30a9c 100644 --- a/mir_eval/hierarchy.py +++ b/mir_eval/hierarchy.py @@ -459,7 +459,7 @@ def validate_hier_intervals(intervals_hier): if boundaries - new_bounds: warnings.warn( - "Segment hierarchy is inconsistent " "at level {:d}".format(level) + "Segment hierarchy is inconsistent " "at level {:d}".format(level), stacklevel=2 ) boundaries |= new_bounds diff --git a/mir_eval/melody.py b/mir_eval/melody.py index ab4d6e42..ec2632f2 100644 --- a/mir_eval/melody.py +++ b/mir_eval/melody.py @@ -82,13 +82,13 @@ def validate_voicing(ref_voicing, est_voicing): """ if ref_voicing.size == 0: - warnings.warn("Reference voicing array is empty.") + warnings.warn("Reference voicing array is empty.", stacklevel=2) if est_voicing.size == 0: - warnings.warn("Estimated voicing array is empty.") + warnings.warn("Estimated voicing array is empty.", stacklevel=2) if ref_voicing.sum() == 0: - warnings.warn("Reference melody has no voiced frames.") + warnings.warn("Reference melody has no voiced frames.", stacklevel=2) if est_voicing.sum() == 0: - warnings.warn("Estimated melody has no voiced frames.") + warnings.warn("Estimated melody has no voiced frames.", stacklevel=2) # Make sure they're the same length if ref_voicing.shape[0] != est_voicing.shape[0]: raise ValueError( @@ -117,9 +117,9 @@ def validate(ref_voicing, ref_cent, est_voicing, est_cent): """ if ref_cent.size == 0: - warnings.warn("Reference frequency array is empty.") + warnings.warn("Reference frequency array is empty.", stacklevel=2) if est_cent.size == 0: - warnings.warn("Estimated frequency array is empty.") + warnings.warn("Estimated frequency array is empty.", stacklevel=2) # Make sure they're the same length if ( ref_voicing.shape[0] != ref_cent.shape[0] diff --git a/mir_eval/multipitch.py b/mir_eval/multipitch.py index bc74f3de..456417bf 100644 --- a/mir_eval/multipitch.py +++ b/mir_eval/multipitch.py @@ -72,17 +72,17 @@ def validate(ref_time, ref_freqs, est_time, est_freqs): util.validate_events(est_time, max_time=MAX_TIME) if ref_time.size == 0: - warnings.warn("Reference times are empty.") + warnings.warn("Reference times are empty.", stacklevel=2) if ref_time.ndim != 1: raise ValueError("Reference times have invalid dimension") if len(ref_freqs) == 0: - warnings.warn("Reference frequencies are empty.") + warnings.warn("Reference frequencies are empty.", stacklevel=2) if est_time.size == 0: - warnings.warn("Estimated times are empty.") + warnings.warn("Estimated times are empty.", stacklevel=2) if est_time.ndim != 1: raise ValueError("Estimated times have invalid dimension") if len(est_freqs) == 0: - warnings.warn("Estimated frequencies are empty.") + warnings.warn("Estimated frequencies are empty.", stacklevel=2) if ref_time.size != len(ref_freqs): raise ValueError("Reference times and frequencies have unequal " "lengths.") if est_time.size != len(est_freqs): diff --git a/mir_eval/onset.py b/mir_eval/onset.py index d3437a33..a315f23b 100644 --- a/mir_eval/onset.py +++ b/mir_eval/onset.py @@ -46,9 +46,9 @@ def validate(reference_onsets, estimated_onsets): """ # If reference or estimated onsets are empty, warn because metric will be 0 if reference_onsets.size == 0: - warnings.warn("Reference onsets are empty.") + warnings.warn("Reference onsets are empty.", stacklevel=2) if estimated_onsets.size == 0: - warnings.warn("Estimated onsets are empty.") + warnings.warn("Estimated onsets are empty.", stacklevel=2) for onsets in [reference_onsets, estimated_onsets]: util.validate_events(onsets, MAX_TIME) diff --git a/mir_eval/pattern.py b/mir_eval/pattern.py index 1dc75436..cc05fc8f 100644 --- a/mir_eval/pattern.py +++ b/mir_eval/pattern.py @@ -92,9 +92,9 @@ def validate(reference_patterns, estimated_patterns): """ # Warn if pattern lists are empty if _n_onset_midi(reference_patterns) == 0: - warnings.warn("Reference patterns are empty.") + warnings.warn("Reference patterns are empty.", stacklevel=2) if _n_onset_midi(estimated_patterns) == 0: - warnings.warn("Estimated patterns are empty.") + warnings.warn("Estimated patterns are empty.", stacklevel=2) for patterns in [reference_patterns, estimated_patterns]: for pattern in patterns: if len(pattern) <= 0: diff --git a/mir_eval/segment.py b/mir_eval/segment.py index 7a49d6ff..7b03f596 100644 --- a/mir_eval/segment.py +++ b/mir_eval/segment.py @@ -110,10 +110,10 @@ def validate_boundary(reference_intervals, estimated_intervals, trim): min_size = 1 if len(reference_intervals) < min_size: - warnings.warn("Reference intervals are empty.") + warnings.warn("Reference intervals are empty.", stacklevel=2) if len(estimated_intervals) < min_size: - warnings.warn("Estimated intervals are empty.") + warnings.warn("Estimated intervals are empty.", stacklevel=2) for intervals in [reference_intervals, estimated_intervals]: util.validate_intervals(intervals) @@ -157,9 +157,9 @@ def validate_structure( raise ValueError("Segment intervals do not start at 0") if reference_intervals.size == 0: - warnings.warn("Reference intervals are empty.") + warnings.warn("Reference intervals are empty.", stacklevel=2) if estimated_intervals.size == 0: - warnings.warn("Estimated intervals are empty.") + warnings.warn("Estimated intervals are empty.", stacklevel=2) # Check only when intervals are non-empty if reference_intervals.size > 0 and estimated_intervals.size > 0: if not np.allclose(reference_intervals.max(), estimated_intervals.max()): diff --git a/mir_eval/separation.py b/mir_eval/separation.py index 0bb0704e..fdd9bbfd 100644 --- a/mir_eval/separation.py +++ b/mir_eval/separation.py @@ -91,7 +91,7 @@ def validate(reference_sources, estimated_sources): warnings.warn( "reference_sources is empty, should be of size " "(nsrc, nsample). sdr, sir, sar, and perm will all " - "be empty np.ndarrays" + "be empty np.ndarrays", stacklevel=2 ) elif _any_source_silent(reference_sources): raise ValueError( @@ -106,7 +106,7 @@ def validate(reference_sources, estimated_sources): warnings.warn( "estimated_sources is empty, should be of size " "(nsrc, nsample). sdr, sir, sar, and perm will all " - "be empty np.ndarrays" + "be empty np.ndarrays", stacklevel=2 ) elif _any_source_silent(estimated_sources): raise ValueError( diff --git a/mir_eval/tempo.py b/mir_eval/tempo.py index aa090da8..5fb3e1f7 100644 --- a/mir_eval/tempo.py +++ b/mir_eval/tempo.py @@ -114,7 +114,7 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): "invalid tolerance {}: must lie in the range " "[0, 1]".format(tol) ) if tol == 0.0: - warnings.warn("A tolerance of 0.0 may not " "lead to the results you expect.") + warnings.warn("A tolerance of 0.0 may not " "lead to the results you expect.", stacklevel=2) hits = [False, False] diff --git a/mir_eval/transcription.py b/mir_eval/transcription.py index 65504279..bdb436ae 100644 --- a/mir_eval/transcription.py +++ b/mir_eval/transcription.py @@ -158,9 +158,9 @@ def validate_intervals(ref_intervals, est_intervals): """ # If reference or estimated notes are empty, warn if ref_intervals.size == 0: - warnings.warn("Reference notes are empty.") + warnings.warn("Reference notes are empty.", stacklevel=2) if est_intervals.size == 0: - warnings.warn("Estimated notes are empty.") + warnings.warn("Estimated notes are empty.", stacklevel=2) # Validate intervals util.validate_intervals(ref_intervals) From 692b84ede7ae51dabd3eb11264622dfcc903c6b5 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 12:56:54 -0400 Subject: [PATCH 02/10] blacked the package --- mir_eval/hierarchy.py | 3 ++- mir_eval/separation.py | 6 ++++-- mir_eval/tempo.py | 5 ++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mir_eval/hierarchy.py b/mir_eval/hierarchy.py index 9df30a9c..1e79ba1a 100644 --- a/mir_eval/hierarchy.py +++ b/mir_eval/hierarchy.py @@ -459,7 +459,8 @@ def validate_hier_intervals(intervals_hier): if boundaries - new_bounds: warnings.warn( - "Segment hierarchy is inconsistent " "at level {:d}".format(level), stacklevel=2 + "Segment hierarchy is inconsistent " "at level {:d}".format(level), + stacklevel=2, ) boundaries |= new_bounds diff --git a/mir_eval/separation.py b/mir_eval/separation.py index fdd9bbfd..facb98db 100644 --- a/mir_eval/separation.py +++ b/mir_eval/separation.py @@ -91,7 +91,8 @@ def validate(reference_sources, estimated_sources): warnings.warn( "reference_sources is empty, should be of size " "(nsrc, nsample). sdr, sir, sar, and perm will all " - "be empty np.ndarrays", stacklevel=2 + "be empty np.ndarrays", + stacklevel=2, ) elif _any_source_silent(reference_sources): raise ValueError( @@ -106,7 +107,8 @@ def validate(reference_sources, estimated_sources): warnings.warn( "estimated_sources is empty, should be of size " "(nsrc, nsample). sdr, sir, sar, and perm will all " - "be empty np.ndarrays", stacklevel=2 + "be empty np.ndarrays", + stacklevel=2, ) elif _any_source_silent(estimated_sources): raise ValueError( diff --git a/mir_eval/tempo.py b/mir_eval/tempo.py index 5fb3e1f7..15050364 100644 --- a/mir_eval/tempo.py +++ b/mir_eval/tempo.py @@ -114,7 +114,10 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): "invalid tolerance {}: must lie in the range " "[0, 1]".format(tol) ) if tol == 0.0: - warnings.warn("A tolerance of 0.0 may not " "lead to the results you expect.", stacklevel=2) + warnings.warn( + "A tolerance of 0.0 may not " "lead to the results you expect.", + stacklevel=2, + ) hits = [False, False] From 51ee846be2c9f001f2eef5d6a55241fcb54795f1 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 16:01:11 -0400 Subject: [PATCH 03/10] pulling safe= parameter through evaluators --- mir_eval/alignment.py | 41 +++++++-- mir_eval/beat.py | 71 ++++++++++---- mir_eval/chord.py | 208 ++++++++++++++++++++++++++++++------------ 3 files changed, 238 insertions(+), 82 deletions(-) diff --git a/mir_eval/alignment.py b/mir_eval/alignment.py index e8bf22fe..cd231579 100644 --- a/mir_eval/alignment.py +++ b/mir_eval/alignment.py @@ -141,7 +141,9 @@ def absolute_error(reference_timestamps, estimated_timestamps): return np.median(deviations), np.mean(deviations) -def percentage_correct(reference_timestamps, estimated_timestamps, window=0.3): +def percentage_correct( + reference_timestamps, estimated_timestamps, window=0.3, safe=True +): """Compute the percentage of correctly predicted timestamps. A timestamp is predicted correctly if its position doesn't deviate more than the window parameter from the ground truth timestamp. @@ -161,19 +163,27 @@ def percentage_correct(reference_timestamps, estimated_timestamps, window=0.3): window : float Window size, in seconds (Default value = .3) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- pc : float Percentage of correct timestamps """ - validate(reference_timestamps, estimated_timestamps) + if safe: + validate(reference_timestamps, estimated_timestamps) + deviations = np.abs(reference_timestamps - estimated_timestamps) return np.mean(deviations <= window) def percentage_correct_segments( - reference_timestamps, estimated_timestamps, duration: Optional[float] = None + reference_timestamps, + estimated_timestamps, + duration: Optional[float] = None, + safe=True, ): """Calculate the percentage of correct segments (PCS) metric. @@ -215,13 +225,17 @@ def percentage_correct_segments( duration : float Optional. Total duration of audio (seconds). WARNING: Metric is computed differently depending on whether this is provided or not - see documentation above! + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- pcs : float Percentage of time where ground truth and predicted segments overlap """ - validate(reference_timestamps, estimated_timestamps) + if safe: + validate(reference_timestamps, estimated_timestamps) if duration is not None: duration = float(duration) if duration <= 0: @@ -266,7 +280,7 @@ def percentage_correct_segments( return overlap_duration / duration -def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps): +def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps, safe=True): """Metric based on human synchronicity perception as measured in the paper "User-centered evaluation of lyrics to audio alignment" [#lizemasclef2021] @@ -288,13 +302,17 @@ def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps): reference timestamps, in seconds estimated_timestamps : np.ndarray estimated timestamps, in seconds + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- perceptual_score : float Perceptual score, averaged over all timestamps """ - validate(reference_timestamps, estimated_timestamps) + if safe: + validate(reference_timestamps, estimated_timestamps) offsets = estimated_timestamps - reference_timestamps # Score offsets using a certain skewed normal distribution @@ -309,7 +327,7 @@ def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps): return np.mean(perceptual_scores) -def evaluate(reference_timestamps, estimated_timestamps, **kwargs): +def evaluate(reference_timestamps, estimated_timestamps, safe=True, **kwargs): """Compute all metrics for the given reference and estimated annotations. Examples @@ -328,6 +346,9 @@ def evaluate(reference_timestamps, estimated_timestamps, **kwargs): **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -335,6 +356,12 @@ def evaluate(reference_timestamps, estimated_timestamps, **kwargs): Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. """ + if safe: + validate(reference_timestamps, estimated_timestamps) + + # We can now bypass validation + kwargs["safe"] = False + # Compute all metrics scores = collections.OrderedDict() diff --git a/mir_eval/beat.py b/mir_eval/beat.py index 83221e72..ce298e7a 100644 --- a/mir_eval/beat.py +++ b/mir_eval/beat.py @@ -133,7 +133,7 @@ def _get_reference_beat_variations(reference_beats): ) -def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07): +def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07, safe=True): """Compute the F-measure of correct vs incorrectly predicted beats. "Correctness" is determined over a small window. @@ -144,7 +144,7 @@ def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07): >>> estimated_beats = mir_eval.io.load_events('estimated.txt') >>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats) >>> f_measure = mir_eval.beat.f_measure(reference_beats, - estimated_beats) + ... estimated_beats) Parameters ---------- @@ -155,6 +155,9 @@ def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07): f_measure_threshold : float Window size, in seconds (Default value = 0.07) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -162,7 +165,8 @@ def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07): The computed F-measure score """ - validate(reference_beats, estimated_beats) + if safe: + validate(reference_beats, estimated_beats) # When estimated beats are empty, no beats are correct; metric is 0 if estimated_beats.size == 0 or reference_beats.size == 0: return 0.0 @@ -174,7 +178,7 @@ def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07): return util.f_measure(precision, recall) -def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04): +def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04, safe=True): """Cemgil's score, computes a gaussian error of each estimated beat. Compares against the original beat times and all metrical variations. @@ -185,7 +189,7 @@ def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04): >>> estimated_beats = mir_eval.io.load_events('estimated.txt') >>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats) >>> cemgil_score, cemgil_max = mir_eval.beat.cemgil(reference_beats, - estimated_beats) + ... estimated_beats) Parameters ---------- @@ -196,6 +200,9 @@ def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04): cemgil_sigma : float Sigma parameter of gaussian error windows (Default value = 0.04) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -204,7 +211,8 @@ def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04): cemgil_max : float The best Cemgil score for all metrical variations """ - validate(reference_beats, estimated_beats) + if safe: + validate(reference_beats, estimated_beats) # When estimated beats are empty, no beats are correct; metric is 0 if estimated_beats.size == 0 or reference_beats.size == 0: return 0.0, 0.0 @@ -228,7 +236,12 @@ def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04): def goto( - reference_beats, estimated_beats, goto_threshold=0.35, goto_mu=0.2, goto_sigma=0.2 + reference_beats, + estimated_beats, + goto_threshold=0.35, + goto_mu=0.2, + goto_sigma=0.2, + safe=True, ): """Calculate Goto's score, a binary 1 or 0 depending on some specific heuristic criteria @@ -258,13 +271,17 @@ def goto( The std of the beat errors in the continuously correct track must be less than this (Default value = 0.2) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- goto_score : float Either 1.0 or 0.0 if some specific criteria are met """ - validate(reference_beats, estimated_beats) + if safe: + validate(reference_beats, estimated_beats) # When estimated beats are empty, no beats are correct; metric is 0 if estimated_beats.size == 0 or reference_beats.size == 0: return 0.0 @@ -327,7 +344,7 @@ def goto( return 1.0 * (goto_criteria == 3) -def p_score(reference_beats, estimated_beats, p_score_threshold=0.2): +def p_score(reference_beats, estimated_beats, p_score_threshold=0.2, safe=True): """Get McKinney's P-score. Based on the autocorrelation of the reference and estimated beats @@ -349,6 +366,9 @@ def p_score(reference_beats, estimated_beats, p_score_threshold=0.2): Window size will be ``p_score_threshold*np.median(inter_annotation_intervals)``, (Default value = 0.2) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -356,7 +376,8 @@ def p_score(reference_beats, estimated_beats, p_score_threshold=0.2): McKinney's P-score """ - validate(reference_beats, estimated_beats) + if safe: + validate(reference_beats, estimated_beats) # Warn when only one beat is provided for either estimated or reference, # report a warning if reference_beats.size == 1: @@ -412,6 +433,7 @@ def continuity( estimated_beats, continuity_phase_threshold=0.175, continuity_period_threshold=0.175, + safe=True, ): """Get metrics based on how much of the estimated beat sequence is continually correct. @@ -439,6 +461,9 @@ def continuity( Allowable distance between the inter-beat-interval and the inter-annotation-interval (Default value = 0.175) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -451,7 +476,8 @@ def continuity( AMLt : float Any metric level, total accuracy (continuity not required) """ - validate(reference_beats, estimated_beats) + if safe: + validate(reference_beats, estimated_beats) # Warn when only one beat is provided for either estimated or reference, # report a warning if reference_beats.size == 1: @@ -583,7 +609,7 @@ def continuity( ) -def information_gain(reference_beats, estimated_beats, bins=41): +def information_gain(reference_beats, estimated_beats, bins=41, safe=True): """Get the information gain - K-L divergence of the beat error histogram to a uniform histogram @@ -594,7 +620,7 @@ def information_gain(reference_beats, estimated_beats, bins=41): >>> estimated_beats = mir_eval.io.load_events('estimated.txt') >>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats) >>> information_gain = mir_eval.beat.information_gain(reference_beats, - estimated_beats) + ... estimated_beats) Parameters ---------- @@ -605,13 +631,17 @@ def information_gain(reference_beats, estimated_beats, bins=41): bins : int Number of bins in the beat error histogram (Default value = 41) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- information_gain_score : float Entropy of beat error histogram """ - validate(reference_beats, estimated_beats) + if safe: + validate(reference_beats, estimated_beats) # If an even number of bins is provided, # there will be no bin centered at zero, so warn the user. if not bins % 2: @@ -712,7 +742,7 @@ def _get_entropy(reference_beats, estimated_beats, bins): return -np.sum(raw_bin_values * np.log2(raw_bin_values)) -def evaluate(reference_beats, estimated_beats, **kwargs): +def evaluate(reference_beats, estimated_beats, safe=True, **kwargs): """Compute all metrics for the given reference and estimated annotations. Examples @@ -727,6 +757,9 @@ def evaluate(reference_beats, estimated_beats, **kwargs): Reference beat times, in seconds estimated_beats : np.ndarray Query beat times, in seconds + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -742,8 +775,14 @@ def evaluate(reference_beats, estimated_beats, **kwargs): reference_beats = util.filter_kwargs(trim_beats, reference_beats, **kwargs) estimated_beats = util.filter_kwargs(trim_beats, estimated_beats, **kwargs) - # Now compute all the metrics + # Validate inputs + if safe: + validate(reference_beats, estimated_beats) + + # We can now bypass validation + kwargs["safe"] = False + # Now compute all the metrics scores = collections.OrderedDict() # F-Measure diff --git a/mir_eval/chord.py b/mir_eval/chord.py index 5a5a9a23..49398ebe 100644 --- a/mir_eval/chord.py +++ b/mir_eval/chord.py @@ -340,6 +340,15 @@ def reduce_extended_quality(quality): # --- Chord Label Parsing --- +# This monster regexp is pulled from the JAMS chord namespace, +# which is in turn derived from the context-free grammar of +# Harte et al., 2005. +# Just compile this regexp once +CHORD_PATTERN = re.compile( + r"""^((N|X)|(([A-G](b*|#*))((:(maj|min|dim|aug|1|5|sus2|sus4|maj6|min6|7|maj7|min7|dim7|hdim7|minmaj7|aug7|9|maj9|min9|11|maj11|min11|13|maj13|min13)(\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\))?)|(:\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\)))?((/((b*|#*)([1-9]|1[0-3]?)))?)?))$""" +) # nopep8 + + def validate_chord_label(chord_label): """Test for well-formedness of a chord label. @@ -348,17 +357,8 @@ def validate_chord_label(chord_label): chord_label : str Chord label to validate. """ - # This monster regexp is pulled from the JAMS chord namespace, - # which is in turn derived from the context-free grammar of - # Harte et al., 2005. - - pattern = re.compile( - r"""^((N|X)|(([A-G](b*|#*))((:(maj|min|dim|aug|1|5|sus2|sus4|maj6|min6|7|maj7|min7|dim7|hdim7|minmaj7|aug7|9|maj9|min9|11|maj11|min11|13|maj13|min13)(\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\))?)|(:\((\*?((b*|#*)([1-9]|1[0-3]?))(,\*?((b*|#*)([1-9]|1[0-3]?)))*)\)))?((/((b*|#*)([1-9]|1[0-3]?)))?)?))$""" - ) # nopep8 - - if not pattern.match(chord_label): + if not CHORD_PATTERN.match(chord_label): raise InvalidChordException("Invalid chord label: " "{}".format(chord_label)) - pass def split(chord_label, reduce_extended_chords=False): @@ -715,7 +715,7 @@ def weighted_accuracy(comparisons, weights): return np.sum(comparisons * normalized_weights) -def thirds(reference_labels, estimated_labels): +def thirds(reference_labels, estimated_labels, safe=True): """Compare chords along root & third relationships. Examples @@ -742,6 +742,9 @@ def thirds(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -749,7 +752,8 @@ def thirds(reference_labels, estimated_labels): Comparison scores, in [0.0, 1.0] """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] est_roots, est_semitones = encode_many(estimated_labels, False)[:2] @@ -762,7 +766,7 @@ def thirds(reference_labels, estimated_labels): return comparison_scores -def thirds_inv(reference_labels, estimated_labels): +def thirds_inv(reference_labels, estimated_labels, safe=True): """Score chords along root, third, & bass relationships. Examples @@ -789,6 +793,9 @@ def thirds_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -796,7 +803,8 @@ def thirds_inv(reference_labels, estimated_labels): Comparison scores, in [0.0, 1.0] """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False) est_roots, est_semitones, est_bass = encode_many(estimated_labels, False) @@ -810,7 +818,7 @@ def thirds_inv(reference_labels, estimated_labels): return comparison_scores -def triads(reference_labels, estimated_labels): +def triads(reference_labels, estimated_labels, safe=True): """Compare chords along triad (root & quality to #5) relationships. Examples @@ -837,6 +845,9 @@ def triads(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -844,7 +855,8 @@ def triads(reference_labels, estimated_labels): Comparison scores, in [0.0, 1.0] """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] est_roots, est_semitones = encode_many(estimated_labels, False)[:2] @@ -857,7 +869,7 @@ def triads(reference_labels, estimated_labels): return comparison_scores -def triads_inv(reference_labels, estimated_labels): +def triads_inv(reference_labels, estimated_labels, safe=True): """Score chords along triad (root, quality to #5, & bass) relationships. Examples @@ -884,6 +896,9 @@ def triads_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -891,7 +906,8 @@ def triads_inv(reference_labels, estimated_labels): Comparison scores, in [0.0, 1.0] """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False) est_roots, est_semitones, est_bass = encode_many(estimated_labels, False) @@ -905,7 +921,7 @@ def triads_inv(reference_labels, estimated_labels): return comparison_scores -def tetrads(reference_labels, estimated_labels): +def tetrads(reference_labels, estimated_labels, safe=True): """Compare chords along tetrad (root & full quality) relationships. Examples @@ -932,6 +948,9 @@ def tetrads(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -939,7 +958,8 @@ def tetrads(reference_labels, estimated_labels): Comparison scores, in [0.0, 1.0] """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] est_roots, est_semitones = encode_many(estimated_labels, False)[:2] @@ -952,7 +972,7 @@ def tetrads(reference_labels, estimated_labels): return comparison_scores -def tetrads_inv(reference_labels, estimated_labels): +def tetrads_inv(reference_labels, estimated_labels, safe=True): """Compare chords along tetrad (root, full quality, & bass) relationships. Examples @@ -979,6 +999,9 @@ def tetrads_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -986,7 +1009,8 @@ def tetrads_inv(reference_labels, estimated_labels): Comparison scores, in [0.0, 1.0] """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False) est_roots, est_semitones, est_bass = encode_many(estimated_labels, False) @@ -1000,7 +1024,7 @@ def tetrads_inv(reference_labels, estimated_labels): return comparison_scores -def root(reference_labels, estimated_labels): +def root(reference_labels, estimated_labels, safe=True): """Compare chords according to roots. Examples @@ -1027,6 +1051,9 @@ def root(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1034,7 +1061,8 @@ def root(reference_labels, estimated_labels): Comparison scores, in [0.0, 1.0], or -1 if the comparison is out of gamut. """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] est_roots = encode_many(estimated_labels, False)[0] comparison_scores = (ref_roots == est_roots).astype(np.float64) @@ -1044,7 +1072,7 @@ def root(reference_labels, estimated_labels): return comparison_scores -def mirex(reference_labels, estimated_labels): +def mirex(reference_labels, estimated_labels, safe=True): """Compare chords along MIREX rules. Examples @@ -1071,6 +1099,9 @@ def mirex(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1078,7 +1109,8 @@ def mirex(reference_labels, estimated_labels): Comparison scores, in [0.0, 1.0] """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) # TODO(?): Should this be an argument? min_intersection = 3 ref_data = encode_many(reference_labels, False) @@ -1107,7 +1139,7 @@ def mirex(reference_labels, estimated_labels): return comparison_scores -def majmin(reference_labels, estimated_labels): +def majmin(reference_labels, estimated_labels, safe=True): """Compare chords along major-minor rules. Chords with qualities outside Major/minor/no-chord are ignored. @@ -1135,6 +1167,9 @@ def majmin(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1143,7 +1178,8 @@ def majmin(reference_labels, estimated_labels): gamut. """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) maj_semitones = np.array(QUALITIES["maj"][:8]) min_semitones = np.array(QUALITIES["min"][:8]) @@ -1172,7 +1208,7 @@ def majmin(reference_labels, estimated_labels): return comparison_scores -def majmin_inv(reference_labels, estimated_labels): +def majmin_inv(reference_labels, estimated_labels, safe=True): """Compare chords along major-minor rules, with inversions. Chords with qualities outside Major/minor/no-chord are ignored, and the bass note must exist in the triad (bass in [1, 3, 5]). @@ -1201,6 +1237,9 @@ def majmin_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1209,7 +1248,8 @@ def majmin_inv(reference_labels, estimated_labels): gamut. """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) maj_semitones = np.array(QUALITIES["maj"][:8]) min_semitones = np.array(QUALITIES["min"][:8]) @@ -1236,7 +1276,7 @@ def majmin_inv(reference_labels, estimated_labels): return comparison_scores -def sevenths(reference_labels, estimated_labels): +def sevenths(reference_labels, estimated_labels, safe=True): """Compare chords along MIREX 'sevenths' rules. Chords with qualities outside [maj, maj7, 7, min, min7, N] are ignored. @@ -1264,6 +1304,9 @@ def sevenths(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1272,7 +1315,8 @@ def sevenths(reference_labels, estimated_labels): gamut. """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) seventh_qualities = ["maj", "min", "maj7", "7", "min7", ""] valid_semitones = np.array([QUALITIES[name] for name in seventh_qualities]) @@ -1295,7 +1339,7 @@ def sevenths(reference_labels, estimated_labels): return comparison_scores -def sevenths_inv(reference_labels, estimated_labels): +def sevenths_inv(reference_labels, estimated_labels, safe=True): """Compare chords along MIREX 'sevenths' rules. Chords with qualities outside [maj, maj7, 7, min, min7, N] are ignored. @@ -1323,6 +1367,9 @@ def sevenths_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1331,7 +1378,8 @@ def sevenths_inv(reference_labels, estimated_labels): gamut. """ - validate(reference_labels, estimated_labels) + if safe: + validate(reference_labels, estimated_labels) seventh_qualities = ["maj", "min", "maj7", "7", "min7", ""] valid_semitones = np.array([QUALITIES[name] for name in seventh_qualities]) @@ -1359,7 +1407,7 @@ def sevenths_inv(reference_labels, estimated_labels): return comparison_scores -def directional_hamming_distance(reference_intervals, estimated_intervals): +def directional_hamming_distance(reference_intervals, estimated_intervals, safe=True): """Compute the directional hamming distance between reference and estimated intervals as defined by [#harte2010towards]_ and used for MIREX 'OverSeg', 'UnderSeg' and 'MeanSeg' measures. @@ -1382,6 +1430,9 @@ def directional_hamming_distance(reference_intervals, estimated_intervals): Reference chord intervals to score against. estimated_intervals : np.ndarray, shape=(m, 2), dtype=float Estimated chord intervals to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1389,8 +1440,9 @@ def directional_hamming_distance(reference_intervals, estimated_intervals): directional hamming distance between reference intervals and estimated intervals. """ - util.validate_intervals(estimated_intervals) - util.validate_intervals(reference_intervals) + if safe: + util.validate_intervals(estimated_intervals) + util.validate_intervals(reference_intervals) # make sure chord intervals do not overlap if ( @@ -1409,7 +1461,7 @@ def directional_hamming_distance(reference_intervals, estimated_intervals): return seg / (reference_intervals[-1, 1] - reference_intervals[0, 0]) -def overseg(reference_intervals, estimated_intervals): +def overseg(reference_intervals, estimated_intervals, safe=True): """Compute the MIREX 'OverSeg' score. Examples @@ -1426,16 +1478,21 @@ def overseg(reference_intervals, estimated_intervals): Reference chord intervals to score against. estimated_intervals : np.ndarray, shape=(m, 2), dtype=float Estimated chord intervals to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- oversegmentation score : float Comparison score, in [0.0, 1.0], where 1.0 means no oversegmentation. """ - return 1 - directional_hamming_distance(reference_intervals, estimated_intervals) + return 1 - directional_hamming_distance( + reference_intervals, estimated_intervals, safe=safe + ) -def underseg(reference_intervals, estimated_intervals): +def underseg(reference_intervals, estimated_intervals, safe=True): """Compute the MIREX 'UnderSeg' score. Examples @@ -1452,16 +1509,21 @@ def underseg(reference_intervals, estimated_intervals): Reference chord intervals to score against. estimated_intervals : np.ndarray, shape=(m, 2), dtype=float Estimated chord intervals to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- undersegmentation score : float Comparison score, in [0.0, 1.0], where 1.0 means no undersegmentation. """ - return 1 - directional_hamming_distance(estimated_intervals, reference_intervals) + return 1 - directional_hamming_distance( + estimated_intervals, reference_intervals, safe=safe + ) -def seg(reference_intervals, estimated_intervals): +def seg(reference_intervals, estimated_intervals, safe=True): """Compute the MIREX 'MeanSeg' score. Examples @@ -1478,6 +1540,9 @@ def seg(reference_intervals, estimated_intervals): Reference chord intervals to score against. estimated_intervals : np.ndarray, shape=(m, 2), dtype=float Estimated chord intervals to score against. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1485,8 +1550,8 @@ def seg(reference_intervals, estimated_intervals): Comparison score, in [0.0, 1.0], where 1.0 means perfect segmentation. """ return min( - underseg(reference_intervals, estimated_intervals), - overseg(reference_intervals, estimated_intervals), + underseg(reference_intervals, estimated_intervals, safe=safe), + overseg(reference_intervals, estimated_intervals, safe=safe), ) @@ -1525,7 +1590,7 @@ def merge_chord_intervals(intervals, labels): return np.array(merged_ivs) -def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): +def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, safe=True, **kwargs): """Compute weighted accuracy for all comparison functions for the given reference and estimated annotations. @@ -1552,6 +1617,9 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): est_labels : list, shape=(m,) estimated chord labels, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -1582,33 +1650,55 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): # Convert intervals to durations (used as weights) durations = util.intervals_to_durations(intervals) + # Validate the data up front + if safe: + validate(ref_labels, est_labels) + util.validate_intervals(ref_intervals) + util.validate_intervals(est_intervals) + # Store scores for each comparison function scores = collections.OrderedDict() - scores["thirds"] = weighted_accuracy(thirds(ref_labels, est_labels), durations) + scores["thirds"] = weighted_accuracy( + thirds(ref_labels, est_labels, safe=False), durations + ) scores["thirds_inv"] = weighted_accuracy( - thirds_inv(ref_labels, est_labels), durations + thirds_inv(ref_labels, est_labels, safe=False), durations + ) + scores["triads"] = weighted_accuracy( + triads(ref_labels, est_labels, safe=False), durations ) - scores["triads"] = weighted_accuracy(triads(ref_labels, est_labels), durations) scores["triads_inv"] = weighted_accuracy( - triads_inv(ref_labels, est_labels), durations + triads_inv(ref_labels, est_labels, safe=False), durations + ) + scores["tetrads"] = weighted_accuracy( + tetrads(ref_labels, est_labels, safe=False), durations ) - scores["tetrads"] = weighted_accuracy(tetrads(ref_labels, est_labels), durations) scores["tetrads_inv"] = weighted_accuracy( - tetrads_inv(ref_labels, est_labels), durations + tetrads_inv(ref_labels, est_labels, safe=False), durations + ) + scores["root"] = weighted_accuracy( + root(ref_labels, est_labels, safe=False), durations + ) + scores["mirex"] = weighted_accuracy( + mirex(ref_labels, est_labels, safe=False), durations + ) + scores["majmin"] = weighted_accuracy( + majmin(ref_labels, est_labels, safe=False), durations ) - scores["root"] = weighted_accuracy(root(ref_labels, est_labels), durations) - scores["mirex"] = weighted_accuracy(mirex(ref_labels, est_labels), durations) - scores["majmin"] = weighted_accuracy(majmin(ref_labels, est_labels), durations) scores["majmin_inv"] = weighted_accuracy( - majmin_inv(ref_labels, est_labels), durations + majmin_inv(ref_labels, est_labels, safe=False), durations + ) + scores["sevenths"] = weighted_accuracy( + sevenths(ref_labels, est_labels, safe=False), durations ) - scores["sevenths"] = weighted_accuracy(sevenths(ref_labels, est_labels), durations) scores["sevenths_inv"] = weighted_accuracy( - sevenths_inv(ref_labels, est_labels), durations + sevenths_inv(ref_labels, est_labels, safe=False), durations + ) + scores["underseg"] = underseg( + merged_ref_intervals, merged_est_intervals, safe=False ) - scores["underseg"] = underseg(merged_ref_intervals, merged_est_intervals) - scores["overseg"] = overseg(merged_ref_intervals, merged_est_intervals) + scores["overseg"] = overseg(merged_ref_intervals, merged_est_intervals, safe=False) scores["seg"] = min(scores["overseg"], scores["underseg"]) return scores From 089bd79154db67bc1f8e469b67922d4a178de3d3 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 16:04:19 -0400 Subject: [PATCH 04/10] safety param in hierarchy --- mir_eval/hierarchy.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/mir_eval/hierarchy.py b/mir_eval/hierarchy.py index 1e79ba1a..f46b96b1 100644 --- a/mir_eval/hierarchy.py +++ b/mir_eval/hierarchy.py @@ -472,6 +472,7 @@ def tmeasure( window=15.0, frame_size=0.1, beta=1.0, + safe=True, ): """Compute the tree measures for hierarchical segment annotations. @@ -494,6 +495,9 @@ def tmeasure( the window. beta : float > 0 beta parameter for the F-measure. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -531,8 +535,9 @@ def tmeasure( window_frames = int(_round(window, frame_size) / frame_size) # Validate the hierarchical segmentations - validate_hier_intervals(reference_intervals_hier) - validate_hier_intervals(estimated_intervals_hier) + if safe: + validate_hier_intervals(reference_intervals_hier) + validate_hier_intervals(estimated_intervals_hier) # Build the least common ancestor matrices ref_lca = _lca(reference_intervals_hier, frame_size) @@ -554,6 +559,7 @@ def lmeasure( estimated_labels_hier, frame_size=0.1, beta=1.0, + safe=True, ): """Compute the tree measures for hierarchical segment annotations. @@ -576,6 +582,9 @@ def lmeasure( the window. beta : float > 0 beta parameter for the F-measure. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -602,8 +611,9 @@ def lmeasure( ) # Validate the hierarchical segmentations - validate_hier_intervals(reference_intervals_hier) - validate_hier_intervals(estimated_intervals_hier) + if safe: + validate_hier_intervals(reference_intervals_hier) + validate_hier_intervals(estimated_intervals_hier) # Build the least common ancestor matrices ref_meet = _meet(reference_intervals_hier, reference_labels_hier, frame_size) @@ -619,7 +629,12 @@ def lmeasure( def evaluate( - ref_intervals_hier, ref_labels_hier, est_intervals_hier, est_labels_hier, **kwargs + ref_intervals_hier, + ref_labels_hier, + est_intervals_hier, + est_labels_hier, + safe=True, + **kwargs ): r"""Compute all hierarchical structure metrics for the given reference and estimated annotations. @@ -678,6 +693,9 @@ def evaluate( of segmentations. Each segmentation itself is a list (or list-like) of intervals (\*_intervals_hier) and a list of lists of labels (\*_labels_hier). + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs additional keyword arguments to the evaluation metrics. @@ -707,6 +725,12 @@ def evaluate( est_intervals_hier, est_labels_hier = _align_intervals( est_intervals_hier, est_labels_hier, t_min=0.0, t_max=t_end ) + if safe: + validate_hier_intervals(ref_intervals_hier) + validate_hier_intervals(est_intervals_hier) + + # We can now bypass further validation checks + kwargs["safe"] = False scores = collections.OrderedDict() From 17dd57ba9706a10abb5491d472fe636e68146ae0 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 16:08:21 -0400 Subject: [PATCH 05/10] more safety params --- mir_eval/key.py | 17 ++++++++++--- mir_eval/melody.py | 59 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/mir_eval/key.py b/mir_eval/key.py index c0302385..0aa6d208 100644 --- a/mir_eval/key.py +++ b/mir_eval/key.py @@ -113,7 +113,7 @@ def split_key_string(key): return KEY_TO_SEMITONE[key.lower()], mode -def weighted_score(reference_key, estimated_key): +def weighted_score(reference_key, estimated_key, safe=True): """Compute a heuristic score which is weighted according to the relationship of the reference and estimated key, as follows: @@ -143,13 +143,17 @@ def weighted_score(reference_key, estimated_key): Reference key string. estimated_key : str Estimated key string. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- score : float Score representing how closely related the keys are. """ - validate(reference_key, estimated_key) + if safe: + validate(reference_key, estimated_key) reference_key, reference_mode = split_key_string(reference_key) estimated_key, estimated_mode = split_key_string(estimated_key) # If keys are the same, return 1. @@ -181,7 +185,7 @@ def weighted_score(reference_key, estimated_key): return 0.0 -def evaluate(reference_key, estimated_key, **kwargs): +def evaluate(reference_key, estimated_key, safe=True, **kwargs): """Compute all metrics for the given reference and estimated annotations. Examples @@ -196,6 +200,9 @@ def evaluate(reference_key, estimated_key, **kwargs): Reference key string. estimated_key : str Estimated key string. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -206,6 +213,10 @@ def evaluate(reference_key, estimated_key, **kwargs): Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. """ + if safe: + validate(reference_key, estimated_key) + + kwargs["safe"] = False # Compute all metrics scores = collections.OrderedDict() diff --git a/mir_eval/melody.py b/mir_eval/melody.py index ec2632f2..f4d28d64 100644 --- a/mir_eval/melody.py +++ b/mir_eval/melody.py @@ -494,7 +494,7 @@ def voicing_false_alarm(ref_voicing, est_voicing): return np.sum(est_voicing * ref_indicator) / np.sum(ref_indicator) -def voicing_measures(ref_voicing, est_voicing): +def voicing_measures(ref_voicing, est_voicing, safe=True): """Compute the voicing recall and false alarm rates given two voicing indicator sequences, one as reference (truth) and the other as the estimate (prediction). The sequences must be of the same length. @@ -517,6 +517,9 @@ def voicing_measures(ref_voicing, est_voicing): Reference boolean voicing array est_voicing : np.ndarray Estimated boolean voicing array + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -527,13 +530,16 @@ def voicing_measures(ref_voicing, est_voicing): Voicing false alarm rate, the fraction of unvoiced frames in ref indicated as voiced in est """ - validate_voicing(ref_voicing, est_voicing) + if safe: + validate_voicing(ref_voicing, est_voicing) vx_recall = voicing_recall(ref_voicing, est_voicing) vx_false_alm = voicing_false_alarm(ref_voicing, est_voicing) return vx_recall, vx_false_alm -def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50): +def raw_pitch_accuracy( + ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50, safe=True +): """Compute the raw pitch accuracy given two pitch (frequency) sequences in cents and matching voicing indicator sequences. The first pitch and voicing arrays are treated as the reference (truth), and the second two as the @@ -566,6 +572,9 @@ def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolera Maximum absolute deviation in cents for a frequency value to be considered correct (Default value = 50) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -574,8 +583,9 @@ def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolera which est_cent provides a correct frequency values (within cent_tolerance cents). """ - validate_voicing(ref_voicing, est_voicing) - validate(ref_voicing, ref_cent, est_voicing, est_cent) + if safe: + validate_voicing(ref_voicing, est_voicing) + validate(ref_voicing, ref_cent, est_voicing, est_cent) # When input arrays are empty, return 0 by special case # If there are no voiced frames in reference, metric is 0 if ( @@ -602,7 +612,7 @@ def raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolera def raw_chroma_accuracy( - ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50 + ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50, safe=True ): """Compute the raw chroma accuracy given two pitch (frequency) sequences in cents and matching voicing indicator sequences. The first pitch and @@ -636,6 +646,9 @@ def raw_chroma_accuracy( Maximum absolute deviation in cents for a frequency value to be considered correct (Default value = 50) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -645,8 +658,9 @@ def raw_chroma_accuracy( cent_tolerance cents), ignoring octave errors """ - validate_voicing(ref_voicing, est_voicing) - validate(ref_voicing, ref_cent, est_voicing, est_cent) + if safe: + validate_voicing(ref_voicing, est_voicing) + validate(ref_voicing, ref_cent, est_voicing, est_cent) # When input arrays are empty, return 0 by special case # If there are no voiced frames in reference, metric is 0 if ( @@ -670,7 +684,9 @@ def raw_chroma_accuracy( return rca -def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50): +def overall_accuracy( + ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance=50, safe=True +): """Compute the overall accuracy given two pitch (frequency) sequences in cents and matching voicing indicator sequences. The first pitch and voicing arrays are treated as the reference (truth), and the second two @@ -703,6 +719,9 @@ def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_toleranc Maximum absolute deviation in cents for a frequency value to be considered correct (Default value = 50) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -711,8 +730,9 @@ def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_toleranc where provides a correct frequency values (within cent_tolerance). """ - validate_voicing(ref_voicing, est_voicing) - validate(ref_voicing, ref_cent, est_voicing, est_cent) + if safe: + validate_voicing(ref_voicing, est_voicing) + validate(ref_voicing, ref_cent, est_voicing, est_cent) # When input arrays are empty, return 0 by special case if ( @@ -750,7 +770,14 @@ def overall_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_toleranc def evaluate( - ref_time, ref_freq, est_time, est_freq, est_voicing=None, ref_reward=None, **kwargs + ref_time, + ref_freq, + est_time, + est_freq, + est_voicing=None, + ref_reward=None, + safe=True, + **kwargs ): """Evaluate two melody (predominant f0) transcriptions, where the first is treated as the reference (ground truth) and the second as the estimate to @@ -781,6 +808,9 @@ def evaluate( ref_reward : np.ndarray Reference pitch estimation reward. Default None, which means all frames are weighted equally. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -819,6 +849,11 @@ def evaluate( ref_reward, **kwargs ) + if safe: + validate_voicing(ref_voicing, est_voicing) + validate(ref_voicing, ref_cent, est_voicing, est_cent) + + kwargs["safe"] = False # Compute metrics scores = collections.OrderedDict() From 97d1a1f2088f84137a20d566fbead8065fac443d Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 16:11:34 -0400 Subject: [PATCH 06/10] more safety params --- mir_eval/multipitch.py | 17 +++++++++++++---- mir_eval/onset.py | 18 +++++++++++++++--- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/mir_eval/multipitch.py b/mir_eval/multipitch.py index 456417bf..d55edde0 100644 --- a/mir_eval/multipitch.py +++ b/mir_eval/multipitch.py @@ -346,7 +346,7 @@ def compute_err_score(true_positives, n_ref, n_est): return e_sub, e_miss, e_fa, e_tot -def metrics(ref_time, ref_freqs, est_time, est_freqs, **kwargs): +def metrics(ref_time, ref_freqs, est_time, est_freqs, safe=True, **kwargs): """Compute multipitch metrics. All metrics are computed at the 'macro' level such that the frame true positive/false positive/false negative rates are summed across time and the metrics are computed on the combined values. @@ -370,6 +370,9 @@ def metrics(ref_time, ref_freqs, est_time, est_freqs, **kwargs): Time of each estimated frequency value est_freqs : list of np.ndarray List of np.ndarrays of estimate frequency values + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -406,7 +409,8 @@ def metrics(ref_time, ref_freqs, est_time, est_freqs, **kwargs): Chroma total error """ - validate(ref_time, ref_freqs, est_time, est_freqs) + if safe: + validate(ref_time, ref_freqs, est_time, est_freqs) # resample est_freqs if est_times is different from ref_times if est_time.size != ref_time.size or not np.allclose(est_time, ref_time): @@ -476,7 +480,7 @@ def metrics(ref_time, ref_freqs, est_time, est_freqs, **kwargs): ) -def evaluate(ref_time, ref_freqs, est_time, est_freqs, **kwargs): +def evaluate(ref_time, ref_freqs, est_time, est_freqs, safe=True, **kwargs): """Evaluate two multipitch (multi-f0) transcriptions, where the first is treated as the reference (ground truth) and the second as the estimate to be evaluated (prediction). @@ -498,6 +502,9 @@ def evaluate(ref_time, ref_freqs, est_time, est_freqs, **kwargs): Time of each estimated frequency value est_freqs : list of np.ndarray List of np.ndarrays of estimate frequency values + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -526,6 +533,8 @@ def evaluate(ref_time, ref_freqs, est_time, est_freqs, **kwargs): scores["Chroma Miss Error"], scores["Chroma False Alarm Error"], scores["Chroma Total Error"], - ) = util.filter_kwargs(metrics, ref_time, ref_freqs, est_time, est_freqs, **kwargs) + ) = util.filter_kwargs( + metrics, ref_time, ref_freqs, est_time, est_freqs, safe=safe, **kwargs + ) return scores diff --git a/mir_eval/onset.py b/mir_eval/onset.py index a315f23b..49eca818 100644 --- a/mir_eval/onset.py +++ b/mir_eval/onset.py @@ -53,7 +53,7 @@ def validate(reference_onsets, estimated_onsets): util.validate_events(onsets, MAX_TIME) -def f_measure(reference_onsets, estimated_onsets, window=0.05): +def f_measure(reference_onsets, estimated_onsets, window=0.05, safe=True): """Compute the F-measure of correct vs incorrectly predicted onsets. "Correctness" is determined over a small window. @@ -73,6 +73,9 @@ def f_measure(reference_onsets, estimated_onsets, window=0.05): window : float Window size, in seconds (Default value = .05) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -84,7 +87,8 @@ def f_measure(reference_onsets, estimated_onsets, window=0.05): (# true positives)/(# true positives + # false negatives) """ - validate(reference_onsets, estimated_onsets) + if safe: + validate(reference_onsets, estimated_onsets) # If either list is empty, return 0s if reference_onsets.size == 0 or estimated_onsets.size == 0: return 0.0, 0.0, 0.0 @@ -98,7 +102,7 @@ def f_measure(reference_onsets, estimated_onsets, window=0.05): return util.f_measure(precision, recall), precision, recall -def evaluate(reference_onsets, estimated_onsets, **kwargs): +def evaluate(reference_onsets, estimated_onsets, safe=True, **kwargs): """Compute all metrics for the given reference and estimated annotations. Examples @@ -114,6 +118,9 @@ def evaluate(reference_onsets, estimated_onsets, **kwargs): reference onset locations, in seconds estimated_onsets : np.ndarray estimated onset locations, in seconds + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -125,6 +132,11 @@ def evaluate(reference_onsets, estimated_onsets, **kwargs): the value is the (float) score achieved. """ + if safe: + validate(reference_onsets, estimated_onsets) + + kwargs["safe"] = False + # Compute all metrics scores = collections.OrderedDict() From 021acda78f0ef4043bf3e08643ce4225e4661522 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 16:14:33 -0400 Subject: [PATCH 07/10] more safety params --- mir_eval/pattern.py | 62 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/mir_eval/pattern.py b/mir_eval/pattern.py index cc05fc8f..32d403fa 100644 --- a/mir_eval/pattern.py +++ b/mir_eval/pattern.py @@ -168,7 +168,7 @@ def _compute_score_matrix(P, Q, similarity_metric="cardinality_score"): return sm -def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5): +def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5, safe=True): """Compute the standard F1 Score, Precision and Recall. This metric checks if the prototype patterns of the reference match @@ -194,6 +194,9 @@ def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5): Default parameter is the one found in the original matlab code by Tom Collins used for MIREX 2013. (Default value = 1e-5) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -205,7 +208,8 @@ def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5): The standard Recall """ - validate(reference_patterns, estimated_patterns) + if safe: + validate(reference_patterns, estimated_patterns) nP = len(reference_patterns) # Number of patterns in the reference nQ = len(estimated_patterns) # Number of patterns in the estimation k = 0 # Number of patterns that match @@ -236,7 +240,10 @@ def standard_FPR(reference_patterns, estimated_patterns, tol=1e-5): def establishment_FPR( - reference_patterns, estimated_patterns, similarity_metric="cardinality_score" + reference_patterns, + estimated_patterns, + similarity_metric="cardinality_score", + safe=True, ): """Compute the establishment F1 Score, Precision and Recall. @@ -265,6 +272,10 @@ def establishment_FPR( (Default value = "cardinality_score") + safe : bool + If True, validate inputs. + If False, skip validation of inputs. + Returns ------- f_measure : float @@ -275,7 +286,8 @@ def establishment_FPR( The establishment Recall """ - validate(reference_patterns, estimated_patterns) + if safe: + validate(reference_patterns, estimated_patterns) nP = len(reference_patterns) # Number of elements in reference nQ = len(estimated_patterns) # Number of elements in estimation S = np.zeros((nP, nQ)) # Establishment matrix @@ -301,6 +313,7 @@ def occurrence_FPR( estimated_patterns, thres=0.75, similarity_metric="cardinality_score", + safe=True, ): """Compute the occurrence F1 Score, Precision and Recall. @@ -334,6 +347,10 @@ def occurrence_FPR( (Default value = "cardinality_score") + safe : bool + If True, validate inputs. + If False, skip validation of inputs. + Returns ------- f_measure : float @@ -343,7 +360,8 @@ def occurrence_FPR( recall : float The occurrence Recall """ - validate(reference_patterns, estimated_patterns) + if safe: + validate(reference_patterns, estimated_patterns) # Number of elements in reference nP = len(reference_patterns) # Number of elements in estimation @@ -379,7 +397,7 @@ def occurrence_FPR( return f_measure, precision, recall -def three_layer_FPR(reference_patterns, estimated_patterns): +def three_layer_FPR(reference_patterns, estimated_patterns, safe=True): """Three Layer F1 Score, Precision and Recall. As described by Meridith. Examples @@ -396,6 +414,9 @@ def three_layer_FPR(reference_patterns, estimated_patterns): :func:`mir_eval.io.load_patterns()` estimated_patterns : list The estimated patterns in the same format + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -407,7 +428,8 @@ def three_layer_FPR(reference_patterns, estimated_patterns): The three-layer Recall """ - validate(reference_patterns, estimated_patterns) + if safe: + validate(reference_patterns, estimated_patterns) def compute_first_layer_PR(ref_occs, est_occs): """Compute the first layer Precision and Recall values given the @@ -506,7 +528,7 @@ def compute_layer(ref_elements, est_elements, layer=1): return f_measure_3, precision_3, recall_3 -def first_n_three_layer_P(reference_patterns, estimated_patterns, n=5): +def first_n_three_layer_P(reference_patterns, estimated_patterns, n=5, safe=True): """First n three-layer precision. This metric is basically the same as the three-layer FPR but it is only @@ -531,13 +553,17 @@ def first_n_three_layer_P(reference_patterns, estimated_patterns, n=5): Number of patterns to consider from the estimated results, in the order they appear in the matrix (Default value = 5) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- precision : float The first n three-layer Precision """ - validate(reference_patterns, estimated_patterns) + if safe: + validate(reference_patterns, estimated_patterns) # If no patterns were provided, metric is zero if _n_onset_midi(reference_patterns) == 0 or _n_onset_midi(estimated_patterns) == 0: return 0.0, 0.0, 0.0 @@ -551,7 +577,7 @@ def first_n_three_layer_P(reference_patterns, estimated_patterns, n=5): return P # Return the precision only -def first_n_target_proportion_R(reference_patterns, estimated_patterns, n=5): +def first_n_target_proportion_R(reference_patterns, estimated_patterns, n=5, safe=True): """First n target proportion establishment recall metric. This metric is similar is similar to the establishment FPR score, but it @@ -576,13 +602,17 @@ def first_n_target_proportion_R(reference_patterns, estimated_patterns, n=5): Number of patterns to consider from the estimated results, in the order they appear in the matrix. (Default value = 5) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- recall : float The first n target proportion Recall. """ - validate(reference_patterns, estimated_patterns) + if safe: + validate(reference_patterns, estimated_patterns) # If no patterns were provided, metric is zero if _n_onset_midi(reference_patterns) == 0 or _n_onset_midi(estimated_patterns) == 0: return 0.0, 0.0, 0.0 @@ -594,7 +624,7 @@ def first_n_target_proportion_R(reference_patterns, estimated_patterns, n=5): return R -def evaluate(ref_patterns, est_patterns, **kwargs): +def evaluate(ref_patterns, est_patterns, safe=True, **kwargs): """Load data and perform the evaluation. Examples @@ -610,6 +640,9 @@ def evaluate(ref_patterns, est_patterns, **kwargs): :func:`mir_eval.io.load_patterns()` est_patterns : list The estimated patterns in the same format + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -620,6 +653,11 @@ def evaluate(ref_patterns, est_patterns, **kwargs): Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. """ + if safe: + validate(ref_patterns, est_patterns) + + kwargs["safe"] = False + # Compute all the metrics scores = collections.OrderedDict() From a188f7837fd6ff957ab582e0708d0921de742a9b Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 16:20:00 -0400 Subject: [PATCH 08/10] safety params in segment --- mir_eval/segment.py | 90 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 20 deletions(-) diff --git a/mir_eval/segment.py b/mir_eval/segment.py index 7b03f596..92bb4582 100644 --- a/mir_eval/segment.py +++ b/mir_eval/segment.py @@ -167,7 +167,12 @@ def validate_structure( def detection( - reference_intervals, estimated_intervals, window=0.5, beta=1.0, trim=False + reference_intervals, + estimated_intervals, + window=0.5, + beta=1.0, + trim=False, + safe=True, ): """Boundary detection hit-rate. @@ -215,6 +220,9 @@ def detection( if ``True``, the first and last boundary times are ignored. Typically, these denote start (0) and end-markers. (Default value = False) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -225,7 +233,8 @@ def detection( f_measure : float F-measure (weighted harmonic mean of ``precision`` and ``recall``) """ - validate_boundary(reference_intervals, estimated_intervals, trim) + if safe: + validate_boundary(reference_intervals, estimated_intervals, trim) # Convert intervals to boundaries reference_boundaries = util.intervals_to_boundaries(reference_intervals) @@ -250,7 +259,7 @@ def detection( return precision, recall, f_measure -def deviation(reference_intervals, estimated_intervals, trim=False): +def deviation(reference_intervals, estimated_intervals, trim=False, safe=True): """Compute the median deviations between reference and estimated boundary times. @@ -275,6 +284,9 @@ def deviation(reference_intervals, estimated_intervals, trim=False): if ``True``, the first and last intervals are ignored. Typically, these denote start (0.0) and end-of-track markers. (Default value = False) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -285,7 +297,8 @@ def deviation(reference_intervals, estimated_intervals, trim=False): median time from each estimated boundary to the closest reference boundary """ - validate_boundary(reference_intervals, estimated_intervals, trim) + if safe: + validate_boundary(reference_intervals, estimated_intervals, trim) # Convert intervals to boundaries reference_boundaries = util.intervals_to_boundaries(reference_intervals) @@ -315,6 +328,7 @@ def pairwise( estimated_labels, frame_size=0.1, beta=1.0, + safe=True, ): """Frame-clustering segmentation evaluation by pair-wise agreement. @@ -357,6 +371,9 @@ def pairwise( beta : float > 0 beta value for F-measure (Default value = 1.0) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -368,9 +385,10 @@ def pairwise( F-measure of detecting whether frames belong in the same cluster """ - validate_structure( - reference_intervals, reference_labels, estimated_intervals, estimated_labels - ) + if safe: + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals @@ -418,6 +436,7 @@ def rand_index( estimated_labels, frame_size=0.1, beta=1.0, + safe=True, ): """(Non-adjusted) Rand index. @@ -460,15 +479,19 @@ def rand_index( beta : float > 0 beta value for F-measure (Default value = 1.0) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- rand_index : float > 0 Rand index """ - validate_structure( - reference_intervals, reference_labels, estimated_intervals, estimated_labels - ) + if safe: + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals @@ -593,6 +616,7 @@ def ari( estimated_intervals, estimated_labels, frame_size=0.1, + safe=True, ): """Compute the Adjusted Rand Index (ARI) for frame clustering segmentation evaluation. @@ -630,6 +654,9 @@ def ari( frame_size : float > 0 length (in seconds) of frames for clustering (Default value = 0.1) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -637,9 +664,10 @@ def ari( Adjusted Rand index between segmentations. """ - validate_structure( - reference_intervals, reference_labels, estimated_intervals, estimated_labels - ) + if safe: + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals @@ -873,6 +901,7 @@ def mutual_information( estimated_intervals, estimated_labels, frame_size=0.1, + safe=True, ): """Frame-clustering segmentation: mutual information metrics. @@ -912,6 +941,9 @@ def mutual_information( frame_size : float > 0 length (in seconds) of frames for clustering (Default value = 0.1) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -923,9 +955,10 @@ def mutual_information( Normalize mutual information between segmentations """ - validate_structure( - reference_intervals, reference_labels, estimated_intervals, estimated_labels - ) + if safe: + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals @@ -966,6 +999,7 @@ def nce( frame_size=0.1, beta=1.0, marginal=False, + safe=True, ): """Frame-clustering segmentation: normalized conditional entropy @@ -1015,6 +1049,9 @@ def nce( If `False`, normalize conditional entropy by uniform entropy. If `True`, normalize conditional entropy by the marginal entropy. (Default value = False) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1037,9 +1074,10 @@ def nce( S_F F-measure for (S_over, S_under) """ - validate_structure( - reference_intervals, reference_labels, estimated_intervals, estimated_labels - ) + if safe: + validate_structure( + reference_intervals, reference_labels, estimated_intervals, estimated_labels + ) # Check for empty annotations. Don't need to check labels because # validate_structure makes sure they're the same size as intervals @@ -1105,6 +1143,7 @@ def vmeasure( estimated_labels, frame_size=0.1, beta=1.0, + safe=True, ): """Frame-clustering segmentation: v-measure @@ -1152,6 +1191,9 @@ def vmeasure( beta : float > 0 beta for F-measure (Default value = 1.0) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -1176,10 +1218,11 @@ def vmeasure( frame_size=frame_size, beta=beta, marginal=True, + safe=safe, ) -def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): +def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, safe=True, **kwargs): """Compute all metrics for the given reference and estimated annotations. Examples @@ -1205,6 +1248,9 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): est_labels : list, shape=(m,) estimated segment labels, in the format returned by :func:`mir_eval.io.load_labeled_intervals`. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -1224,6 +1270,10 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): est_intervals, labels=est_labels, t_min=0.0, t_max=ref_intervals.max() ) + if safe: + validate_boundary(ref_intervals, est_intervals, trim=kwargs.get("trim", False)) + validate_structure(ref_intervals, ref_labels, est_intervals, est_labels) + # Now compute all the metrics scores = collections.OrderedDict() From d62243fa6239a2c2830168db8ba7ba86385bcee2 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 20:37:39 -0400 Subject: [PATCH 09/10] safety in separation --- mir_eval/separation.py | 48 ++++++++++++++++++++++++++++++++++++------ mir_eval/tempo.py | 20 ++++++++++++++---- 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/mir_eval/separation.py b/mir_eval/separation.py index facb98db..c66e3b25 100644 --- a/mir_eval/separation.py +++ b/mir_eval/separation.py @@ -143,7 +143,9 @@ def _any_source_silent(sources): ) -def bss_eval_sources(reference_sources, estimated_sources, compute_permutation=True): +def bss_eval_sources( + reference_sources, estimated_sources, compute_permutation=True, safe=True +): """ Ordering and measurement of the separation quality for estimated source signals in terms of filtered true source, interference and artifacts. @@ -175,6 +177,9 @@ def bss_eval_sources(reference_sources, estimated_sources, compute_permutation=T reference_sources) compute_permutation : bool, optional compute permutation of estimate/source combinations (True by default) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -205,7 +210,8 @@ def bss_eval_sources(reference_sources, estimated_sources, compute_permutation=T if reference_sources.ndim == 1: reference_sources = reference_sources[np.newaxis, :] - validate(reference_sources, estimated_sources) + if safe: + validate(reference_sources, estimated_sources) # If empty matrices were supplied, return empty lists (special case) if reference_sources.size == 0 or estimated_sources.size == 0: return np.array([]), np.array([]), np.array([]), np.array([]) @@ -259,6 +265,7 @@ def bss_eval_sources_framewise( window=30 * 44100, hop=15 * 44100, compute_permutation=False, + safe=True, ): """Framewise computation of bss_eval_sources @@ -305,6 +312,9 @@ def bss_eval_sources_framewise( compute_permutation : bool, optional compute permutation of estimate/source combinations for all windows (False by default) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -326,7 +336,8 @@ def bss_eval_sources_framewise( if reference_sources.ndim == 1: reference_sources = reference_sources[np.newaxis, :] - validate(reference_sources, estimated_sources) + if safe: + validate(reference_sources, estimated_sources) # If empty matrices were supplied, return empty lists (special case) if reference_sources.size == 0 or estimated_sources.size == 0: return np.array([]), np.array([]), np.array([]), np.array([]) @@ -364,7 +375,9 @@ def bss_eval_sources_framewise( return sdr, sir, sar, perm -def bss_eval_images(reference_sources, estimated_sources, compute_permutation=True): +def bss_eval_images( + reference_sources, estimated_sources, compute_permutation=True, safe=True +): """Compute the bss_eval_images function from the BSS_EVAL Matlab toolbox. @@ -397,6 +410,9 @@ def bss_eval_images(reference_sources, estimated_sources, compute_permutation=Tr matrix containing estimated sources compute_permutation : bool, optional compute permutation of estimate/source combinations (True by default) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -428,7 +444,8 @@ def bss_eval_images(reference_sources, estimated_sources, compute_permutation=Tr reference_sources = np.atleast_3d(reference_sources) # we will ensure input doesn't have more than 3 dimensions in validate - validate(reference_sources, estimated_sources) + if safe: + validate(reference_sources, estimated_sources) # If empty matrices were supplied, return empty lists (special case) if reference_sources.size == 0 or estimated_sources.size == 0: return np.array([]), np.array([]), np.array([]), np.array([]), np.array([]) @@ -504,6 +521,7 @@ def bss_eval_images_framewise( window=30 * 44100, hop=15 * 44100, compute_permutation=False, + safe=True, ): """Framewise computation of bss_eval_images @@ -549,6 +567,9 @@ def bss_eval_images_framewise( compute_permutation : bool, optional compute permutation of estimate/source combinations for all windows (False by default) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -573,7 +594,8 @@ def bss_eval_images_framewise( reference_sources = np.atleast_3d(reference_sources) # we will ensure input doesn't have more than 3 dimensions in validate - validate(reference_sources, estimated_sources) + if safe: + validate(reference_sources, estimated_sources) # If empty matrices were supplied, return empty lists (special case) if reference_sources.size == 0 or estimated_sources.size == 0: return np.array([]), np.array([]), np.array([]), np.array([]) @@ -843,7 +865,7 @@ def _safe_db(num, den): return 10 * np.log10(num / den) -def evaluate(reference_sources, estimated_sources, **kwargs): +def evaluate(reference_sources, estimated_sources, safe=True, **kwargs): """Compute all metrics for the given reference and estimated signals. NOTE: This will always compute :func:`mir_eval.separation.bss_eval_images` @@ -865,6 +887,9 @@ def evaluate(reference_sources, estimated_sources, **kwargs): matrix containing true sources estimated_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) matrix containing estimated sources + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -876,6 +901,15 @@ def evaluate(reference_sources, estimated_sources, **kwargs): the value is the (float) score achieved. """ + if estimated_sources.ndim == 1: + estimated_sources = estimated_sources[np.newaxis, :] + if reference_sources.ndim == 1: + reference_sources = reference_sources[np.newaxis, :] + if safe: + validate(reference_sources, estimated_sources) + + kwargs["safe"] = False + # Compute all the metrics scores = collections.OrderedDict() diff --git a/mir_eval/tempo.py b/mir_eval/tempo.py index 15050364..bfac67be 100644 --- a/mir_eval/tempo.py +++ b/mir_eval/tempo.py @@ -70,7 +70,7 @@ def validate(reference_tempi, reference_weight, estimated_tempi): raise ValueError("Reference weight must lie in range [0, 1]") -def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): +def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08, safe=True): """Compute the tempo detection accuracy metric. Parameters @@ -87,6 +87,9 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): count as a hit. ``|est_t - ref_t| <= tol * ref_t`` (Default value = 0.08) + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -107,7 +110,8 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): If ``tol < 0`` or ``tol > 1``. """ - validate(reference_tempi, reference_weight, estimated_tempi) + if safe: + validate(reference_tempi, reference_weight, estimated_tempi) if tol < 0 or tol > 1: raise ValueError( @@ -138,7 +142,7 @@ def detection(reference_tempi, reference_weight, estimated_tempi, tol=0.08): return p_score, one_correct, both_correct -def evaluate(reference_tempi, reference_weight, estimated_tempi, **kwargs): +def evaluate(reference_tempi, reference_weight, estimated_tempi, safe=True, **kwargs): """Compute all metrics for the given reference and estimated annotations. Parameters @@ -150,6 +154,9 @@ def evaluate(reference_tempi, reference_weight, estimated_tempi, **kwargs): ``reference_tempi[1]``. estimated_tempi : np.ndarray, shape=(2,) Two non-negative estimated tempi. + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -168,7 +175,12 @@ def evaluate(reference_tempi, reference_weight, estimated_tempi, **kwargs): scores["One-correct"], scores["Both-correct"], ) = util.filter_kwargs( - detection, reference_tempi, reference_weight, estimated_tempi, **kwargs + detection, + reference_tempi, + reference_weight, + estimated_tempi, + safe=safe, + **kwargs ) return scores From fd2b8809e00ab4d303bad11cc30c1b79df02df80 Mon Sep 17 00:00:00 2001 From: Brian McFee Date: Tue, 26 Mar 2024 20:41:15 -0400 Subject: [PATCH 10/10] finished safety param threading --- mir_eval/transcription.py | 19 +++++++++++++-- mir_eval/transcription_velocity.py | 37 +++++++++++++++++++++++------- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/mir_eval/transcription.py b/mir_eval/transcription.py index bdb436ae..09dd2fb3 100644 --- a/mir_eval/transcription.py +++ b/mir_eval/transcription.py @@ -485,6 +485,7 @@ def precision_recall_f1_overlap( offset_min_tolerance=0.05, strict=False, beta=1.0, + safe=True, ): """Compute the Precision, Recall and F-measure of correct vs incorrectly transcribed notes, and the Average Overlap Ratio for correctly transcribed @@ -551,6 +552,9 @@ def precision_recall_f1_overlap( than). beta : float > 0 Weighting factor for f-measure (default value = 1.0). + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -563,7 +567,8 @@ def precision_recall_f1_overlap( avg_overlap_ratio : float The computed Average Overlap Ratio score """ - validate(ref_intervals, ref_pitches, est_intervals, est_pitches) + if safe: + validate(ref_intervals, ref_pitches, est_intervals, est_pitches) # When reference notes are empty, metrics are undefined, return 0's if len(ref_pitches) == 0 or len(est_pitches) == 0: return 0.0, 0.0, 0.0, 0.0 @@ -782,7 +787,9 @@ def offset_precision_recall_f1( return offset_precision, offset_recall, offset_f_measure -def evaluate(ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs): +def evaluate( + ref_intervals, ref_pitches, est_intervals, est_pitches, safe=True, **kwargs +): """Compute all metrics for the given reference and estimated annotations. Examples @@ -804,6 +811,9 @@ def evaluate(ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs): Array of estimated notes time intervals (onset and offset times) est_pitches : np.ndarray, shape=(m,) Array of estimated pitch values in Hertz + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -814,6 +824,11 @@ def evaluate(ref_intervals, ref_pitches, est_intervals, est_pitches, **kwargs): Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. """ + if safe: + validate(ref_intervals, ref_pitches, est_intervals, est_pitches) + + kwargs["safe"] = False + # Compute all the metrics scores = collections.OrderedDict() diff --git a/mir_eval/transcription_velocity.py b/mir_eval/transcription_velocity.py index 866ac97e..438121cc 100644 --- a/mir_eval/transcription_velocity.py +++ b/mir_eval/transcription_velocity.py @@ -241,6 +241,7 @@ def precision_recall_f1_overlap( strict=False, velocity_tolerance=0.1, beta=1.0, + safe=True, ): """Compute the Precision, Recall and F-measure of correct vs incorrectly transcribed notes, and the Average Overlap Ratio for correctly transcribed @@ -306,6 +307,9 @@ def precision_recall_f1_overlap( matched reference note. beta : float > 0 Weighting factor for f-measure (default value = 1.0). + safe : bool + If True, validate inputs. + If False, skip validation of inputs. Returns ------- @@ -318,14 +322,15 @@ def precision_recall_f1_overlap( avg_overlap_ratio : float The computed Average Overlap Ratio score """ - validate( - ref_intervals, - ref_pitches, - ref_velocities, - est_intervals, - est_pitches, - est_velocities, - ) + if safe: + validate( + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + ) # When reference notes are empty, metrics are undefined, return 0's if len(ref_pitches) == 0 or len(est_pitches) == 0: return 0.0, 0.0, 0.0, 0.0 @@ -363,6 +368,7 @@ def evaluate( est_intervals, est_pitches, est_velocities, + safe=True, **kwargs ): """Compute all metrics for the given reference and estimated annotations. @@ -381,6 +387,9 @@ def evaluate( Array of estimated pitch values in Hertz est_velocities : np.ndarray, shape=(n,) Array of MIDI velocities (i.e. between 0 and 127) of estimated notes + safe : bool + If True, validate inputs. + If False, skip validation of inputs. **kwargs Additional keyword arguments which will be passed to the appropriate metric or preprocessing functions. @@ -391,6 +400,18 @@ def evaluate( Dictionary of scores, where the key is the metric name (str) and the value is the (float) score achieved. """ + if safe: + validate( + ref_intervals, + ref_pitches, + ref_velocities, + est_intervals, + est_pitches, + est_velocities, + ) + + kwargs["safe"] = False + # Compute all the metrics scores = collections.OrderedDict()