diff --git a/CHANGELOG.md b/CHANGELOG.md index efb575fa2d4..3a9cd6e4be8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) +- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492)) + + - Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498)) diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index d61b68d97ef..f7da4aa8ba7 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -124,6 +124,10 @@ A few important things to note for this example: ``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless of the mode. +* Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care + must be taken when referencing list states. If you require the values after your metric is reset, you must first + copy the attribute to another object (e.g. using `deepcopy.copy`). + ***************** Metric attributes ***************** diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 7437fc35f95..6d64955198d 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -238,6 +238,12 @@ def add_state( When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow the format discussed in the above note. + Note: + The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows + device memory to be automatically reallocated, but may produce unexpected effects when referencing list + states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another + object. + Raises: ValueError: If ``default`` is not a ``tensor`` or an ``empty list``. @@ -325,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: self.compute_on_cpu = False # save context before switch - cache = {attr: getattr(self, attr) for attr in self._defaults} + cache = self._copy_state_dict() # call reset, update, compute, on single batch self._enable_grad = True # allow grads for batch computation @@ -358,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """ # store global state and reset to default - global_state = {attr: getattr(self, attr) for attr in self._defaults} + global_state = self._copy_state_dict() _update_count = self._update_count self.reset() @@ -525,7 +531,7 @@ def sync( dist_sync_fn = gather_all_tensors # cache prior to syncing - self._cache = {attr: getattr(self, attr) for attr in self._defaults} + self._cache = self._copy_state_dict() # sync self._sync_dist(dist_sync_fn, process_group=process_group) @@ -681,7 +687,7 @@ def reset(self) -> None: if isinstance(default, Tensor): setattr(self, attr, default.detach().clone().to(current_val.device)) else: - setattr(self, attr, []) + getattr(self, attr).clear() # delete/free list items # reset internal states self._cache = None @@ -870,6 +876,21 @@ def state_dict( # type: ignore[override] # todo destination[prefix + key] = deepcopy(current_val) return destination + def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: + """Copy the current state values.""" + cache: Dict[str, Union[Tensor, List[Any]]] = {} + for attr in self._defaults: + current_value = getattr(self, attr) + + if isinstance(current_value, Tensor): + cache[attr] = current_value.detach().clone().to(current_value.device) + else: + cache[attr] = [ # safely copy (non-graph leaf) Tensor elements + _.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value + ] + + return cache + def _load_from_state_dict( self, state_dict: dict, diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 22257c0d0c9..753150478e4 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -124,11 +124,17 @@ class B(DummyListMetric): metric = B() assert isinstance(metric.x, list) assert len(metric.x) == 0 - metric.x = tensor(5) + metric.x = [tensor(5)] metric.reset() assert isinstance(metric.x, list) assert len(metric.x) == 0 + metric = B() + metric.x = [1, 2, 3] + reference = metric.x # prevents garbage collection + metric.reset() + assert len(reference) == 0 # check list state is freed + def test_reset_compute(): """Test that `reset`+`compute` methods works as expected.""" @@ -474,11 +480,8 @@ def test_constant_memory_on_repeat_init(): def mem(): return torch.cuda.memory_allocated() / 1024**2 - x = torch.randn(10000).cuda() - for i in range(100): - m = DummyListMetric(compute_with_cache=False).cuda() - m(x) + _ = DummyListMetric(compute_with_cache=False).cuda() if i == 0: after_one_iter = mem() @@ -486,6 +489,25 @@ def mem(): assert after_one_iter * 1.05 >= mem(), "memory increased too much above base level" +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +def test_freed_memory_on_reset(): + """Test that resetting a metric frees all the memory allocated when updating it.""" + + def mem(): + return torch.cuda.memory_allocated() / 1024**2 + + m = DummyListMetric().cuda() + after_init = mem() + + for _ in range(100): + m(x=torch.randn(10000).cuda()) + + m.reset() + + # allow for 5% flucturation due to measuring + assert after_init * 1.05 >= mem(), "memory increased too much above base level" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu") def test_specific_error_on_wrong_device(): """Test that a specific error is raised if we detect input and metric are on different devices."""