Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CriticalSuccessIndex (CSI) metric #2257

Merged
merged 22 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for logging `MultiTaskWrapper` directly with lightnings `log_dict` method ([#2213](https://github.com/Lightning-AI/torchmetrics/pull/2213))


- Added `aggregate`` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220))
- Added `aggregate` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220))


- Added `CriticalSuccessIndex` metric to image subpackage ([#2257](https://github.com/Lightning-AI/torchmetrics/pull/2257))

### Changed

- Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145))
Expand Down
22 changes: 22 additions & 0 deletions docs/source/image/critical_success_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Critical Success Index (CSI)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

############################
Critical Success Index (CSI)
############################

Module Interface
________________

.. autoclass:: torchmetrics.image.CriticalSuccessIndex
:exclude-members: update, compute


Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.critical_success_index
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.image.csi import critical_success_index
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
Expand Down Expand Up @@ -45,4 +46,5 @@
"visual_information_fidelity",
"learned_perceptual_image_patch_similarity",
"perceptual_path_length",
"critical_success_index",
]
112 changes: 112 additions & 0 deletions src/torchmetrics/functional/image/csi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide


def _critical_success_index_update(
preds: Tensor, target: Tensor, threshold: float, keep_sequence_dim: Optional[int] = None
) -> Tuple[Tensor, Tensor, Tensor]:
"""Update and return variables required to compute Critical Success Index. Checks for same shape of tensors.

Args:
preds: Predicted tensor
target: Ground truth tensor
threshold: Values above or equal to threshold are replaced with 1, below by 0
keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified,
the score will be calculated separately for each image in the sequence. If ``None``, the score will be
calculated across all dimensions.

"""
_check_same_shape(preds, target)

if keep_sequence_dim is None:
sum_dims = None
elif not 0 <= keep_sequence_dim < preds.ndim:
raise ValueError(f"Expected keep_sequence dim to be in range [0, {preds.ndim}] but got {keep_sequence_dim}")
else:
sum_dims = tuple(i for i in range(preds.ndim) if i != keep_sequence_dim)

# binarize the tensors with the threshold
preds_bin = (preds >= threshold).bool()
target_bin = (target >= threshold).bool()

if keep_sequence_dim is None:
hits = torch.sum(preds_bin & target_bin).int()
misses = torch.sum((preds_bin ^ target_bin) & target_bin).int()
false_alarms = torch.sum((preds_bin ^ target_bin) & preds_bin).int()
else:
hits = torch.sum(preds_bin & target_bin, dim=sum_dims).int()
misses = torch.sum((preds_bin ^ target_bin) & target_bin, dim=sum_dims).int()
false_alarms = torch.sum((preds_bin ^ target_bin) & preds_bin, dim=sum_dims).int()
return hits, misses, false_alarms


def _critical_success_index_compute(hits: Tensor, misses: Tensor, false_alarms: Tensor) -> Tensor:
"""Compute critical success index.

Args:
hits: Number of true positives after binarization
misses: Number of false negatives after binarization
false_alarms: Number of false positives after binarization

Returns:
If input tensors are 5-dimensional and ``keep_sequence_dim=True``, the metric returns a ``(S,)`` vector
with CSI scores for each image in the sequence. Otherwise, it returns a scalar tensor with the CSI score.

"""
return _safe_divide(hits, hits + misses + false_alarms)


def critical_success_index(
preds: Tensor, target: Tensor, threshold: float, keep_sequence_dim: Optional[int] = None
) -> Tensor:
"""Compute critical success index.

Args:
preds: Predicted tensor
target: Ground truth tensor
threshold: Values above or equal to threshold are replaced with 1, below by 0
keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified,
the score will be calculated separately for each image in the sequence. If ``None``, the score will be
calculated across all dimensions.

Returns:
If ``keep_sequence_dim`` is specified, the metric returns a vector of with CSI scores for each image
in the sequence. Otherwise, it returns a scalar tensor with the CSI score.

Example:
>>> import torch
>>> from torchmetrics.functional.image.csi import critical_success_index
>>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
>>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
>>> critical_success_index(x, y, 0.5)
tensor(0.3333)

Example:
>>> import torch
>>> from torchmetrics.functional.image.csi import critical_success_index
>>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
>>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
>>> critical_success_index(x, y, 0.5, keep_sequence_dim=0)
tensor([0.3333, 0.3333])

"""
hits, misses, false_alarms = _critical_success_index_update(preds, target, threshold, keep_sequence_dim)
return _critical_success_index_compute(hits, misses, false_alarms)
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.image.csi import CriticalSuccessIndex
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
Expand Down Expand Up @@ -42,6 +43,7 @@
"UniversalImageQualityIndex",
"VisualInformationFidelity",
"TotalVariation",
"CriticalSuccessIndex",
]

if _TORCH_FIDELITY_AVAILABLE:
Expand Down
111 changes: 111 additions & 0 deletions src/torchmetrics/image/csi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional

import torch

from torchmetrics.functional.image.csi import _critical_success_index_compute, _critical_success_index_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import dim_zero_cat


class CriticalSuccessIndex(Metric):
r"""Calculate critical success index (CSI).

Critical success index (also known as the threat score) is a statistic used weather forecasting that measures
forecast performance over inputs binarized at a specified threshold. It is defined as:

.. math:: \text{CSI} = \frac{\text{TP}}{\text{TP}+\text{FN}+\text{FP}}

Where :math:`\text{TP}`, :math:`\text{FN}` and :math:`\text{FP}` represent the number of true positives, false
negatives and false positives respectively after binarizing the input tensors.

Args:
threshold: Values above or equal to threshold are replaced with 1, below by 0
keep_sequence_dim: Index of the sequence dimension if the inputs are sequences of images. If specified,
the score will be calculated separately for each image in the sequence. If ``None``, the score will be
calculated across all dimensions.

Example:
>>> import torch
>>> from torchmetrics.image.csi import CriticalSuccessIndex
>>> x = torch.Tensor([[0.2, 0.7], [0.9, 0.3]])
>>> y = torch.Tensor([[0.4, 0.2], [0.8, 0.6]])
>>> csi = CriticalSuccessIndex(0.5)
>>> csi(x, y)
tensor(0.3333)

Example:
>>> import torch
>>> from torchmetrics.image.csi import CriticalSuccessIndex
>>> x = torch.Tensor([[[0.2, 0.7], [0.9, 0.3]], [[0.2, 0.7], [0.9, 0.3]]])
>>> y = torch.Tensor([[[0.4, 0.2], [0.8, 0.6]], [[0.4, 0.2], [0.8, 0.6]]])
>>> csi = CriticalSuccessIndex(0.5, keep_sequence_dim=0)
>>> csi(x, y)
tensor([0.3333, 0.3333])

"""

is_differentiable: bool = False
higher_is_better: bool = True

hits: torch.Tensor
misses: torch.Tensor
false_alarms: torch.Tensor
hits_list: List[torch.Tensor]
misses_list: List[torch.Tensor]
false_alarms_list: List[torch.Tensor]

def __init__(self, threshold: float, keep_sequence_dim: Optional[int] = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.threshold = float(threshold)

if keep_sequence_dim and (not isinstance(keep_sequence_dim, int) or keep_sequence_dim < 0):
raise ValueError(f"Expected keep_sequence_dim to be a non-negative integer but got {keep_sequence_dim}")
self.keep_sequence_dim = keep_sequence_dim

if keep_sequence_dim is None:
self.add_state("hits", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("misses", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("false_alarms", default=torch.tensor(0), dist_reduce_fx="sum")
else:
self.add_state("hits_list", default=[], dist_reduce_fx="cat")
self.add_state("misses_list", default=[], dist_reduce_fx="cat")
self.add_state("false_alarms_list", default=[], dist_reduce_fx="cat")

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Update state with predictions and targets."""
hits, misses, false_alarms = _critical_success_index_update(
preds, target, self.threshold, self.keep_sequence_dim
)
if self.keep_sequence_dim is None:
self.hits += hits
self.misses += misses
self.false_alarms += false_alarms
else:
self.hits_list.append(hits)
self.misses_list.append(misses)
self.false_alarms_list.append(false_alarms)

def compute(self) -> torch.Tensor:
"""Compute critical success index over state."""
if self.keep_sequence_dim is None:
hits = self.hits
misses = self.misses
false_alarms = self.false_alarms
else:
hits = dim_zero_cat(self.hits_list)
misses = dim_zero_cat(self.misses_list)
false_alarms = dim_zero_cat(self.false_alarms_list)
return _critical_success_index_compute(hits, misses, false_alarms)
93 changes: 93 additions & 0 deletions tests/unittests/image/test_csi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import pytest
import torch
from sklearn.metrics import jaccard_score
from torchmetrics.functional.image.csi import critical_success_index
from torchmetrics.image.csi import CriticalSuccessIndex

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)


_inputs_1 = _Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE))
_inputs_2 = _Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE))


def _calculate_ref_metric(preds: torch.Tensor, target: torch.Tensor, threshold: float):
"""Calculate reference metric for `CriticalSuccessIndex`."""
preds, target = preds.numpy(), target.numpy()
preds = preds >= threshold
target = target >= threshold
return jaccard_score(preds.ravel(), target.ravel())


@pytest.mark.parametrize(
"preds, target",
[
(_inputs_1.preds, _inputs_1.target),
(_inputs_2.preds, _inputs_2.target),
],
)
@pytest.mark.parametrize("threshold", [0.5, 0.25, 0.75])
class TestCriticalSuccessIndex(MetricTester):
"""Test class for `CriticalSuccessIndex` metric."""

@pytest.mark.parametrize("ddp", [True, False])
def test_csi_class(self, preds, target, threshold, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=CriticalSuccessIndex,
reference_metric=partial(_calculate_ref_metric, threshold=threshold),
metric_args={"threshold": threshold},
)

def test_csi_functional(self, preds, target, threshold):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=critical_success_index,
reference_metric=partial(_calculate_ref_metric, threshold=threshold),
metric_args={"threshold": threshold},
)

def test_csi_half_cpu(self, preds, target, threshold):
"""Test dtype support of the metric on CPU."""
self.run_precision_test_cpu(
preds=preds, target=target, metric_functional=critical_success_index, metric_args={"threshold": threshold}
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_csi_half_gpu(self, preds, target, threshold):
"""Test dtype support of the metric on GPU."""
self.run_precision_test_gpu(
preds=preds, target=target, metric_functional=critical_success_index, metric_args={"threshold": threshold}
)


def test_error_on_different_shape():
"""Test that error is raised on different shapes of input."""
metric = CriticalSuccessIndex(0.5)
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
Loading