-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
149 lines (120 loc) · 5.1 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# File containing functions used for evaluation of all data
# In practice, only the run_eval function should be used given a string
# variable in the main function pertaining to the mode
# Functions convert to cents and voicing following mirex guidlines:
# http://craffel.github.io/mir_eval/#module-mir_eval.melody
# ----------------- Imports
from mir_eval import melody as mel_eval
import numpy as np
import helpers as hr
# ----------------- Evaluation Functions
def evaluate_model_voicing(test_guesses, test_labels):
'''
If we are only looking at voicing, then we only care about some metrics
Inputs:
1d Boolean np.array containing all predictions made by the model
1d Boolean np.array containing all ground truth labels
Output: Dict containing
'''
ref_voicing = test_labels.astype(bool)
est_voicing = test_guesses.astype(bool)
print('Evaluating voicing...')
vx_recall, vx_false_alarm = mel_eval.voicing_measures(ref_voicing,
est_voicing)
print('Evaluating overall accuracy...')
correct_tries = (ref_voicing == est_voicing)
overall_accuracy = sum(correct_tries)/correct_tries.size
metrics = {
'vx_recall': vx_recall,
'vx_false_alarm ': vx_false_alarm,
'overall_accuracy': overall_accuracy
}
for m, v in metrics.items(): # Python2 is iteritems I think
print(m, ':', v)
return metrics
def evaluate_model_melody(test_guesses, test_labels):
'''
Run standard pitch and chroma evaluations on all test data
Inputs:
1d np.array containing all predictions made by the model
1d np.array containing all ground truth labels
Outputs:
Dict holding results of all evaluations
'''
ref_freq = hr.note_to_hz_zeros(test_labels)
est_freq = hr.note_to_hz_zeros(test_guesses)
ref_cent = mel_eval.hz2cents(ref_freq)
est_cent = mel_eval.hz2cents(est_freq)
all_voiced = np.ones(len(ref_cent), dtype=bool)
print('Evaluating pitch...')
raw_pitch = mel_eval.raw_pitch_accuracy(all_voiced, ref_cent,
all_voiced, est_cent,
cent_tolerance=50)
print('Evaluating chroma...')
raw_chroma = mel_eval.raw_chroma_accuracy(all_voiced, ref_cent,
all_voiced, est_cent,
cent_tolerance=50)
metrics = {
'raw_pitch': raw_pitch,
'raw_chroma': raw_chroma,
}
for m, v in metrics.items():
print(m, ':', v)
return metrics
def evaluate_model_all(test_guesses, test_labels):
'''
Run standard Mirex evaluations on all test data
Inputs:
1d np.array containing all predictions made by the model
1d np.array containing all ground truth labels
Outputs:
Dict holding results of all evaluations
'''
print('Running conversions...')
ref_freq = hr.note_to_hz_zeros(test_labels) # And back to Hz!
est_freq = hr.note_to_hz_zeros(test_guesses)
ref_cent = mel_eval.hz2cents(ref_freq) # Then to cents...
est_cent = mel_eval.hz2cents(est_freq)
ref_voicing = mel_eval.freq_to_voicing(ref_freq)[1] # And voicings!
est_voicing = mel_eval.freq_to_voicing(est_freq)[1]
print('Evaluating voicing...')
vx_recall, vx_false_alarm = mel_eval.voicing_measures(ref_voicing,
est_voicing)
print('Evaluating pitch...')
raw_pitch = mel_eval.raw_pitch_accuracy(ref_voicing, ref_cent,
est_voicing, est_cent,
cent_tolerance=50)
print('Evaluating chroma...')
raw_chroma = mel_eval.raw_chroma_accuracy(ref_voicing, ref_cent,
est_voicing, est_cent,
cent_tolerance=50)
print('Evaluating overall accuracy...')
overall_accuracy = mel_eval.overall_accuracy(ref_voicing, ref_cent,
est_voicing, est_cent,
cent_tolerance=50)
metrics = {
'vx_recall': vx_recall,
'vx_false_alarm ': vx_false_alarm,
'raw_pitch': raw_pitch,
'raw_chroma': raw_chroma,
'overall_accuracy': overall_accuracy
}
for m, v in metrics.items(): # Python2 is iteritems I think
print(m, ':', v)
return metrics
# ----------------- Function Generator
def generate_eval(mode):
'''
Returns the right evaluation function based on the string inputted
Input: String containing ['options' | 'voicing' | 'melody' | 'all']
Output: Evaluation function corresponding to input
'''
evaluations = {
'voicing': evaluate_model_voicing,
'melody': evaluate_model_melody,
'all': evaluate_model_all
}
if mode == 'options':
return {i: k for i, k in enumerate(evaluations)}
else:
return evaluations[mode]