Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMSE option to MSE code #249

Merged
merged 8 commits into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for unnormalized scores (e.g. logits) in `Accuracy`, `Precision`, `Recall`, `FBeta`, `F1`, `StatScore`, `Hamming`, `ConfusionMatrix` metrics ([#200](https://github.com/PyTorchLightning/metrics/pull/200))


- Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249))


### Changed


Expand Down
41 changes: 25 additions & 16 deletions tests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import namedtuple
from functools import partial

Expand Down Expand Up @@ -43,16 +44,18 @@
)


def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error):
def _single_target_sk_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_fn(sk_preds, sk_target)
res = sk_fn(sk_preds, sk_target)
return math.sqrt(res) if (metric_args and not metric_args['squared']) else res


def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
def _multi_target_sk_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
return sk_fn(sk_preds, sk_target)
res = sk_fn(sk_preds, sk_target)
return math.sqrt(res) if (metric_args and not metric_args['squared']) else res


@pytest.mark.parametrize(
Expand All @@ -63,55 +66,61 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
],
)
@pytest.mark.parametrize(
"metric_class, metric_functional, sk_fn",
"metric_class, metric_functional, sk_fn, metric_args",
[
(MeanSquaredError, mean_squared_error, sk_mean_squared_error),
(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error),
(MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error),
(MeanSquaredError, mean_squared_error, sk_mean_squared_error, {'squared': True}),
(MeanSquaredError, mean_squared_error, sk_mean_squared_error, {'squared': False}),
(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}),
(MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}),
],
)
class TestMeanError(MetricTester):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_mean_error_class(
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, ddp, dist_sync_on_step
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args, ddp, dist_sync_on_step
):
# todo: `metric_functional` is unused
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=metric_class,
sk_metric=partial(sk_metric, sk_fn=sk_fn),
sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args),
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args
)

def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
def test_mean_error_functional(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args):
# todo: `metric_class` is unused
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=metric_functional,
sk_metric=partial(sk_metric, sk_fn=sk_fn),
sk_metric=partial(sk_metric, sk_fn=sk_fn, metric_args=metric_args),
metric_args=metric_args
)

def test_mean_error_differentiability(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
def test_mean_error_differentiability(
self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args
):
self.run_differentiability_test(
preds=preds, target=target, metric_module=metric_class, metric_functional=metric_functional
preds=preds, target=target, metric_module=metric_class, metric_functional=metric_functional,
metric_args=metric_args
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason='half support of core operations on not support before pytorch v1.6'
)
def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args):
if metric_class == MeanSquaredLogError:
# MeanSquaredLogError half + cpu does not work due to missing support in torch.log
pytest.xfail("MeanSquaredLogError metric does not support cpu + half precision")
self.run_precision_test_cpu(preds, target, metric_class, metric_functional)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn):
def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args):
self.run_precision_test_gpu(preds, target, metric_class, metric_functional)


Expand Down
9 changes: 5 additions & 4 deletions torchmetrics/functional/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,18 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, i
return sum_squared_error, n_obs


def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int) -> Tensor:
return sum_squared_error / n_obs
def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: int, squared: bool = True) -> Tensor:
return sum_squared_error / n_obs if squared else torch.sqrt(sum_squared_error / n_obs)


def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor:
def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor:
"""
Computes mean squared error

Args:
preds: estimated labels
target: ground truth labels
squared: returns RMSE value if set to False

Return:
Tensor with MSE
Expand All @@ -50,4 +51,4 @@ def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor:
tensor(0.2500)
"""
sum_squared_error, n_obs = _mean_squared_error_update(preds, target)
return _mean_squared_error_compute(sum_squared_error, n_obs)
return _mean_squared_error_compute(sum_squared_error, n_obs, squared=squared)
6 changes: 5 additions & 1 deletion torchmetrics/regression/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class MeanSquaredError(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
squared:
If True returns MSE value, if False returns RMSE value.

Example:
>>> from torchmetrics import MeanSquaredError
Expand All @@ -56,6 +58,7 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
squared: bool = True,
):
super().__init__(
compute_on_step=compute_on_step,
Expand All @@ -66,6 +69,7 @@ def __init__(

self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
self.squared = squared

def update(self, preds: Tensor, target: Tensor):
"""
Expand All @@ -84,7 +88,7 @@ def compute(self):
"""
Computes mean squared error over state.
"""
return _mean_squared_error_compute(self.sum_squared_error, self.total)
return _mean_squared_error_compute(self.sum_squared_error, self.total, squared=self.squared)

@property
def is_differentiable(self):
Expand Down