Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit suggestions #2902

Merged
merged 7 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -46,11 +46,10 @@ repos:
exclude: pyproject.toml

- repo: https://github.com/crate-ci/typos
rev: v1.22.9
rev: dictgen-v0.3.1
hooks:
- id: typos
# empty to do not write fixes
args: []
args: [] # empty to do not write fixes
exclude: pyproject.toml

- repo: https://github.com/PyCQA/docformatter
Expand All @@ -61,12 +60,12 @@ repos:
args: ["--in-place"]

- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v0.9.1
rev: v1.0.0
hooks:
- id: sphinx-lint

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17
rev: 0.7.21
hooks:
- id: mdformat
args: ["--number"]
Expand Down Expand Up @@ -113,7 +112,7 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.8.6
hooks:
# try to fix what is possible
- id: ruff
Expand All @@ -124,11 +123,11 @@ repos:
- id: ruff

- repo: https://github.com/tox-dev/pyproject-fmt
rev: 2.1.3
rev: v2.5.0
hooks:
- id: pyproject-fmt
additional_dependencies: [tox]
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.18
rev: v0.23
hooks:
- id: validate-pyproject
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `PearsonCorrcoef`
* `SpearmanCorrcoef`
- Removed deprecated functions, and warnings in detection and pairwise ([#804](https://github.com/Lightning-AI/metrics/pull/804))
* `MAP` and `functional.pairwise.manhatten`
* `MAP` and `functional.pairwise.manhattan`
- Removed deprecated functions, and warnings in Audio ([#805](https://github.com/Lightning-AI/metrics/pull/805))
* `PESQ` and `functional.audio.pesq`
* `PIT` and `functional.audio.pit`
Expand Down Expand Up @@ -1029,7 +1029,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `pairwise_cosine_similarity`
- `pairwise_euclidean_distance`
- `pairwise_linear_similarity`
- `pairwise_manhatten_distance`
- `pairwise_manhattan_distance`

### Changed

Expand Down
1 change: 1 addition & 0 deletions _samples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from torch import Tensor, nn
from torch.nn import Module

from torchmetrics.text.bert import BERTScore

_NUM_LAYERS = 2
Expand Down
1 change: 1 addition & 0 deletions _samples/detection_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""An example of how the predictions and target should be defined for the MAP object detection metric."""

from torch import BoolTensor, IntTensor, Tensor

from torchmetrics.detection.mean_ap import MeanAveragePrecision

# Preds should be a list of elements, where each element is a dict
Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from typing import Optional

import lai_sphinx_theme
import torchmetrics
from lightning_utilities.docs.formatting import _linkcode_resolve, _transform_changelog

import torchmetrics

_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
_PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", ".."))
sys.path.insert(0, os.path.abspath(_PATH_ROOT))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ We provide the remaining interface, such as ``reset()`` that will make sure to c
states that have been added using ``add_state``. You should therefore not implement ``reset()`` yourself, only in rare
cases where not all the state variables should be reset to their default value. Adding metric states with ``add_state``
will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are
synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state()` docs from the base
synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state` docs from the base
:class:`~torchmetrics.Metric` class.

Below is a basic implementation of a custom accuracy metric. In the ``__init__`` method we add the metric states
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/binary_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/binary_accuracy_multistep.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/collection_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/collection_binary_together.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/multiclass_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions docs/source/pyplots/tracker_binary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
Expand Down
1 change: 1 addition & 0 deletions examples/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio

from torchmetrics.audio import PerceptualEvaluationSpeechQuality

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/audio/signal_to_noise_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import matplotlib.pyplot as plt
import numpy as np
import torch

from torchmetrics.audio import SignalNoiseRatio

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/image/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from matplotlib.table import Table
from skimage.data import astronaut, cat, coffee

from torchmetrics.multimodal import CLIPScore

# %%
Expand Down
1 change: 1 addition & 0 deletions examples/image/spatial_correlation_coef.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from skimage.data import shepp_logan_phantom
from skimage.transform import iradon, radon, rescale

from torchmetrics.image import SpatialCorrelationCoefficient

# %%
Expand Down
3 changes: 2 additions & 1 deletion examples/text/bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
Let's consider a use case in natural language processing where BERTScore is used to evaluate the quality of a text generation model. In this case we are imaging that we are developing a automated news summarization system. The goal is to create concise summaries of news articles that accurately capture the key points of the original articles. To evaluate the performance of your summarization system, you need a metric that can compare the generated summaries to human-written summaries. This is where the BERTScore can be used.
"""

from torchmetrics.text import BERTScore, ROUGEScore
from transformers import AutoTokenizer, pipeline

from torchmetrics.text import BERTScore, ROUGEScore

pipe = pipeline("text-generation", model="openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

Expand Down
3 changes: 2 additions & 1 deletion examples/text/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# Here's a hypothetical Python example demonstrating the usage of Perplexity to evaluate a generative language model

import torch
from torchmetrics.text import Perplexity
from transformers import AutoModelWithLMHead, AutoTokenizer

from torchmetrics.text import Perplexity

# %%
# Load the GPT-2 model and tokenizer

Expand Down
3 changes: 2 additions & 1 deletion examples/text/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
# %%
# Here's a hypothetical Python example demonstrating the usage of unigram ROUGE F-score to evaluate a generative language model:

from torchmetrics.text import ROUGEScore
from transformers import AutoTokenizer, pipeline

from torchmetrics.text import ROUGEScore

pipe = pipeline("text-generation", model="openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

Expand Down
2 changes: 1 addition & 1 deletion src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
MANUAL_SEED = doctest.register_optionflag("MANUAL_SEED")

@pytest.fixture(autouse=True)
def reset_random_seed(seed: int = 42) -> None: # noqa: PT004
def reset_random_seed(seed: int = 42) -> None:
"""Reset the random seed before running each doctest."""
import random

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@

__all__ = [
"AUROC",
"ROC",
"Accuracy",
"AveragePrecision",
"BLEUScore",
Expand Down Expand Up @@ -229,7 +230,6 @@
"PrecisionAtFixedRecall",
"PrecisionRecallCurve",
"R2Score",
"ROC",
"Recall",
"RecallAtFixedPrecision",
"RelativeAverageSpectralError",
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def __init__(
allowed_nan_strategy = ("error", "warn", "ignore")
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
raise ValueError(
f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy}"
f" but got {nan_strategy}."
f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}."
)

self.nan_strategy = nan_strategy
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
scipy.signal.hamming = scipy.signal.windows.hamming

__all__ = [
"ComplexScaleInvariantSignalNoiseRatio",
"PermutationInvariantTraining",
"ScaleInvariantSignalDistortionRatio",
"SignalDistortionRatio",
"SourceAggregatedSignalDistortionRatio",
"ScaleInvariantSignalNoiseRatio",
"SignalDistortionRatio",
"SignalNoiseRatio",
"ComplexScaleInvariantSignalNoiseRatio",
"SourceAggregatedSignalDistortionRatio",
]

if _PESQ_AVAILABLE:
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SpeechReverberationModulationEnergyRatio(Metric):
This implementation is experimental, and might not be consistent with the matlab
implementation `SRMRToolbox`_, especially the fast implementation.
The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have
a relatively small inconsistence.
a relatively small inconsistency.

Args:
fs: the sampling rate
Expand Down
Loading
Loading