From 0278114e4c854c205faeae0d61a72addd4868638 Mon Sep 17 00:00:00 2001 From: Johannes Pitz Date: Fri, 14 May 2021 14:58:03 +0200 Subject: [PATCH 1/6] Add RMSE option to MSE code --- .../functional/regression/mean_squared_error.py | 12 ++++++++---- torchmetrics/regression/mean_squared_error.py | 6 +++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/torchmetrics/functional/regression/mean_squared_error.py b/torchmetrics/functional/regression/mean_squared_error.py index 225dd7dd509..72c95499158 100644 --- a/torchmetrics/functional/regression/mean_squared_error.py +++ b/torchmetrics/functional/regression/mean_squared_error.py @@ -19,10 +19,13 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _mean_squared_error_update(preds: Tensor, target: Tensor, squared: bool = True) -> Tuple[Tensor, int]: _check_same_shape(preds, target) diff = preds - target - sum_squared_error = torch.sum(diff * diff) + squared_error = diff * diff + if not squared: + squared_error = squared_error ** 0.5 + sum_squared_error = torch.sum(squared_error) n_obs = target.numel() return sum_squared_error, n_obs @@ -31,13 +34,14 @@ def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int) -> Tensor return sum_squared_error / n_obs -def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor: +def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor: """ Computes mean squared error Args: preds: estimated labels target: ground truth labels + squared: returns RMSE value if set to False Return: Tensor with MSE @@ -49,5 +53,5 @@ def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor: >>> mean_squared_error(x, y) tensor(0.2500) """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + sum_squared_error, n_obs = _mean_squared_error_update(preds, target, squared=squared) return _mean_squared_error_compute(sum_squared_error, n_obs) diff --git a/torchmetrics/regression/mean_squared_error.py b/torchmetrics/regression/mean_squared_error.py index a1bd8a6a282..66dd3df6a1a 100644 --- a/torchmetrics/regression/mean_squared_error.py +++ b/torchmetrics/regression/mean_squared_error.py @@ -39,6 +39,8 @@ class MeanSquaredError(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + squared: + If True returns MSE value, if False returns RMSE value. Example: >>> from torchmetrics import MeanSquaredError @@ -56,6 +58,7 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + squared: bool = True, ): super().__init__( compute_on_step=compute_on_step, @@ -66,6 +69,7 @@ def __init__( self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + self.squared = squared def update(self, preds: Tensor, target: Tensor): """ @@ -75,7 +79,7 @@ def update(self, preds: Tensor, target: Tensor): preds: Predictions from model target: Ground truth values """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + sum_squared_error, n_obs = _mean_squared_error_update(preds, target, squared=self.squared) self.sum_squared_error += sum_squared_error self.total += n_obs From b978383fa66b058535dc322727137903563bd9cd Mon Sep 17 00:00:00 2001 From: Johannes Pitz Date: Mon, 17 May 2021 11:20:59 +0200 Subject: [PATCH 2/6] Fix RMSE computation --- .../functional/regression/mean_squared_error.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchmetrics/functional/regression/mean_squared_error.py b/torchmetrics/functional/regression/mean_squared_error.py index 72c95499158..d5375535a6e 100644 --- a/torchmetrics/functional/regression/mean_squared_error.py +++ b/torchmetrics/functional/regression/mean_squared_error.py @@ -19,13 +19,10 @@ from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_error_update(preds: Tensor, target: Tensor, squared: bool = True) -> Tuple[Tensor, int]: +def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: _check_same_shape(preds, target) diff = preds - target - squared_error = diff * diff - if not squared: - squared_error = squared_error ** 0.5 - sum_squared_error = torch.sum(squared_error) + sum_squared_error = torch.sum(diff * diff) n_obs = target.numel() return sum_squared_error, n_obs @@ -53,5 +50,8 @@ def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> T >>> mean_squared_error(x, y) tensor(0.2500) """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target, squared=squared) - return _mean_squared_error_compute(sum_squared_error, n_obs) + sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + output_error = _mean_squared_error_compute(sum_squared_error, n_obs) + if not squared: + output_error = torch.sqrt(mean_squared_error) + return output_error From 09311487a6ddc533c9b76f300de02b053c6cdc3f Mon Sep 17 00:00:00 2001 From: Johannes Pitz Date: Mon, 17 May 2021 11:24:51 +0200 Subject: [PATCH 3/6] Move sqrt to compute --- torchmetrics/functional/regression/mean_squared_error.py | 9 ++++----- torchmetrics/regression/mean_squared_error.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/torchmetrics/functional/regression/mean_squared_error.py b/torchmetrics/functional/regression/mean_squared_error.py index d5375535a6e..f17c2eb2df2 100644 --- a/torchmetrics/functional/regression/mean_squared_error.py +++ b/torchmetrics/functional/regression/mean_squared_error.py @@ -27,7 +27,9 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, i return sum_squared_error, n_obs -def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int) -> Tensor: +def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int, squared: bool = True) -> Tensor: + if not squared: + sum_squared_error = torch.sqrt(sum_squared_error) return sum_squared_error / n_obs @@ -51,7 +53,4 @@ def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> T tensor(0.2500) """ sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - output_error = _mean_squared_error_compute(sum_squared_error, n_obs) - if not squared: - output_error = torch.sqrt(mean_squared_error) - return output_error + return _mean_squared_error_compute(sum_squared_error, n_obs, squared=squared) diff --git a/torchmetrics/regression/mean_squared_error.py b/torchmetrics/regression/mean_squared_error.py index 66dd3df6a1a..0eae049687e 100644 --- a/torchmetrics/regression/mean_squared_error.py +++ b/torchmetrics/regression/mean_squared_error.py @@ -79,7 +79,7 @@ def update(self, preds: Tensor, target: Tensor): preds: Predictions from model target: Ground truth values """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target, squared=self.squared) + sum_squared_error, n_obs = _mean_squared_error_update(preds, target) self.sum_squared_error += sum_squared_error self.total += n_obs @@ -88,7 +88,7 @@ def compute(self): """ Computes mean squared error over state. """ - return _mean_squared_error_compute(self.sum_squared_error, self.total) + return _mean_squared_error_compute(self.sum_squared_error, self.total, squared=self.squared) @property def is_differentiable(self): From d8eb2566bd4d3e4d15998a30a5df96d003b70447 Mon Sep 17 00:00:00 2001 From: Johannes Pitz Date: Mon, 17 May 2021 11:32:17 +0200 Subject: [PATCH 4/6] Fix division by n_obs --- torchmetrics/functional/regression/mean_squared_error.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchmetrics/functional/regression/mean_squared_error.py b/torchmetrics/functional/regression/mean_squared_error.py index f17c2eb2df2..291f42b6a4e 100644 --- a/torchmetrics/functional/regression/mean_squared_error.py +++ b/torchmetrics/functional/regression/mean_squared_error.py @@ -28,9 +28,7 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, i def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int, squared: bool = True) -> Tensor: - if not squared: - sum_squared_error = torch.sqrt(sum_squared_error) - return sum_squared_error / n_obs + return sum_squared_error / n_obs if squared else torch.sqrt(sum_squared_error / n_obs) def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor: From c503c55d070b81ea628ca42daef177d3e9441bce Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 17 May 2021 18:13:52 +0200 Subject: [PATCH 5/6] add tests --- CHANGELOG.md | 3 +++ tests/regression/test_mean_error.py | 41 ++++++++++++++++++----------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf3dfe9fb2b..0d7f4377d51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for unnormalized scores (e.g. logits) in `Accuracy`, `Precision`, `Recall`, `FBeta`, `F1`, `StatScore`, `Hamming`, `ConfusionMatrix` metrics ([#200](https://github.com/PyTorchLightning/metrics/pull/200)) +- Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249)) + + ### Changed diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 7009d4fb71b..6c4a6445b71 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -13,6 +13,7 @@ # limitations under the License. from collections import namedtuple from functools import partial +import math import pytest import torch @@ -43,16 +44,18 @@ ) -def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error): +def _single_target_sk_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - return sk_fn(sk_preds, sk_target) + res = sk_fn(sk_preds, sk_target) + return math.sqrt(res) if (metric_args and not metric_args['squared']) else res -def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): +def _multi_target_sk_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_preds, sk_target) + res = sk_fn(sk_preds, sk_target) + return math.sqrt(res) if (metric_args and not metric_args['squared']) else res @pytest.mark.parametrize( @@ -63,11 +66,12 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): ], ) @pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn", + "metric_class, metric_functional, sk_fn, metric_args", [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error), + (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {'squared': True}), + (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {'squared': False}), + (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), + (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}), ], ) class TestMeanError(MetricTester): @@ -75,7 +79,7 @@ class TestMeanError(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_mean_error_class( - self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step + self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args, ddp, dist_sync_on_step ): # todo: `metric_functional` is unused self.run_class_metric_test( @@ -83,35 +87,40 @@ def test_mean_error_class( preds=preds, target=target, metric_class=metric_class, - sk_metric=partial(sk_metric, sk_fn=sk_fn), + sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args), dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args ) - def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args): # todo: `metric_class` is unused self.run_functional_metric_test( preds=preds, target=target, metric_functional=metric_functional, - sk_metric=partial(sk_metric, sk_fn=sk_fn), + sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args), + metric_args=metric_args ) - def test_mean_error_differentiability(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + def test_mean_error_differentiability( + self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args + ): self.run_differentiability_test( - preds=preds, target=target, metric_module=metric_class, metric_functional=metric_functional + preds=preds, target=target, metric_module=metric_class, metric_functional=metric_functional, + metric_args=metric_args ) @pytest.mark.skipif( not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' ) - def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args): if metric_class == MeanSquaredLogError: # MeanSquaredLogError half + cpu does not work due to missing support in torch.log pytest.xfail("MeanSquaredLogError metric does not support cpu + half precision") self.run_precision_test_cpu(preds, target, metric_class, metric_functional) @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') - def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args): self.run_precision_test_gpu(preds, target, metric_class, metric_functional) From 98d685afbcc39c02586db60e0442c6f3af1ccea5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 May 2021 16:14:43 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/regression/test_mean_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 6c4a6445b71..46dc8ec572e 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from collections import namedtuple from functools import partial -import math import pytest import torch