Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/segmentation_dice_zero_overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jan 8, 2025
2 parents 5486901 + e690bbd commit ab5070d
Show file tree
Hide file tree
Showing 222 changed files with 474 additions and 456 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-integrate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- { python-version: "3.10", requires: "latest", os: "ubuntu-22.04" }
# - { python-version: "3.10", requires: "latest", os: "macOS-14" } # M1 machine # todo: crashing for MPS out of memory
env:
PYTORCH_URL: "http://download.pytorch.org/whl/cpu/"
PYTORCH_URL: "https://download.pytorch.org/whl/cpu/"
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
PYPI_CACHE: "_ci-cache_PyPI"

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
TOKENIZERS_PARALLELISM: false
TEST_DIRS: ${{ needs.check-diff.outputs.test-dirs }}
PIP_EXTRA_INDEX_URL: "--extra-index-url=http://download.pytorch.org/whl/cpu/"
PIP_EXTRA_INDEX_URL: "--extra-index-url=https://download.pytorch.org/whl/cpu/"
UNITTEST_TIMEOUT: "" # by default, it is not set

# Timeout: https://stackoverflow.com/a/59076067/4521646
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ defaults:

env:
FREEZE_REQUIREMENTS: "1"
TORCH_URL: "http://download.pytorch.org/whl/cpu/"
TORCH_URL: "https://download.pytorch.org/whl/cpu/"
PYPI_CACHE: "_ci-cache_PyPI"
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: "python"
TOKENIZERS_PARALLELISM: false
Expand Down
17 changes: 8 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -46,11 +46,10 @@ repos:
exclude: pyproject.toml

- repo: https://github.com/crate-ci/typos
rev: v1.22.9
rev: dictgen-v0.3.1
hooks:
- id: typos
# empty to do not write fixes
args: []
args: [] # empty to do not write fixes
exclude: pyproject.toml

- repo: https://github.com/PyCQA/docformatter
Expand All @@ -61,12 +60,12 @@ repos:
args: ["--in-place"]

- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v0.9.1
rev: v1.0.0
hooks:
- id: sphinx-lint

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17
rev: 0.7.21
hooks:
- id: mdformat
args: ["--number"]
Expand Down Expand Up @@ -113,7 +112,7 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.8.6
hooks:
# try to fix what is possible
- id: ruff
Expand All @@ -124,11 +123,11 @@ repos:
- id: ruff

- repo: https://github.com/tox-dev/pyproject-fmt
rev: 2.1.3
rev: v2.5.0
hooks:
- id: pyproject-fmt
additional_dependencies: [tox]
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.18
rev: v0.23
hooks:
- id: validate-pyproject
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `PearsonCorrcoef`
* `SpearmanCorrcoef`
- Removed deprecated functions, and warnings in detection and pairwise ([#804](https://github.com/Lightning-AI/metrics/pull/804))
* `MAP` and `functional.pairwise.manhatten`
* `MAP` and `functional.pairwise.manhattan`
- Removed deprecated functions, and warnings in Audio ([#805](https://github.com/Lightning-AI/metrics/pull/805))
* `PESQ` and `functional.audio.pesq`
* `PIT` and `functional.audio.pit`
Expand Down Expand Up @@ -1032,7 +1032,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `pairwise_cosine_similarity`
- `pairwise_euclidean_distance`
- `pairwise_linear_similarity`
- `pairwise_manhatten_distance`
- `pairwise_manhattan_distance`

### Changed

Expand Down
1 change: 1 addition & 0 deletions _samples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from torch import Tensor, nn
from torch.nn import Module

from torchmetrics.text.bert import BERTScore

_NUM_LAYERS = 2
Expand Down
1 change: 1 addition & 0 deletions _samples/detection_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""An example of how the predictions and target should be defined for the MAP object detection metric."""

from torch import BoolTensor, IntTensor, Tensor

from torchmetrics.detection.mean_ap import MeanAveragePrecision

# Preds should be a list of elements, where each element is a dict
Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from typing import Optional

import lai_sphinx_theme
import torchmetrics
from lightning_utilities.docs.formatting import _linkcode_resolve, _transform_changelog

import torchmetrics

_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
sys.path.insert(0, os.path.abspath(_PATH_ROOT))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ We provide the remaining interface, such as ``reset()`` that will make sure to c
states that have been added using ``add_state``. You should therefore not implement ``reset()`` yourself, only in rare
cases where not all the state variables should be reset to their default value. Adding metric states with ``add_state``
will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are
synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state()` docs from the base
synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state` docs from the base
:class:`~torchmetrics.Metric` class.

Below is a basic implementation of a custom accuracy metric. In the ``__init__`` method we add the metric states
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/binary_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/binary_accuracy_multistep.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/collection_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/collection_binary_together.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/multiclass_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/tracker_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions examples/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio

from torchmetrics.audio import PerceptualEvaluationSpeechQuality

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/audio/signal_to_noise_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import matplotlib.pyplot as plt
import numpy as np
import torch

from torchmetrics.audio import SignalNoiseRatio

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/image/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from matplotlib.table import Table
from skimage.data import astronaut, cat, coffee

from torchmetrics.multimodal import CLIPScore

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/image/spatial_correlation_coef.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from skimage.data import shepp_logan_phantom
from skimage.transform import iradon, radon, rescale

from torchmetrics.image import SpatialCorrelationCoefficient

# %%
Expand Down
3 changes: 2 additions & 1 deletion examples/text/bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
Let's consider a use case in natural language processing where BERTScore is used to evaluate the quality of a text generation model. In this case we are imaging that we are developing a automated news summarization system. The goal is to create concise summaries of news articles that accurately capture the key points of the original articles. To evaluate the performance of your summarization system, you need a metric that can compare the generated summaries to human-written summaries. This is where the BERTScore can be used.
"""

from torchmetrics.text import BERTScore, ROUGEScore
from transformers import AutoTokenizer, pipeline

from torchmetrics.text import BERTScore, ROUGEScore

pipe = pipeline("text-generation", model="openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

Expand Down
3 changes: 2 additions & 1 deletion examples/text/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# Here's a hypothetical Python example demonstrating the usage of Perplexity to evaluate a generative language model

import torch
from torchmetrics.text import Perplexity
from transformers import AutoModelWithLMHead, AutoTokenizer

from torchmetrics.text import Perplexity

# %%
# Load the GPT-2 model and tokenizer

Expand Down
3 changes: 2 additions & 1 deletion examples/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
# %%
# Here's a hypothetical Python example demonstrating the usage of unigram ROUGE F-score to evaluate a generative language model:

from torchmetrics.text import ROUGEScore
from transformers import AutoTokenizer, pipeline

from torchmetrics.text import ROUGEScore

pipe = pipeline("text-generation", model="openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

Expand Down
2 changes: 1 addition & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scipy >1.0.0, <1.15.0
scipy >1.0.0, <1.16.0
torchvision >=0.15.1, <0.22.0
torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing
4 changes: 2 additions & 2 deletions requirements/nominal_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

pandas >1.4.0, <=2.2.3 # cannot pin version due to numpy version incompatibility
dython ==0.7.6 ; python_version <"3.9"
dython ~=0.7.8 ; python_version > "3.8" # we do not use `> =`
scipy >1.0.0, <1.15.0 # cannot pin version due to some version conflicts with `oldest` CI configuration
dython ==0.7.9 ; python_version > "3.8" # we do not use `> =`
scipy >1.0.0, <1.16.0 # cannot pin version due to some version conflicts with `oldest` CI configuration
statsmodels >0.13.5, <0.15.0
2 changes: 1 addition & 1 deletion requirements/segmentation_test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scipy >1.0.0, <1.15.0
scipy >1.0.0, <1.16.0
monai ==1.3.2 ; python_version < "3.9"
monai ==1.4.0 ; python_version > "3.8"
2 changes: 1 addition & 1 deletion requirements/text_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jiwer >=2.3.0, <3.1.0
rouge-score >0.1.0, <=0.1.2
bert_score ==0.3.13
huggingface-hub <0.28
sacrebleu >=2.3.0, <2.5.0
sacrebleu >=2.3.0, <2.6.0

mecab-ko >=1.0.0, <1.1.0 ; python_version < "3.12" # strict # todo: unpin python_version
mecab-ko-dic >=1.0.0, <1.1.0 ; python_version < "3.12" # todo: unpin python_version
2 changes: 1 addition & 1 deletion src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
MANUAL_SEED = doctest.register_optionflag("MANUAL_SEED")

@pytest.fixture(autouse=True)
def reset_random_seed(seed: int = 42) -> None: # noqa: PT004
def reset_random_seed(seed: int = 42) -> None:
"""Reset the random seed before running each doctest."""
import random

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@

__all__ = [
"AUROC",
"ROC",
"Accuracy",
"AveragePrecision",
"BLEUScore",
Expand Down Expand Up @@ -229,7 +230,6 @@
"PrecisionAtFixedRecall",
"PrecisionRecallCurve",
"R2Score",
"ROC",
"Recall",
"RecallAtFixedPrecision",
"RelativeAverageSpectralError",
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def __init__(
allowed_nan_strategy = ("error", "warn", "ignore")
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
raise ValueError(
f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy}"
f" but got {nan_strategy}."
f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}."
)

self.nan_strategy = nan_strategy
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
scipy.signal.hamming = scipy.signal.windows.hamming

__all__ = [
"ComplexScaleInvariantSignalNoiseRatio",
"PermutationInvariantTraining",
"ScaleInvariantSignalDistortionRatio",
"SignalDistortionRatio",
"SourceAggregatedSignalDistortionRatio",
"ScaleInvariantSignalNoiseRatio",
"SignalDistortionRatio",
"SignalNoiseRatio",
"ComplexScaleInvariantSignalNoiseRatio",
"SourceAggregatedSignalDistortionRatio",
]

if _PESQ_AVAILABLE:
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SpeechReverberationModulationEnergyRatio(Metric):
This implementation is experimental, and might not be consistent with the matlab
implementation `SRMRToolbox`_, especially the fast implementation.
The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have
a relatively small inconsistence.
a relatively small inconsistency.
Args:
fs: the sampling rate
Expand Down
Loading

0 comments on commit ab5070d

Please sign in to comment.