Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validation warnings set stacklevel=2 #377

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions mir_eval/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def absolute_error(reference_timestamps, estimated_timestamps):
return np.median(deviations), np.mean(deviations)


def percentage_correct(reference_timestamps, estimated_timestamps, window=0.3):
def percentage_correct(
reference_timestamps, estimated_timestamps, window=0.3, safe=True
):
"""Compute the percentage of correctly predicted timestamps. A timestamp is predicted
correctly if its position doesn't deviate more than the window parameter from the ground
truth timestamp.
Expand All @@ -161,19 +163,27 @@ def percentage_correct(reference_timestamps, estimated_timestamps, window=0.3):
window : float
Window size, in seconds
(Default value = .3)
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
pc : float
Percentage of correct timestamps
"""
validate(reference_timestamps, estimated_timestamps)
if safe:
validate(reference_timestamps, estimated_timestamps)

deviations = np.abs(reference_timestamps - estimated_timestamps)
return np.mean(deviations <= window)


def percentage_correct_segments(
reference_timestamps, estimated_timestamps, duration: Optional[float] = None
reference_timestamps,
estimated_timestamps,
duration: Optional[float] = None,
safe=True,
):
"""Calculate the percentage of correct segments (PCS) metric.

Expand Down Expand Up @@ -215,13 +225,17 @@ def percentage_correct_segments(
duration : float
Optional. Total duration of audio (seconds). WARNING: Metric is computed differently
depending on whether this is provided or not - see documentation above!
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
pcs : float
Percentage of time where ground truth and predicted segments overlap
"""
validate(reference_timestamps, estimated_timestamps)
if safe:
validate(reference_timestamps, estimated_timestamps)
if duration is not None:
duration = float(duration)
if duration <= 0:
Expand Down Expand Up @@ -266,7 +280,7 @@ def percentage_correct_segments(
return overlap_duration / duration


def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps):
def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps, safe=True):
"""Metric based on human synchronicity perception as measured in the paper
"User-centered evaluation of lyrics to audio alignment" [#lizemasclef2021]

Expand All @@ -288,13 +302,17 @@ def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps):
reference timestamps, in seconds
estimated_timestamps : np.ndarray
estimated timestamps, in seconds
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
perceptual_score : float
Perceptual score, averaged over all timestamps
"""
validate(reference_timestamps, estimated_timestamps)
if safe:
validate(reference_timestamps, estimated_timestamps)
offsets = estimated_timestamps - reference_timestamps

# Score offsets using a certain skewed normal distribution
Expand All @@ -309,7 +327,7 @@ def karaoke_perceptual_metric(reference_timestamps, estimated_timestamps):
return np.mean(perceptual_scores)


def evaluate(reference_timestamps, estimated_timestamps, **kwargs):
def evaluate(reference_timestamps, estimated_timestamps, safe=True, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.

Examples
Expand All @@ -328,13 +346,22 @@ def evaluate(reference_timestamps, estimated_timestamps, **kwargs):
**kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
scores : dict
Dictionary of scores, where the key is the metric name (str) and
the value is the (float) score achieved.
"""
if safe:
validate(reference_timestamps, estimated_timestamps)

# We can now bypass validation
kwargs["safe"] = False

# Compute all metrics
scores = collections.OrderedDict()

Expand Down
75 changes: 57 additions & 18 deletions mir_eval/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def validate(reference_beats, estimated_beats):
# If reference or estimated beats are empty,
# warn because metric will be 0
if reference_beats.size == 0:
warnings.warn("Reference beats are empty.")
warnings.warn("Reference beats are empty.", stacklevel=2)
if estimated_beats.size == 0:
warnings.warn("Estimated beats are empty.")
warnings.warn("Estimated beats are empty.", stacklevel=2)
for beats in [reference_beats, estimated_beats]:
util.validate_events(beats, MAX_TIME)

Expand Down Expand Up @@ -133,7 +133,7 @@ def _get_reference_beat_variations(reference_beats):
)


def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07):
def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07, safe=True):
"""Compute the F-measure of correct vs incorrectly predicted beats.
"Correctness" is determined over a small window.

Expand All @@ -144,7 +144,7 @@ def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07):
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> f_measure = mir_eval.beat.f_measure(reference_beats,
estimated_beats)
... estimated_beats)

Parameters
----------
Expand All @@ -155,14 +155,18 @@ def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07):
f_measure_threshold : float
Window size, in seconds
(Default value = 0.07)
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
f_score : float
The computed F-measure score

"""
validate(reference_beats, estimated_beats)
if safe:
validate(reference_beats, estimated_beats)
# When estimated beats are empty, no beats are correct; metric is 0
if estimated_beats.size == 0 or reference_beats.size == 0:
return 0.0
Expand All @@ -174,7 +178,7 @@ def f_measure(reference_beats, estimated_beats, f_measure_threshold=0.07):
return util.f_measure(precision, recall)


def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04):
def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04, safe=True):
"""Cemgil's score, computes a gaussian error of each estimated beat.
Compares against the original beat times and all metrical variations.

Expand All @@ -185,7 +189,7 @@ def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04):
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> cemgil_score, cemgil_max = mir_eval.beat.cemgil(reference_beats,
estimated_beats)
... estimated_beats)

Parameters
----------
Expand All @@ -196,6 +200,9 @@ def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04):
cemgil_sigma : float
Sigma parameter of gaussian error windows
(Default value = 0.04)
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
Expand All @@ -204,7 +211,8 @@ def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04):
cemgil_max : float
The best Cemgil score for all metrical variations
"""
validate(reference_beats, estimated_beats)
if safe:
validate(reference_beats, estimated_beats)
# When estimated beats are empty, no beats are correct; metric is 0
if estimated_beats.size == 0 or reference_beats.size == 0:
return 0.0, 0.0
Expand All @@ -228,7 +236,12 @@ def cemgil(reference_beats, estimated_beats, cemgil_sigma=0.04):


def goto(
reference_beats, estimated_beats, goto_threshold=0.35, goto_mu=0.2, goto_sigma=0.2
reference_beats,
estimated_beats,
goto_threshold=0.35,
goto_mu=0.2,
goto_sigma=0.2,
safe=True,
):
"""Calculate Goto's score, a binary 1 or 0 depending on some specific
heuristic criteria
Expand Down Expand Up @@ -258,13 +271,17 @@ def goto(
The std of the beat errors in the continuously correct track must
be less than this
(Default value = 0.2)
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
goto_score : float
Either 1.0 or 0.0 if some specific criteria are met
"""
validate(reference_beats, estimated_beats)
if safe:
validate(reference_beats, estimated_beats)
# When estimated beats are empty, no beats are correct; metric is 0
if estimated_beats.size == 0 or reference_beats.size == 0:
return 0.0
Expand Down Expand Up @@ -327,7 +344,7 @@ def goto(
return 1.0 * (goto_criteria == 3)


def p_score(reference_beats, estimated_beats, p_score_threshold=0.2):
def p_score(reference_beats, estimated_beats, p_score_threshold=0.2, safe=True):
"""Get McKinney's P-score.
Based on the autocorrelation of the reference and estimated beats

Expand All @@ -349,14 +366,18 @@ def p_score(reference_beats, estimated_beats, p_score_threshold=0.2):
Window size will be
``p_score_threshold*np.median(inter_annotation_intervals)``,
(Default value = 0.2)
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
correlation : float
McKinney's P-score

"""
validate(reference_beats, estimated_beats)
if safe:
validate(reference_beats, estimated_beats)
# Warn when only one beat is provided for either estimated or reference,
# report a warning
if reference_beats.size == 1:
Expand Down Expand Up @@ -412,6 +433,7 @@ def continuity(
estimated_beats,
continuity_phase_threshold=0.175,
continuity_period_threshold=0.175,
safe=True,
):
"""Get metrics based on how much of the estimated beat sequence is
continually correct.
Expand Down Expand Up @@ -439,6 +461,9 @@ def continuity(
Allowable distance between the inter-beat-interval
and the inter-annotation-interval
(Default value = 0.175)
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
Expand All @@ -451,7 +476,8 @@ def continuity(
AMLt : float
Any metric level, total accuracy (continuity not required)
"""
validate(reference_beats, estimated_beats)
if safe:
validate(reference_beats, estimated_beats)
# Warn when only one beat is provided for either estimated or reference,
# report a warning
if reference_beats.size == 1:
Expand Down Expand Up @@ -583,7 +609,7 @@ def continuity(
)


def information_gain(reference_beats, estimated_beats, bins=41):
def information_gain(reference_beats, estimated_beats, bins=41, safe=True):
"""Get the information gain - K-L divergence of the beat error histogram
to a uniform histogram

Expand All @@ -594,7 +620,7 @@ def information_gain(reference_beats, estimated_beats, bins=41):
>>> estimated_beats = mir_eval.io.load_events('estimated.txt')
>>> estimated_beats = mir_eval.beat.trim_beats(estimated_beats)
>>> information_gain = mir_eval.beat.information_gain(reference_beats,
estimated_beats)
... estimated_beats)

Parameters
----------
Expand All @@ -605,13 +631,17 @@ def information_gain(reference_beats, estimated_beats, bins=41):
bins : int
Number of bins in the beat error histogram
(Default value = 41)
safe : bool
If True, validate inputs.
If False, skip validation of inputs.

Returns
-------
information_gain_score : float
Entropy of beat error histogram
"""
validate(reference_beats, estimated_beats)
if safe:
validate(reference_beats, estimated_beats)
# If an even number of bins is provided,
# there will be no bin centered at zero, so warn the user.
if not bins % 2:
Expand Down Expand Up @@ -712,7 +742,7 @@ def _get_entropy(reference_beats, estimated_beats, bins):
return -np.sum(raw_bin_values * np.log2(raw_bin_values))


def evaluate(reference_beats, estimated_beats, **kwargs):
def evaluate(reference_beats, estimated_beats, safe=True, **kwargs):
"""Compute all metrics for the given reference and estimated annotations.

Examples
Expand All @@ -727,6 +757,9 @@ def evaluate(reference_beats, estimated_beats, **kwargs):
Reference beat times, in seconds
estimated_beats : np.ndarray
Query beat times, in seconds
safe : bool
If True, validate inputs.
If False, skip validation of inputs.
**kwargs
Additional keyword arguments which will be passed to the
appropriate metric or preprocessing functions.
Expand All @@ -742,8 +775,14 @@ def evaluate(reference_beats, estimated_beats, **kwargs):
reference_beats = util.filter_kwargs(trim_beats, reference_beats, **kwargs)
estimated_beats = util.filter_kwargs(trim_beats, estimated_beats, **kwargs)

# Now compute all the metrics
# Validate inputs
if safe:
validate(reference_beats, estimated_beats)

# We can now bypass validation
kwargs["safe"] = False

# Now compute all the metrics
scores = collections.OrderedDict()

# F-Measure
Expand Down
Loading