Skip to content

Commit

Permalink
Add support for SQ & RQ as well as per-class metrics (#2381)
Browse files Browse the repository at this point in the history
* Fix RQ and SQ
* Change return type and refactor flag name
* Fix typing
* changelog
* input/output docstring
* guard against older versions

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
3 people authored Mar 25, 2024
1 parent aead080 commit 6dcb61c
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `QualityWithNoReference` metric ([#2288](https://github.com/Lightning-AI/torchmetrics/pull/2288))


- Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381))


### Changed

- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))
Expand Down
9 changes: 9 additions & 0 deletions src/torchmetrics/detection/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Any, Collection

from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.prints import _deprecated_root_import_class

if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = [
"_PanopticQuality",
"_PanopticQuality.*",
"_ModifiedPanopticQuality",
"_ModifiedPanopticQuality.*",
]


class _ModifiedPanopticQuality(ModifiedPanopticQuality):
"""Wrapper for deprecated import.
Expand Down
83 changes: 80 additions & 3 deletions src/torchmetrics/detection/panoptic_qualities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@
_validate_inputs,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["PanopticQuality.plot", "ModifiedPanopticQuality.plot"]


if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = ["PanopticQuality", "PanopticQuality.*", "ModifiedPanopticQuality", "ModifiedPanopticQuality.*"]


class PanopticQuality(Metric):
r"""Compute the `Panoptic Quality`_ for panoptic segmentations.
Expand All @@ -47,6 +51,23 @@ class PanopticQuality(Metric):
Points in the target tensor that do not map to a known category ID are automatically ignored in the metric
computation.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)``, where there needs to
be at least one spatial dimension.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, *spatial_dims, 2)``, where there needs to
be at least one spatial dimension.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``quality`` (:class:`~torch.Tensor`): If ``return_sq_and_rq=False`` and ``return_per_class=False`` then a
single scalar tensor is returned with average panoptic quality over all classes. If ``return_sq_and_rq=True``
and ``return_per_class=False`` a tensor of length 3 is returned with panoptic, segmentation and recognition
quality (in that order). If If ``return_sq_and_rq=False`` and ``return_per_class=True`` a tensor of length
equal to the number of classes are returned, with panoptic quality for each class. Finally, if both arguments
are ``True`` a tensor of shape ``(3, C)`` is returned with individual panoptic, segmentation and recognition
quality for each class.
Args:
things:
Set of ``category_id`` for countable things.
Expand All @@ -55,6 +76,10 @@ class PanopticQuality(Metric):
allow_unknown_preds_category:
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
computation or raise an exception when found.
return_sq_and_rq:
Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned.
return_per_class:
Boolean flag to specify if the per-class values should be returned or the class average.
Raises:
Expand All @@ -80,6 +105,40 @@ class PanopticQuality(Metric):
>>> panoptic_quality(preds, target)
tensor(0.5463, dtype=torch.float64)
You can also return the segmentation and recognition quality alognside the PQ
>>> from torch import tensor
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True)
>>> panoptic_quality(preds, target)
tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64)
You can also specify to return the per-class metrics
>>> from torch import tensor
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}, return_per_class=True)
>>> panoptic_quality(preds, target)
tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64)
"""

is_differentiable: bool = False
Expand All @@ -98,16 +157,22 @@ def __init__(
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
return_sq_and_rq: bool = False,
return_per_class: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not _TORCH_GREATER_EQUAL_1_12:
raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later")

things, stuffs = _parse_categories(things, stuffs)
self.things = things
self.stuffs = stuffs
self.void_color = _get_void_color(things, stuffs)
self.cat_id_to_continuous_id = _get_category_id_to_continuous_id(things, stuffs)
self.allow_unknown_preds_category = allow_unknown_preds_category
self.return_sq_and_rq = return_sq_and_rq
self.return_per_class = return_per_class

# per category intermediate metrics
num_categories = len(things) + len(stuffs)
Expand Down Expand Up @@ -154,7 +219,16 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Tensor:
"""Compute panoptic quality based on inputs passed in to ``update`` previously."""
return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives)
pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute(
self.iou_sum, self.true_positives, self.false_positives, self.false_negatives
)
if self.return_per_class:
if self.return_sq_and_rq:
return torch.stack((pq, sq, rq), dim=-1)
return pq.view(1, -1)
if self.return_sq_and_rq:
return torch.stack((pq_avg, sq_avg, rq_avg), dim=0)
return pq_avg

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -337,7 +411,10 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Tensor:
"""Compute panoptic quality based on inputs passed in to ``update`` previously."""
return _panoptic_quality_compute(self.iou_sum, self.true_positives, self.false_positives, self.false_negatives)
_, _, _, pq_avg, _, _ = _panoptic_quality_compute(
self.iou_sum, self.true_positives, self.false_positives, self.false_negatives
)
return pq_avg

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/detection/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from torch import Tensor

from torchmetrics.functional.detection.panoptic_qualities import modified_panoptic_quality, panoptic_quality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.prints import _deprecated_root_import_func

if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = ["_panoptic_quality", "_modified_panoptic_quality"]


def _modified_panoptic_quality(
preds: Tensor,
Expand Down
20 changes: 13 additions & 7 deletions src/torchmetrics/functional/detection/_panoptic_quality_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _panoptic_quality_compute(
true_positives: Tensor,
false_positives: Tensor,
false_negatives: Tensor,
) -> Tensor:
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Compute the final panoptic quality from interim values.
Args:
Expand All @@ -459,11 +459,17 @@ def _panoptic_quality_compute(
false_negatives: the FN value from the update step
Returns:
Panoptic quality as a tensor containing a single scalar.
A tuple containing the per-class panoptic, segmentation and recognition quality followed by the averages
"""
# per category calculation
denominator = (true_positives + 0.5 * false_positives + 0.5 * false_negatives).double()
panoptic_quality = torch.where(denominator > 0.0, iou_sum / denominator, 0.0)
# Reduce across categories. TODO: is it useful to have the option of returning per class metrics?
return torch.mean(panoptic_quality[denominator > 0])
# compute segmentation and recognition quality (per-class)
sq: Tensor = torch.where(true_positives > 0.0, iou_sum / true_positives, 0.0)
denominator: Tensor = true_positives + 0.5 * false_positives + 0.5 * false_negatives
rq: Tensor = torch.where(denominator > 0.0, true_positives / denominator, 0.0)
# compute per-class panoptic quality
pq: Tensor = sq * rq
# compute averages
pq_avg: Tensor = torch.mean(pq[denominator > 0])
sq_avg: Tensor = torch.mean(sq[denominator > 0])
rq_avg: Tensor = torch.mean(rq[denominator > 0])
return pq, sq, rq, pq_avg, sq_avg, rq_avg
80 changes: 78 additions & 2 deletions src/torchmetrics/functional/detection/panoptic_qualities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Collection

import torch
from torch import Tensor

from torchmetrics.functional.detection._panoptic_quality_common import (
Expand All @@ -24,6 +25,10 @@
_prepocess_inputs,
_validate_inputs,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = ["panoptic_quality", "modified_panoptic_quality"]


def panoptic_quality(
Expand All @@ -32,6 +37,8 @@ def panoptic_quality(
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
return_sq_and_rq: bool = False,
return_per_class: bool = False,
) -> Tensor:
r"""Compute `Panoptic Quality`_ for panoptic segmentations.
Expand Down Expand Up @@ -61,6 +68,10 @@ def panoptic_quality(
allow_unknown_preds_category:
Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric
computation or raise an exception when found.
return_sq_and_rq:
Boolean flag to specify if Segmentation Quality and Recognition Quality should be also returned.
return_per_class:
Boolean flag to specify if the per-class values should be returned or the class average.
Raises:
ValueError:
Expand Down Expand Up @@ -91,7 +102,59 @@ def panoptic_quality(
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7})
tensor(0.5463, dtype=torch.float64)
You can also return the segmentation and recognition quality alognside the PQ
>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, return_sq_and_rq=True)
tensor([0.5463, 0.6111, 0.6667], dtype=torch.float64)
You can also specify to return the per-class metrics
>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}, return_per_class=True)
tensor([[0.5185, 0.0000, 0.6667, 1.0000]], dtype=torch.float64)
You can also specify to return the per-class metrics and the segmentation and recognition quality
>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7},
... return_per_class=True, return_sq_and_rq=True)
tensor([[0.5185, 0.7778, 0.6667],
[0.0000, 0.0000, 0.0000],
[0.6667, 0.6667, 1.0000],
[1.0000, 1.0000, 1.0000]], dtype=torch.float64)
"""
if not _TORCH_GREATER_EQUAL_1_12:
raise RuntimeError("Panoptic Quality metric requires PyTorch 1.12 or later")

things, stuffs = _parse_categories(things, stuffs)
_validate_inputs(preds, target)
void_color = _get_void_color(things, stuffs)
Expand All @@ -101,7 +164,19 @@ def panoptic_quality(
iou_sum, true_positives, false_positives, false_negatives = _panoptic_quality_update(
flatten_preds, flatten_target, cat_id_to_continuous_id, void_color
)
return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
pq, sq, rq, pq_avg, sq_avg, rq_avg = _panoptic_quality_compute(
iou_sum,
true_positives,
false_positives,
false_negatives,
)
if return_per_class:
if return_sq_and_rq:
return torch.stack((pq, sq, rq), dim=-1)
return pq.view(1, -1)
if return_sq_and_rq:
return torch.stack((pq_avg, sq_avg, rq_avg), dim=0)
return pq_avg


def modified_panoptic_quality(
Expand Down Expand Up @@ -177,4 +252,5 @@ def modified_panoptic_quality(
void_color,
modified_metric_stuffs=stuffs,
)
return _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
_, _, _, pq_avg, _, _ = _panoptic_quality_compute(iou_sum, true_positives, false_positives, false_negatives)
return pq_avg
6 changes: 6 additions & 0 deletions tests/unittests/detection/test_modified_panoptic_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from torchmetrics.detection import ModifiedPanopticQuality
from torchmetrics.functional.detection import modified_panoptic_quality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12

from unittests import _Input
from unittests._helpers import seed_all
Expand Down Expand Up @@ -76,6 +77,7 @@ def _reference_fn_1_2(preds, target) -> np.ndarray:
return np.array([23 / 30])


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
class TestModifiedPanopticQuality(MetricTester):
"""Test class for `ModifiedPanopticQuality` metric."""

Expand Down Expand Up @@ -111,6 +113,7 @@ def test_panoptic_quality_functional(self):
)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
def test_empty_metric():
"""Test empty metric."""
with pytest.raises(ValueError, match="At least one of `things` and `stuffs` must be non-empty"):
Expand All @@ -120,6 +123,7 @@ def test_empty_metric():
assert torch.isnan(metric.compute())


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
def test_error_on_wrong_input():
"""Test class input validation."""
with pytest.raises(TypeError, match="Expected argument `stuffs` to contain `int` categories.*"):
Expand Down Expand Up @@ -162,6 +166,7 @@ def test_error_on_wrong_input():
metric.update(preds, preds)


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
def test_extreme_values():
"""Test that the metric returns expected values in trivial cases."""
# Exact match between preds and target => metric is 1
Expand All @@ -170,6 +175,7 @@ def test_extreme_values():
assert modified_panoptic_quality(_INPUTS_0.target[0], _INPUTS_0.target[0] + 1, **_ARGS_0) == 0.0


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, reason="PanopticQuality metric only supports PyTorch >= 1.12")
@pytest.mark.parametrize(
("inputs", "args", "cat_dim"),
[
Expand Down
Loading

0 comments on commit 6dcb61c

Please sign in to comment.