From f293b9b31afe60ef06e572767ba5eff04579f287 Mon Sep 17 00:00:00 2001 From: Bobak Hashemi Date: Wed, 20 Dec 2023 13:28:57 -0800 Subject: [PATCH] Fix warning in aggregation.mean (#187) Summary: Pull Request resolved: https://github.com/pytorch/torcheval/pull/187 This diff fixes the incorrect warning when running `mean.compute()` when the mean is exactly 0. Instead of checking for the weighted sum of elements to be 0, we instead check for the total sum of weights to be zero (meaning that the average can be 0 without error, but we throw a warning when dividing by zero) We also update the error message to reflect that the issue is no weight has been accumulated, since it is possible to call this function with only 0 weights. Addresses: https://github.com/pytorch/torcheval/issues/185 Reviewed By: JKSenthil Differential Revision: D50806243 fbshipit-source-id: 04d75826ae8c1a24cc3718967d86bdd982081538 --- tests/metrics/aggregation/test_mean.py | 4 ++++ torcheval/metrics/aggregation/mean.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/metrics/aggregation/test_mean.py b/tests/metrics/aggregation/test_mean.py index 1b53fd17..72e7c765 100644 --- a/tests/metrics/aggregation/test_mean.py +++ b/tests/metrics/aggregation/test_mean.py @@ -54,6 +54,10 @@ def test_mean_class_compute_without_update(self) -> None: metric = Mean() self.assertEqual(metric.compute(), torch.tensor(0.0, dtype=torch.float64)) + metric = Mean() + metric.update(torch.tensor([0.0, 0.0]), weight=0) + self.assertEqual(metric.compute(), torch.tensor(0.0, dtype=torch.float64)) + def test_mean_class_update_input_valid_weight(self) -> None: update_value = [ torch.rand(BATCH_SIZE), diff --git a/torcheval/metrics/aggregation/mean.py b/torcheval/metrics/aggregation/mean.py index dbea9321..bb897e17 100644 --- a/torcheval/metrics/aggregation/mean.py +++ b/torcheval/metrics/aggregation/mean.py @@ -55,9 +55,11 @@ def __init__( device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) + # weighted sum of values over the entire state self._add_state( "weighted_sum", torch.tensor(0.0, device=self.device, dtype=torch.float64) ) + # sum total of weights over the entire state self._add_state( "weights", torch.tensor(0.0, device=self.device, dtype=torch.float64) ) @@ -82,9 +84,9 @@ def update( ValueError: If value of weight is neither a ``float`` nor a ``int'' nor a ``torch.Tensor`` that matches the input tensor size. """ - weighted_sum, weights = _mean_update(input, weight) + weighted_sum, net_weight = _mean_update(input, weight) self.weighted_sum += weighted_sum - self.weights += weights + self.weights += net_weight return self @torch.inference_mode() @@ -93,8 +95,10 @@ def compute(self: TMean) -> torch.Tensor: If no calls to ``update()`` are made before ``compute()`` is called, the function throws a warning and returns 0.0. """ - if not self.weighted_sum: - logging.warning("No calls to update() have been made - returning 0.0") + if not torch.is_nonzero(self.weights): + logging.warning( + "There is no weight for the average, no samples with weight have been added (did you ever run update()?)- returning 0.0" + ) return torch.tensor(0.0, dtype=torch.float64) return self.weighted_sum / self.weights