Skip to content

Commit

Permalink
array API support for mean_absolute_percentage_error (scikit-learn#29300
Browse files Browse the repository at this point in the history
)
  • Loading branch information
EmilyXinyi authored Jul 12, 2024
1 parent 1813b4a commit dc6c01c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ Metrics
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.max_error`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_absolute_percentage_error`
- :func:`sklearn.metrics.mean_gamma_deviance`
- :func:`sklearn.metrics.mean_squared_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
Expand Down
3 changes: 2 additions & 1 deletion doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ See :ref:`array_api` for more details.
- :func:`sklearn.metrics.max_error` :pr:`29212` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`
and :pr:`29143` by :user:`Tialo <Tialo>` and :user:`Loïc Estève <lesteve>`;
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :usser:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.mean_absolute_percentage_error` :pr:`29300` by :user:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :user:`Emily Chen <EmilyXinyi>`;
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
Expand Down
19 changes: 14 additions & 5 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,21 +395,30 @@ def mean_absolute_percentage_error(
>>> mean_absolute_percentage_error(y_true, y_pred)
112589990684262.48
"""
input_arrays = [y_true, y_pred, sample_weight, multioutput]
xp, _ = get_namespace(*input_arrays)
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)

y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
)
check_consistent_length(y_true, y_pred, sample_weight)
epsilon = np.finfo(np.float64).eps
mape = np.abs(y_pred - y_true) / np.maximum(np.abs(y_true), epsilon)
output_errors = np.average(mape, weights=sample_weight, axis=0)
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=dtype)
y_true_abs = xp.asarray(xp.abs(y_true), dtype=dtype)
mape = xp.asarray(xp.abs(y_pred - y_true), dtype=dtype) / xp.maximum(
y_true_abs, epsilon
)
output_errors = _average(mape, weights=sample_weight, axis=0)
if isinstance(multioutput, str):
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
# pass None as weights to _average: uniform mean
multioutput = None

return np.average(output_errors, weights=multioutput)
mean_absolute_percentage_error = _average(output_errors, weights=multioutput)
assert mean_absolute_percentage_error.shape == ()
return float(mean_absolute_percentage_error)


@validate_params(
Expand Down
4 changes: 4 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
additive_chi2_kernel: [check_array_api_metric_pairwise],
mean_gamma_deviance: [check_array_api_regression_metric],
max_error: [check_array_api_regression_metric],
mean_absolute_percentage_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
chi2_kernel: [check_array_api_metric_pairwise],
cosine_distances: [check_array_api_metric_pairwise],
euclidean_distances: [check_array_api_metric_pairwise],
Expand Down

0 comments on commit dc6c01c

Please sign in to comment.