diff --git a/mir_eval/chord.py b/mir_eval/chord.py index 0cf8ca2a..66cb7bb6 100644 --- a/mir_eval/chord.py +++ b/mir_eval/chord.py @@ -709,7 +709,7 @@ def weighted_accuracy(comparisons, weights): return np.sum(comparisons*normalized_weights) -def thirds(reference_labels, estimated_labels): +def thirds(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along root & third relationships. Examples @@ -736,6 +736,11 @@ def thirds(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -744,8 +749,8 @@ def thirds(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] - est_roots, est_semitones = encode_many(estimated_labels, False)[:2] + ref_roots, ref_semitones = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords)[:2] + est_roots, est_semitones = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords)[:2] eq_roots = ref_roots == est_roots eq_thirds = ref_semitones[:, 3] == est_semitones[:, 3] @@ -756,7 +761,7 @@ def thirds(reference_labels, estimated_labels): return comparison_scores -def thirds_inv(reference_labels, estimated_labels): +def thirds_inv(reference_labels, estimated_labels, reduce_extended_chords=False): """Score chords along root, third, & bass relationships. Examples @@ -783,6 +788,11 @@ def thirds_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -791,8 +801,8 @@ def thirds_inv(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False) - est_roots, est_semitones, est_bass = encode_many(estimated_labels, False) + ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords) + est_roots, est_semitones, est_bass = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords) eq_root = ref_roots == est_roots eq_bass = ref_bass == est_bass @@ -804,7 +814,7 @@ def thirds_inv(reference_labels, estimated_labels): return comparison_scores -def triads(reference_labels, estimated_labels): +def triads(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along triad (root & quality to #5) relationships. Examples @@ -831,6 +841,11 @@ def triads(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -839,8 +854,8 @@ def triads(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] - est_roots, est_semitones = encode_many(estimated_labels, False)[:2] + ref_roots, ref_semitones = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords)[:2] + est_roots, est_semitones = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords)[:2] eq_roots = ref_roots == est_roots eq_semitones = np.all( @@ -852,7 +867,7 @@ def triads(reference_labels, estimated_labels): return comparison_scores -def triads_inv(reference_labels, estimated_labels): +def triads_inv(reference_labels, estimated_labels, reduce_extended_chords=False): """Score chords along triad (root, quality to #5, & bass) relationships. Examples @@ -879,6 +894,11 @@ def triads_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -887,8 +907,8 @@ def triads_inv(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False) - est_roots, est_semitones, est_bass = encode_many(estimated_labels, False) + ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords) + est_roots, est_semitones, est_bass = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords) eq_roots = ref_roots == est_roots eq_basses = ref_bass == est_bass @@ -901,7 +921,7 @@ def triads_inv(reference_labels, estimated_labels): return comparison_scores -def tetrads(reference_labels, estimated_labels): +def tetrads(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along tetrad (root & full quality) relationships. Examples @@ -928,6 +948,11 @@ def tetrads(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -936,8 +961,8 @@ def tetrads(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] - est_roots, est_semitones = encode_many(estimated_labels, False)[:2] + ref_roots, ref_semitones = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords)[:2] + est_roots, est_semitones = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords)[:2] eq_roots = ref_roots == est_roots eq_semitones = np.all(np.equal(ref_semitones, est_semitones), axis=1) @@ -948,7 +973,7 @@ def tetrads(reference_labels, estimated_labels): return comparison_scores -def tetrads_inv(reference_labels, estimated_labels): +def tetrads_inv(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along tetrad (root, full quality, & bass) relationships. Examples @@ -975,6 +1000,11 @@ def tetrads_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -983,8 +1013,8 @@ def tetrads_inv(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False) - est_roots, est_semitones, est_bass = encode_many(estimated_labels, False) + ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords) + est_roots, est_semitones, est_bass = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords) eq_roots = ref_roots == est_roots eq_basses = ref_bass == est_bass @@ -996,7 +1026,7 @@ def tetrads_inv(reference_labels, estimated_labels): return comparison_scores -def root(reference_labels, estimated_labels): +def root(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords according to roots. Examples @@ -1023,6 +1053,10 @@ def root(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) Returns ------- @@ -1033,8 +1067,8 @@ def root(reference_labels, estimated_labels): """ validate(reference_labels, estimated_labels) - ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] - est_roots = encode_many(estimated_labels, False)[0] + ref_roots, ref_semitones = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords)[:2] + est_roots = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords)[0] comparison_scores = (ref_roots == est_roots).astype(np.float) # Ignore 'X' chords @@ -1042,7 +1076,7 @@ def root(reference_labels, estimated_labels): return comparison_scores -def mirex(reference_labels, estimated_labels): +def mirex(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along MIREX rules. Examples @@ -1069,6 +1103,11 @@ def mirex(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -1079,9 +1118,9 @@ def mirex(reference_labels, estimated_labels): validate(reference_labels, estimated_labels) # TODO(?): Should this be an argument? min_intersection = 3 - ref_data = encode_many(reference_labels, False) + ref_data = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords) ref_chroma = rotate_bitmaps_to_roots(ref_data[1], ref_data[0]) - est_data = encode_many(estimated_labels, False) + est_data = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords) est_chroma = rotate_bitmaps_to_roots(est_data[1], est_data[0]) eq_chroma = (ref_chroma * est_chroma).sum(axis=-1) @@ -1104,7 +1143,7 @@ def mirex(reference_labels, estimated_labels): return comparison_scores -def majmin(reference_labels, estimated_labels): +def majmin(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along major-minor rules. Chords with qualities outside Major/minor/no-chord are ignored. @@ -1132,6 +1171,11 @@ def majmin(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -1144,8 +1188,8 @@ def majmin(reference_labels, estimated_labels): maj_semitones = np.array(QUALITIES['maj'][:8]) min_semitones = np.array(QUALITIES['min'][:8]) - ref_roots, ref_semitones, _ = encode_many(reference_labels, False) - est_roots, est_semitones, _ = encode_many(estimated_labels, False) + ref_roots, ref_semitones, _ = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords) + est_roots, est_semitones, _ = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords) eq_root = ref_roots == est_roots eq_quality = np.all(np.equal(ref_semitones[:, :8], @@ -1170,7 +1214,7 @@ def majmin(reference_labels, estimated_labels): return comparison_scores -def majmin_inv(reference_labels, estimated_labels): +def majmin_inv(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along major-minor rules, with inversions. Chords with qualities outside Major/minor/no-chord are ignored, and the bass note must exist in the triad (bass in [1, 3, 5]). @@ -1199,6 +1243,11 @@ def majmin_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -1211,8 +1260,8 @@ def majmin_inv(reference_labels, estimated_labels): maj_semitones = np.array(QUALITIES['maj'][:8]) min_semitones = np.array(QUALITIES['min'][:8]) - ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, False) - est_roots, est_semitones, est_bass = encode_many(estimated_labels, False) + ref_roots, ref_semitones, ref_bass = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords) + est_roots, est_semitones, est_bass = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords) eq_root_bass = (ref_roots == est_roots) * (ref_bass == est_bass) eq_semitones = np.all(np.equal(ref_semitones[:, :8], @@ -1235,7 +1284,7 @@ def majmin_inv(reference_labels, estimated_labels): return comparison_scores -def sevenths(reference_labels, estimated_labels): +def sevenths(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along MIREX 'sevenths' rules. Chords with qualities outside [maj, maj7, 7, min, min7, N] are ignored. @@ -1263,6 +1312,11 @@ def sevenths(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -1275,8 +1329,8 @@ def sevenths(reference_labels, estimated_labels): seventh_qualities = ['maj', 'min', 'maj7', '7', 'min7', ''] valid_semitones = np.array([QUALITIES[name] for name in seventh_qualities]) - ref_roots, ref_semitones = encode_many(reference_labels, False)[:2] - est_roots, est_semitones = encode_many(estimated_labels, False)[:2] + ref_roots, ref_semitones = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords)[:2] + est_roots, est_semitones = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords)[:2] eq_root = ref_roots == est_roots eq_semitones = np.all(np.equal(ref_semitones, est_semitones), axis=1) @@ -1290,7 +1344,7 @@ def sevenths(reference_labels, estimated_labels): return comparison_scores -def sevenths_inv(reference_labels, estimated_labels): +def sevenths_inv(reference_labels, estimated_labels, reduce_extended_chords=False): """Compare chords along MIREX 'sevenths' rules. Chords with qualities outside [maj, maj7, 7, min, min7, N] are ignored. @@ -1318,6 +1372,11 @@ def sevenths_inv(reference_labels, estimated_labels): Reference chord labels to score against. estimated_labels : list, len=n Estimated chord labels to score against. + reduce_extended_chords : bool + Whether to map the upper voicings of extended chords (9's, 11's, 13's) + to semitone extensions. + (Default value = False) + Returns ------- @@ -1330,8 +1389,8 @@ def sevenths_inv(reference_labels, estimated_labels): seventh_qualities = ['maj', 'min', 'maj7', '7', 'min7', ''] valid_semitones = np.array([QUALITIES[name] for name in seventh_qualities]) - ref_roots, ref_semitones, ref_basses = encode_many(reference_labels, False) - est_roots, est_semitones, est_basses = encode_many(estimated_labels, False) + ref_roots, ref_semitones, ref_basses = encode_many(reference_labels, reduce_extended_chords=reduce_extended_chords) + est_roots, est_semitones, est_basses = encode_many(estimated_labels, reduce_extended_chords=reduce_extended_chords) eq_roots_basses = (ref_roots == est_roots) * (ref_basses == est_basses) eq_semitones = np.all(np.equal(ref_semitones, est_semitones), axis=1) @@ -1572,33 +1631,32 @@ def evaluate(ref_intervals, ref_labels, est_intervals, est_labels, **kwargs): # Store scores for each comparison function scores = collections.OrderedDict() - scores['thirds'] = weighted_accuracy(thirds(ref_labels, est_labels), + scores['thirds'] = weighted_accuracy(util.filter_kwargs(thirds, ref_labels, est_labels, **kwargs), durations) - scores['thirds_inv'] = weighted_accuracy(thirds_inv(ref_labels, - est_labels), durations) - scores['triads'] = weighted_accuracy(triads(ref_labels, est_labels), + scores['thirds_inv'] = weighted_accuracy(util.filter_kwargs(thirds_inv, ref_labels, est_labels, **kwargs), + durations) + scores['triads'] = weighted_accuracy(util.filter_kwargs(triads, ref_labels, est_labels, **kwargs), durations) - scores['triads_inv'] = weighted_accuracy(triads_inv(ref_labels, - est_labels), durations) - scores['tetrads'] = weighted_accuracy(tetrads(ref_labels, est_labels), + scores['triads_inv'] = weighted_accuracy(util.filter_kwargs(triads_inv, ref_labels, est_labels, **kwargs), + durations) + scores['tetrads'] = weighted_accuracy(util.filter_kwargs(tetrads, ref_labels, est_labels, **kwargs), durations) - scores['tetrads_inv'] = weighted_accuracy(tetrads_inv(ref_labels, - est_labels), + scores['tetrads_inv'] = weighted_accuracy(util.filter_kwargs(tetrads_inv, ref_labels, est_labels, **kwargs), durations) - scores['root'] = weighted_accuracy(root(ref_labels, est_labels), durations) - scores['mirex'] = weighted_accuracy(mirex(ref_labels, est_labels), + scores['root'] = weighted_accuracy(util.filter_kwargs(root, ref_labels, est_labels, **kwargs), + durations) + scores['mirex'] = weighted_accuracy(util.filter_kwargs(mirex, ref_labels, est_labels, **kwargs), durations) - scores['majmin'] = weighted_accuracy(majmin(ref_labels, est_labels), + scores['majmin'] = weighted_accuracy(util.filter_kwargs(majmin, ref_labels, est_labels, **kwargs), durations) - scores['majmin_inv'] = weighted_accuracy(majmin_inv(ref_labels, - est_labels), durations) - scores['sevenths'] = weighted_accuracy(sevenths(ref_labels, est_labels), + scores['majmin_inv'] = weighted_accuracy(util.filter_kwargs(majmin_inv, ref_labels, est_labels, **kwargs), + durations) + scores['sevenths'] = weighted_accuracy(util.filter_kwargs(sevenths, ref_labels, est_labels, **kwargs), durations) - scores['sevenths_inv'] = weighted_accuracy(sevenths_inv(ref_labels, - est_labels), + scores['sevenths_inv'] = weighted_accuracy(util.filter_kwargs(sevenths_inv, ref_labels, est_labels, **kwargs), durations) - scores['underseg'] = underseg(merged_ref_intervals, merged_est_intervals) - scores['overseg'] = overseg(merged_ref_intervals, merged_est_intervals) + scores['underseg'] = util.filter_kwargs(underseg, merged_ref_intervals, merged_est_intervals, **kwargs) + scores['overseg'] = util.filter_kwargs(overseg, merged_ref_intervals, merged_est_intervals, **kwargs) scores['seg'] = min(scores['overseg'], scores['underseg']) return scores diff --git a/tests/test_chord.py b/tests/test_chord.py index 8455fa8c..64c7b700 100644 --- a/tests/test_chord.py +++ b/tests/test_chord.py @@ -591,3 +591,47 @@ def test_validate(): # Test that error is thrown on different-length labels nose.tools.assert_raises( ValueError, mir_eval.chord.validate, [], ['C']) + + +def test_chord_eval_reduce(): + # Test whether chord evaluation with extension reductions behaves as expected + # https://github.com/craffel/mir_eval/issues/274 + + # The following two chords are enharmonically equivalent when extensions + # are reduced, and inequivalent otherwise. + # We should get perfect agreement on all metrics if reduce_extended_chords=True + # and disagreement otherwise + intervals = np.asarray([[0, 1]]) + + # chord1 encodes to (1, array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0]), 0) + chord1 = ['C#:7(b9,#9)'] + + # chord2 encodes to (1, array([1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0]), 0) + chord2 = ['C#:min7(b9,b11)'] + + # With extension reductions, both chords encode as + # (1, array([1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0]), 0) + + # without reduction, we have agreement at root, fifth, and seventh + # metrics with perfect score: root, mirex, seg, underseg, overseg + # metrics with 0 score: everything else + score_default = mir_eval.chord.evaluate(intervals, chord1, intervals, chord2) + score_no_reduction = mir_eval.chord.evaluate(intervals, chord1, intervals, chord2, reduce_extended_chords=False) + assert score_default == score_no_reduction + + perfects = set(['root', 'mirex', 'seg', 'underseg', 'overseg']) + for metric in score_no_reduction: + if metric in perfects: + assert score_no_reduction[metric] == 1 + else: + assert score_no_reduction[metric] == 0 + + score_reduction = mir_eval.chord.evaluate(intervals, chord1, intervals, chord2, reduce_extended_chords=True) + + # with reduction, all scores should be 1 except majmin, majmin_inv + print(score_reduction) + for metric in score_reduction: + if 'majmin' in metric: + assert score_reduction[metric] == 0, metric + else: + assert score_reduction[metric] == 1, metric