diff --git a/CHANGELOG.md b/CHANGELOG.md index bf3dfe9fb2b..0d7f4377d51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for unnormalized scores (e.g. logits) in `Accuracy`, `Precision`, `Recall`, `FBeta`, `F1`, `StatScore`, `Hamming`, `ConfusionMatrix` metrics ([#200](https://github.com/PyTorchLightning/metrics/pull/200)) +- Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249)) + + ### Changed diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 7009d4fb71b..46dc8ec572e 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from collections import namedtuple from functools import partial @@ -43,16 +44,18 @@ ) -def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error): +def _single_target_sk_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - return sk_fn(sk_preds, sk_target) + res = sk_fn(sk_preds, sk_target) + return math.sqrt(res) if (metric_args and not metric_args['squared']) else res -def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): +def _multi_target_sk_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - return sk_fn(sk_preds, sk_target) + res = sk_fn(sk_preds, sk_target) + return math.sqrt(res) if (metric_args and not metric_args['squared']) else res @pytest.mark.parametrize( @@ -63,11 +66,12 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): ], ) @pytest.mark.parametrize( - "metric_class, metric_functional, sk_fn", + "metric_class, metric_functional, sk_fn, metric_args", [ - (MeanSquaredError, mean_squared_error, sk_mean_squared_error), - (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error), - (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error), + (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {'squared': True}), + (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {'squared': False}), + (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), + (MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}), ], ) class TestMeanError(MetricTester): @@ -75,7 +79,7 @@ class TestMeanError(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_mean_error_class( - self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step + self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args, ddp, dist_sync_on_step ): # todo: `metric_functional` is unused self.run_class_metric_test( @@ -83,35 +87,40 @@ def test_mean_error_class( preds=preds, target=target, metric_class=metric_class, - sk_metric=partial(sk_metric, sk_fn=sk_fn), + sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args), dist_sync_on_step=dist_sync_on_step, + metric_args=metric_args ) - def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args): # todo: `metric_class` is unused self.run_functional_metric_test( preds=preds, target=target, metric_functional=metric_functional, - sk_metric=partial(sk_metric, sk_fn=sk_fn), + sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args), + metric_args=metric_args ) - def test_mean_error_differentiability(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + def test_mean_error_differentiability( + self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args + ): self.run_differentiability_test( - preds=preds, target=target, metric_module=metric_class, metric_functional=metric_functional + preds=preds, target=target, metric_module=metric_class, metric_functional=metric_functional, + metric_args=metric_args ) @pytest.mark.skipif( not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6' ) - def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args): if metric_class == MeanSquaredLogError: # MeanSquaredLogError half + cpu does not work due to missing support in torch.log pytest.xfail("MeanSquaredLogError metric does not support cpu + half precision") self.run_precision_test_cpu(preds, target, metric_class, metric_functional) @pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda') - def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): + def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args): self.run_precision_test_gpu(preds, target, metric_class, metric_functional) diff --git a/torchmetrics/functional/regression/mean_squared_error.py b/torchmetrics/functional/regression/mean_squared_error.py index 225dd7dd509..291f42b6a4e 100644 --- a/torchmetrics/functional/regression/mean_squared_error.py +++ b/torchmetrics/functional/regression/mean_squared_error.py @@ -27,17 +27,18 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, i return sum_squared_error, n_obs -def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int) -> Tensor: - return sum_squared_error / n_obs +def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int, squared: bool = True) -> Tensor: + return sum_squared_error / n_obs if squared else torch.sqrt(sum_squared_error / n_obs) -def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor: +def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor: """ Computes mean squared error Args: preds: estimated labels target: ground truth labels + squared: returns RMSE value if set to False Return: Tensor with MSE @@ -50,4 +51,4 @@ def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor: tensor(0.2500) """ sum_squared_error, n_obs = _mean_squared_error_update(preds, target) - return _mean_squared_error_compute(sum_squared_error, n_obs) + return _mean_squared_error_compute(sum_squared_error, n_obs, squared=squared) diff --git a/torchmetrics/regression/mean_squared_error.py b/torchmetrics/regression/mean_squared_error.py index a1bd8a6a282..0eae049687e 100644 --- a/torchmetrics/regression/mean_squared_error.py +++ b/torchmetrics/regression/mean_squared_error.py @@ -39,6 +39,8 @@ class MeanSquaredError(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + squared: + If True returns MSE value, if False returns RMSE value. Example: >>> from torchmetrics import MeanSquaredError @@ -56,6 +58,7 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + squared: bool = True, ): super().__init__( compute_on_step=compute_on_step, @@ -66,6 +69,7 @@ def __init__( self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + self.squared = squared def update(self, preds: Tensor, target: Tensor): """ @@ -84,7 +88,7 @@ def compute(self): """ Computes mean squared error over state. """ - return _mean_squared_error_compute(self.sum_squared_error, self.total) + return _mean_squared_error_compute(self.sum_squared_error, self.total, squared=self.squared) @property def is_differentiable(self):