From 11df0ebe897e6dba726ee9e1269daae19557ca5a Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 6 Apr 2024 23:37:17 +0100 Subject: [PATCH 01/17] Clear (i.e. delete) list state items, not simply overwrite. Previous behaviour produced memory leak from list[Tensor] states --- src/torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index c5e6999e89e..573a8d808b9 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -681,7 +681,7 @@ def reset(self) -> None: if isinstance(default, Tensor): setattr(self, attr, default.detach().clone().to(current_val.device)) else: - setattr(self, attr, []) + getattr(self, attr).clear() # delete/free list items # reset internal states self._cache = None From 1fa7077087844a7e61ff512dfd9474e2a1197b94 Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sun, 7 Apr 2024 02:14:00 +0100 Subject: [PATCH 02/17] Added test to check list states elements are deleted (even when referenced, and hence not automatically garbage collected). Fixed failing test (want to check list state, but assigned Tensor) --- tests/unittests/bases/test_metric.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 22257c0d0c9..0870988beb3 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -124,11 +124,17 @@ class B(DummyListMetric): metric = B() assert isinstance(metric.x, list) assert len(metric.x) == 0 - metric.x = tensor(5) + metric.x = [tensor(5)] metric.reset() assert isinstance(metric.x, list) assert len(metric.x) == 0 + metric = B() + metric.x = [1, 2, 3] + reference = metric.x # prevents garbage collection + metric.reset() + assert len(reference) == 0 # check list state is freed + def test_reset_compute(): """Test that `reset`+`compute` methods works as expected.""" From 4b5c09915f47cf6b950e87b86fee00be94679643 Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Tue, 9 Apr 2024 01:18:57 +0100 Subject: [PATCH 03/17] Updated documentation - highlighted reset clears list states, and that care must be taken when referencing them --- docs/source/pages/implement.rst | 11 +++++++++++ src/torchmetrics/metric.py | 5 +++++ 2 files changed, 16 insertions(+) diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index d61b68d97ef..65133f20229 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -124,6 +124,17 @@ A few important things to note for this example: ``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless of the mode. +* Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care + must be taken when referencing list states. If you require the values after your metric is reset, you must first + copy the attribute to another object: + + .. testcode:: + + x = metric.list_state # referenced (and deleted by reset) + + from deepcopy import copy + y = copy(metric.list_state) # copied (and unchanged by reset) + ***************** Metric attributes ***************** diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 573a8d808b9..2d2796cd358 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -238,6 +238,11 @@ def add_state( When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow the format discussed in the above note. + Note: + The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows + device memory to be automatically reallocated, but may produce unexpected effects when referencing list + states. To retain such values after `~Metric.reset` is called, you must first copy them to another object + Raises: ValueError: If ``default`` is not a ``tensor`` or an ``empty list``. From 8bf151b88c280ce304f8559ebd531c21b9133f1f Mon Sep 17 00:00:00 2001 From: Dominic Kerr Date: Tue, 9 Apr 2024 02:01:08 +0100 Subject: [PATCH 04/17] Add missing method (sphinx) role --- src/torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 2d2796cd358..4e5ffd82d44 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -241,7 +241,7 @@ def add_state( Note: The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list - states. To retain such values after `~Metric.reset` is called, you must first copy them to another object + states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another object Raises: ValueError: From 82f808b0ef544f376cddb73a9d3df50d5a3efe67 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 12 Apr 2024 16:49:01 +0200 Subject: [PATCH 05/17] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40ed1b62c79..4e2e1c251a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) +- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492)) + + ## [1.3.2] - 2024-03-18 ### Fixed From 65b02fa2ff4173c2f30c982fb9454a535e668cee Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 00:39:22 +0100 Subject: [PATCH 06/17] Remove failing testcode example (fixing introduces too much complexity) --- docs/source/pages/implement.rst | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index 65133f20229..f7da4aa8ba7 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -126,14 +126,7 @@ A few important things to note for this example: * Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care must be taken when referencing list states. If you require the values after your metric is reset, you must first - copy the attribute to another object: - - .. testcode:: - - x = metric.list_state # referenced (and deleted by reset) - - from deepcopy import copy - y = copy(metric.list_state) # copied (and unchanged by reset) + copy the attribute to another object (e.g. using `deepcopy.copy`). ***************** Metric attributes From 6241a6b27b1d896ea8a9d391c793b6db66f2a039 Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 14:47:32 +0100 Subject: [PATCH 07/17] Linting - Line break docstring --- src/torchmetrics/metric.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 11b8b61d282..534d3badebb 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -241,7 +241,8 @@ def add_state( Note: The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list - states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another object + states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another + object. Raises: ValueError: From 57589773c28cebd6de20ec5bc76a3c4a0099b1c0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 13 Apr 2024 17:06:01 +0200 Subject: [PATCH 08/17] copy internal states in forward --- src/torchmetrics/metric.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 534d3badebb..fa4d5b81dd7 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -331,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: self.compute_on_cpu = False # save context before switch - cache = {attr: getattr(self, attr) for attr in self._defaults} + cache = deepcopy({attr: getattr(self, attr) for attr in self._defaults}) # call reset, update, compute, on single batch self._enable_grad = True # allow grads for batch computation @@ -364,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """ # store global state and reset to default - global_state = {attr: getattr(self, attr) for attr in self._defaults} + global_state = deepcopy({attr: getattr(self, attr) for attr in self._defaults}) _update_count = self._update_count self.reset() @@ -531,7 +531,7 @@ def sync( dist_sync_fn = gather_all_tensors # cache prior to syncing - self._cache = {attr: getattr(self, attr) for attr in self._defaults} + self._cache = deepcopy({attr: getattr(self, attr) for attr in self._defaults}) # sync self._sync_dist(dist_sync_fn, process_group=process_group) From c9d2a8665f15ea3653b62a36a9654f9b4d72a1af Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 18:17:11 +0100 Subject: [PATCH 09/17] Detach Tensor | list[Tensor] state values before copying. --- src/torchmetrics/metric.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index fa4d5b81dd7..660240e77c9 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -331,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: self.compute_on_cpu = False # save context before switch - cache = deepcopy({attr: getattr(self, attr) for attr in self._defaults}) + cache = self._copy_state_dict() # call reset, update, compute, on single batch self._enable_grad = True # allow grads for batch computation @@ -364,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """ # store global state and reset to default - global_state = deepcopy({attr: getattr(self, attr) for attr in self._defaults}) + global_state = self._copy_state_dict() _update_count = self._update_count self.reset() @@ -531,7 +531,7 @@ def sync( dist_sync_fn = gather_all_tensors # cache prior to syncing - self._cache = deepcopy({attr: getattr(self, attr) for attr in self._defaults}) + self._cache = self._copy_state_dict() # sync self._sync_dist(dist_sync_fn, process_group=process_group) @@ -876,6 +876,22 @@ def state_dict( # type: ignore[override] # todo destination[prefix + key] = deepcopy(current_val) return destination + def _copy_state_dict(self) -> dict[str, Tensor | List[Any]]: + """Copy the current state values""" + cache = {} + for attr in self._defaults: + current_value = getattr(self, attr) + + if isinstance(current_value, Tensor): + cache[attr] = current_value.detach().clone().to(current_value.device) + else: + cache[attr] = [ # safely copy (non-graph leaf) Tensor elements + _.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) + for _ in current_value + ] + + return cache + def _load_from_state_dict( self, state_dict: dict, From e1872deddeee30bf675e9b3b207164f0ce1db10f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Apr 2024 17:18:30 +0000 Subject: [PATCH 10/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/metric.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 660240e77c9..c4250979c6d 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -877,7 +877,7 @@ def state_dict( # type: ignore[override] # todo return destination def _copy_state_dict(self) -> dict[str, Tensor | List[Any]]: - """Copy the current state values""" + """Copy the current state values.""" cache = {} for attr in self._defaults: current_value = getattr(self, attr) @@ -886,8 +886,7 @@ def _copy_state_dict(self) -> dict[str, Tensor | List[Any]]: cache[attr] = current_value.detach().clone().to(current_value.device) else: cache[attr] = [ # safely copy (non-graph leaf) Tensor elements - _.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) - for _ in current_value + _.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value ] return cache From afdd4c5d40a000ecee2b4ea3acfe64cdbf6d8f1e Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 18:26:23 +0100 Subject: [PATCH 11/17] Use 'typing' type hints --- src/torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 660240e77c9..7c6101c4df7 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -876,7 +876,7 @@ def state_dict( # type: ignore[override] # todo destination[prefix + key] = deepcopy(current_val) return destination - def _copy_state_dict(self) -> dict[str, Tensor | List[Any]]: + def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: """Copy the current state values""" cache = {} for attr in self._defaults: From ef2721575cfe5993ceba94c455903f3bf1356831 Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 20:33:42 +0100 Subject: [PATCH 12/17] DO not clone (when caching) Tensor states, but retain references to avoid memory leakage --- src/torchmetrics/metric.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 12cf33cdd82..7803c0e5f71 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -331,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: self.compute_on_cpu = False # save context before switch - cache = self._copy_state_dict() + cache = self._save_state_dict() # call reset, update, compute, on single batch self._enable_grad = True # allow grads for batch computation @@ -364,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """ # store global state and reset to default - global_state = self._copy_state_dict() + global_state = self._save_state_dict() _update_count = self._update_count self.reset() @@ -531,7 +531,7 @@ def sync( dist_sync_fn = gather_all_tensors # cache prior to syncing - self._cache = self._copy_state_dict() + self._cache = self._save_state_dict() # sync self._sync_dist(dist_sync_fn, process_group=process_group) @@ -876,17 +876,19 @@ def state_dict( # type: ignore[override] # todo destination[prefix + key] = deepcopy(current_val) return destination - def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: - """Copy the current state values.""" + def _save_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: + """Save the current state values, retaining references to Tensor values""" + + # do not .clone() Tensor values, as new objects leak memory cache = {} for attr in self._defaults: current_value = getattr(self, attr) if isinstance(current_value, Tensor): - cache[attr] = current_value.detach().clone().to(current_value.device) + cache[attr] = current_value.detach().to(current_value.device) else: cache[attr] = [ # safely copy (non-graph leaf) Tensor elements - _.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value + _.detach().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value ] return cache From e1045874af0ae4a40391066340d924c0605934b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Apr 2024 19:34:05 +0000 Subject: [PATCH 13/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/metric.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 7803c0e5f71..2cd571059ec 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -877,8 +877,7 @@ def state_dict( # type: ignore[override] # todo return destination def _save_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: - """Save the current state values, retaining references to Tensor values""" - + """Save the current state values, retaining references to Tensor values.""" # do not .clone() Tensor values, as new objects leak memory cache = {} for attr in self._defaults: From 21c7970081371138ff0e151f6c191ce748687137 Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 22:20:51 +0100 Subject: [PATCH 14/17] Revert "DO not clone (when caching) Tensor states, but retain references to avoid memory leakage" This reverts commit ef2721575cfe5993ceba94c455903f3bf1356831. --- src/torchmetrics/metric.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 2cd571059ec..12cf33cdd82 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -331,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any: self.compute_on_cpu = False # save context before switch - cache = self._save_state_dict() + cache = self._copy_state_dict() # call reset, update, compute, on single batch self._enable_grad = True # allow grads for batch computation @@ -364,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """ # store global state and reset to default - global_state = self._save_state_dict() + global_state = self._copy_state_dict() _update_count = self._update_count self.reset() @@ -531,7 +531,7 @@ def sync( dist_sync_fn = gather_all_tensors # cache prior to syncing - self._cache = self._save_state_dict() + self._cache = self._copy_state_dict() # sync self._sync_dist(dist_sync_fn, process_group=process_group) @@ -876,18 +876,17 @@ def state_dict( # type: ignore[override] # todo destination[prefix + key] = deepcopy(current_val) return destination - def _save_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: - """Save the current state values, retaining references to Tensor values.""" - # do not .clone() Tensor values, as new objects leak memory + def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: + """Copy the current state values.""" cache = {} for attr in self._defaults: current_value = getattr(self, attr) if isinstance(current_value, Tensor): - cache[attr] = current_value.detach().to(current_value.device) + cache[attr] = current_value.detach().clone().to(current_value.device) else: cache[attr] = [ # safely copy (non-graph leaf) Tensor elements - _.detach().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value + _.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value ] return cache From 51975a8f292ee32bc61b291c34091020e5644347 Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 22:35:38 +0100 Subject: [PATCH 15/17] Added mypy type-hinting requirement/recommendation --- src/torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 12cf33cdd82..6d64955198d 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -878,7 +878,7 @@ def state_dict( # type: ignore[override] # todo def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]: """Copy the current state values.""" - cache = {} + cache: Dict[str, Union[Tensor, List[Any]]] = {} for attr in self._defaults: current_value = getattr(self, attr) From 5954d02e58b9015aaf5ea66751f4a9d290765765 Mon Sep 17 00:00:00 2001 From: dominicgkerr Date: Sat, 13 Apr 2024 22:37:27 +0100 Subject: [PATCH 16/17] Moved update from test checking .__init__ memory leakage. Added test checking .reset clears memory allocated during update (memory should be allowed to grow, as long as discarded safely) --- tests/unittests/bases/test_metric.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 0870988beb3..512977ce74c 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -480,11 +480,8 @@ def test_constant_memory_on_repeat_init(): def mem(): return torch.cuda.memory_allocated() / 1024**2 - x = torch.randn(10000).cuda() - for i in range(100): - m = DummyListMetric(compute_with_cache=False).cuda() - m(x) + _ = DummyListMetric(compute_with_cache=False).cuda() if i == 0: after_one_iter = mem() @@ -492,6 +489,25 @@ def mem(): assert after_one_iter * 1.05 >= mem(), "memory increased too much above base level" +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +def test_freed_memory_on_reset(): + """Test that resetting a metric frees all the memory allocated when updating it.""" + + def mem(): + return torch.cuda.memory_allocated() / 1024**2 + + m = DummyListMetric().cuda() + after_init = mem() + + for i in range(100): + m(x=torch.randn(10000).cuda()) + + m.reset() + + # allow for 5% flucturation due to measuring + assert after_init * 1.05 >= mem(), "memory increased too much above base level" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu") def test_specific_error_on_wrong_device(): """Test that a specific error is raised if we detect input and metric are on different devices.""" From 3f013cfc54b1b057dae2719f72bf84d94be8cc88 Mon Sep 17 00:00:00 2001 From: stancld Date: Tue, 16 Apr 2024 10:39:06 +0200 Subject: [PATCH 17/17] Fix unused loop control variable for pre-commit --- tests/unittests/bases/test_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 512977ce74c..753150478e4 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -499,7 +499,7 @@ def mem(): m = DummyListMetric().cuda() after_init = mem() - for i in range(100): + for _ in range(100): m(x=torch.randn(10000).cuda()) m.reset()