From 95556d4875ff4389b40533b3e7f4e98229772c35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 19:06:39 +0900 Subject: [PATCH 1/5] [pre-commit.ci] pre-commit suggestions (#2902) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit suggestions updates: - [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v5.0.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.6.0...v5.0.0) - [github.com/crate-ci/typos: v1.22.9 → dictgen-v0.3.1](https://github.com/crate-ci/typos/compare/v1.22.9...dictgen-v0.3.1) - [github.com/PyCQA/docformatter: 06907d0267368b49b9180eed423fae5697c1e909 → v1.7.5](https://github.com/PyCQA/docformatter/compare/06907d0267368b49b9180eed423fae5697c1e909...v1.7.5) - [github.com/sphinx-contrib/sphinx-lint: v0.9.1 → v1.0.0](https://github.com/sphinx-contrib/sphinx-lint/compare/v0.9.1...v1.0.0) - [github.com/executablebooks/mdformat: 0.7.17 → 0.7.21](https://github.com/executablebooks/mdformat/compare/0.7.17...0.7.21) - [github.com/pre-commit/mirrors-prettier: v3.1.0 → v4.0.0-alpha.8](https://github.com/pre-commit/mirrors-prettier/compare/v3.1.0...v4.0.0-alpha.8) - [github.com/astral-sh/ruff-pre-commit: v0.5.0 → v0.8.6](https://github.com/astral-sh/ruff-pre-commit/compare/v0.5.0...v0.8.6) - [github.com/tox-dev/pyproject-fmt: 2.1.3 → v2.5.0](https://github.com/tox-dev/pyproject-fmt/compare/2.1.3...v2.5.0) - [github.com/abravalheri/validate-pyproject: v0.18 → v0.23](https://github.com/abravalheri/validate-pyproject/compare/v0.18...v0.23) * Apply suggestions from code review -------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka B --- .pre-commit-config.yaml | 17 +- CHANGELOG.md | 4 +- _samples/bert_score-own_model.py | 1 + _samples/detection_map.py | 1 + docs/source/conf.py | 3 +- docs/source/pages/implement.rst | 2 +- docs/source/pyplots/binary_accuracy.py | 1 + .../pyplots/binary_accuracy_multistep.py | 1 + docs/source/pyplots/collection_binary.py | 1 + .../pyplots/collection_binary_together.py | 1 + docs/source/pyplots/confusion_matrix.py | 1 + docs/source/pyplots/multiclass_accuracy.py | 1 + docs/source/pyplots/tracker_binary.py | 1 + examples/audio/pesq.py | 1 + examples/audio/signal_to_noise_ratio.py | 1 + examples/image/clip_score.py | 1 + examples/image/spatial_correlation_coef.py | 1 + examples/text/bertscore.py | 3 +- examples/text/perplexity.py | 3 +- examples/text/rouge.py | 3 +- src/conftest.py | 2 +- src/torchmetrics/__init__.py | 2 +- src/torchmetrics/aggregation.py | 3 +- src/torchmetrics/audio/__init__.py | 6 +- src/torchmetrics/audio/srmr.py | 2 +- src/torchmetrics/classification/__init__.py | 146 ++++++++--------- src/torchmetrics/classification/jaccard.py | 8 +- src/torchmetrics/collections.py | 2 +- src/torchmetrics/detection/__init__.py | 6 +- src/torchmetrics/detection/_mean_ap.py | 16 +- src/torchmetrics/functional/__init__.py | 2 +- src/torchmetrics/functional/audio/__init__.py | 6 +- src/torchmetrics/functional/audio/srmr.py | 2 +- .../functional/classification/__init__.py | 148 +++++++++--------- .../functional/classification/jaccard.py | 8 +- .../functional/detection/__init__.py | 4 +- .../detection/_panoptic_quality_common.py | 3 +- src/torchmetrics/functional/image/__init__.py | 14 +- src/torchmetrics/functional/image/ergas.py | 3 +- src/torchmetrics/functional/image/rase.py | 2 +- src/torchmetrics/functional/image/rmse_sw.py | 2 +- src/torchmetrics/functional/image/sam.py | 3 +- src/torchmetrics/functional/image/uqi.py | 3 +- .../functional/multimodal/__init__.py | 2 +- src/torchmetrics/functional/regression/r2.py | 3 +- .../functional/segmentation/__init__.py | 2 +- src/torchmetrics/functional/text/eed.py | 2 +- .../text/helper_embedding_metric.py | 2 +- src/torchmetrics/image/__init__.py | 14 +- src/torchmetrics/image/rase.py | 2 +- src/torchmetrics/image/rmse_sw.py | 2 +- src/torchmetrics/metric.py | 3 +- src/torchmetrics/multimodal/__init__.py | 2 +- src/torchmetrics/retrieval/__init__.py | 2 +- src/torchmetrics/segmentation/__init__.py | 2 +- src/torchmetrics/text/__init__.py | 4 +- src/torchmetrics/text/eed.py | 2 +- src/torchmetrics/utilities/__init__.py | 8 +- src/torchmetrics/utilities/prints.py | 2 +- src/torchmetrics/wrappers/__init__.py | 2 +- src/torchmetrics/wrappers/tracker.py | 3 +- tests/integrations/test_lightning.py | 5 +- tests/unittests/__init__.py | 4 +- tests/unittests/_helpers/testers.py | 4 +- tests/unittests/audio/test_c_si_snr.py | 2 +- tests/unittests/audio/test_dnsmos.py | 2 +- tests/unittests/audio/test_nisqa.py | 2 +- tests/unittests/audio/test_pesq.py | 2 +- tests/unittests/audio/test_pit.py | 2 +- tests/unittests/audio/test_sa_sdr.py | 2 +- tests/unittests/audio/test_sdr.py | 2 +- tests/unittests/audio/test_si_sdr.py | 2 +- tests/unittests/audio/test_si_snr.py | 2 +- tests/unittests/audio/test_snr.py | 2 +- tests/unittests/audio/test_srmr.py | 2 +- tests/unittests/audio/test_stoi.py | 2 +- tests/unittests/bases/test_aggregation.py | 2 +- tests/unittests/bases/test_collections.py | 2 +- tests/unittests/bases/test_composition.py | 1 + tests/unittests/bases/test_ddp.py | 16 +- tests/unittests/bases/test_metric.py | 2 +- tests/unittests/bases/test_saving_loading.py | 1 + .../unittests/classification/test_accuracy.py | 2 +- tests/unittests/classification/test_auc.py | 2 +- tests/unittests/classification/test_auroc.py | 2 +- .../classification/test_average_precision.py | 2 +- .../classification/test_calibration_error.py | 2 +- .../classification/test_cohen_kappa.py | 2 +- .../classification/test_confusion_matrix.py | 2 +- tests/unittests/classification/test_dice.py | 2 +- .../classification/test_exact_match.py | 2 +- tests/unittests/classification/test_f_beta.py | 14 +- .../classification/test_group_fairness.py | 8 +- .../classification/test_hamming_distance.py | 2 +- tests/unittests/classification/test_hinge.py | 2 +- .../unittests/classification/test_jaccard.py | 2 +- tests/unittests/classification/test_logauc.py | 2 +- .../classification/test_matthews_corrcoef.py | 2 +- .../test_negative_predictive_value.py | 2 +- .../test_precision_fixed_recall.py | 2 +- .../classification/test_precision_recall.py | 2 +- .../test_precision_recall_curve.py | 2 +- .../unittests/classification/test_ranking.py | 2 +- .../test_recall_fixed_precision.py | 2 +- tests/unittests/classification/test_roc.py | 2 +- .../test_sensitivity_specificity.py | 2 +- .../classification/test_specificity.py | 2 +- .../test_specificity_sensitivity.py | 2 +- .../classification/test_stat_scores.py | 4 +- .../test_adjusted_mutual_info_score.py | 2 +- .../clustering/test_adjusted_rand_score.py | 2 +- .../test_calinski_harabasz_score.py | 2 +- .../clustering/test_davies_bouldin_score.py | 2 +- tests/unittests/clustering/test_dunn_index.py | 2 +- .../clustering/test_fowlkes_mallows_index.py | 2 +- ...test_homogeneity_completeness_v_measure.py | 2 +- .../clustering/test_mutual_info_score.py | 2 +- .../test_normalized_mutual_info_score.py | 2 +- tests/unittests/clustering/test_rand_score.py | 2 +- tests/unittests/clustering/test_utils.py | 2 +- tests/unittests/conftest.py | 4 +- .../deprecations/root_class_imports.py | 1 + .../unittests/detection/test_intersection.py | 1 + tests/unittests/detection/test_map.py | 2 +- .../test_modified_panoptic_quality.py | 2 +- .../detection/test_panoptic_quality.py | 2 +- tests/unittests/image/test_csi.py | 2 +- tests/unittests/image/test_d_lambda.py | 2 +- tests/unittests/image/test_d_s.py | 2 +- tests/unittests/image/test_ergas.py | 2 +- tests/unittests/image/test_fid.py | 2 +- tests/unittests/image/test_image_gradients.py | 1 + tests/unittests/image/test_inception.py | 8 +- tests/unittests/image/test_kid.py | 2 +- tests/unittests/image/test_lpips.py | 2 +- tests/unittests/image/test_mifid.py | 2 +- tests/unittests/image/test_ms_ssim.py | 2 +- .../image/test_perceptual_path_length.py | 2 +- tests/unittests/image/test_psnr.py | 2 +- tests/unittests/image/test_psnrb.py | 2 +- tests/unittests/image/test_qnr.py | 2 +- tests/unittests/image/test_rase.py | 2 +- tests/unittests/image/test_rmse_sw.py | 2 +- tests/unittests/image/test_sam.py | 2 +- tests/unittests/image/test_scc.py | 2 +- tests/unittests/image/test_ssim.py | 2 +- tests/unittests/image/test_tv.py | 2 +- tests/unittests/image/test_uqi.py | 2 +- tests/unittests/image/test_vif.py | 2 +- tests/unittests/multimodal/test_clip_iqa.py | 4 +- tests/unittests/multimodal/test_clip_score.py | 6 +- tests/unittests/nominal/test_cramers.py | 4 +- tests/unittests/nominal/test_fleiss_kappa.py | 2 +- tests/unittests/nominal/test_pearson.py | 4 +- tests/unittests/nominal/test_theils_u.py | 4 +- tests/unittests/nominal/test_tschuprows.py | 4 +- .../pairwise/test_pairwise_distance.py | 4 +- .../unittests/regression/test_concordance.py | 2 +- .../regression/test_cosine_similarity.py | 2 +- .../regression/test_explained_variance.py | 2 +- tests/unittests/regression/test_kendall.py | 2 +- .../regression/test_kl_divergence.py | 2 +- .../regression/test_log_cosh_error.py | 2 +- tests/unittests/regression/test_mean_error.py | 2 +- .../regression/test_minkowski_distance.py | 2 +- tests/unittests/regression/test_pearson.py | 2 +- tests/unittests/regression/test_r2.py | 2 +- tests/unittests/regression/test_rse.py | 2 +- tests/unittests/regression/test_spearman.py | 2 +- .../regression/test_tweedie_deviance.py | 2 +- tests/unittests/retrieval/test_auroc.py | 4 +- tests/unittests/retrieval/test_fallout.py | 4 +- tests/unittests/retrieval/test_hit_rate.py | 4 +- tests/unittests/retrieval/test_map.py | 4 +- tests/unittests/retrieval/test_mrr.py | 4 +- tests/unittests/retrieval/test_ndcg.py | 4 +- tests/unittests/retrieval/test_precision.py | 6 +- .../retrieval/test_precision_recall_curve.py | 4 +- tests/unittests/retrieval/test_r_precision.py | 4 +- tests/unittests/retrieval/test_recall.py | 4 +- tests/unittests/segmentation/test_dice.py | 2 +- .../test_generalized_dice_score.py | 2 +- .../segmentation/test_hausdorff_distance.py | 2 +- tests/unittests/segmentation/test_mean_iou.py | 2 +- tests/unittests/segmentation/test_utils.py | 1 + tests/unittests/shape/test_procrustes.py | 2 +- tests/unittests/test_deprecated.py | 1 + tests/unittests/text/_helpers.py | 2 +- tests/unittests/text/test_bertscore.py | 4 +- tests/unittests/text/test_bleu.py | 2 +- tests/unittests/text/test_cer.py | 2 +- tests/unittests/text/test_chrf.py | 2 +- tests/unittests/text/test_edit.py | 2 +- tests/unittests/text/test_eed.py | 2 +- tests/unittests/text/test_infolm.py | 2 +- tests/unittests/text/test_mer.py | 2 +- tests/unittests/text/test_perplexity.py | 2 +- tests/unittests/text/test_rouge.py | 10 +- tests/unittests/text/test_sacre_bleu.py | 2 +- tests/unittests/text/test_squad.py | 2 +- tests/unittests/text/test_ter.py | 2 +- tests/unittests/text/test_wer.py | 2 +- tests/unittests/text/test_wil.py | 2 +- tests/unittests/text/test_wip.py | 2 +- tests/unittests/utilities/test_plot.py | 1 + tests/unittests/utilities/test_utilities.py | 1 + .../unittests/wrappers/test_bootstrapping.py | 2 +- tests/unittests/wrappers/test_classwise.py | 1 + .../unittests/wrappers/test_feature_share.py | 13 +- tests/unittests/wrappers/test_minmax.py | 2 +- tests/unittests/wrappers/test_multioutput.py | 2 +- tests/unittests/wrappers/test_multitask.py | 2 +- tests/unittests/wrappers/test_running.py | 10 +- tests/unittests/wrappers/test_tracker.py | 2 +- .../wrappers/test_transformations.py | 2 +- 215 files changed, 466 insertions(+), 448 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b36b25e59f4..9be6fbb7fc6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -46,11 +46,10 @@ repos: exclude: pyproject.toml - repo: https://github.com/crate-ci/typos - rev: v1.22.9 + rev: dictgen-v0.3.1 hooks: - id: typos - # empty to do not write fixes - args: [] + args: [] # empty to do not write fixes exclude: pyproject.toml - repo: https://github.com/PyCQA/docformatter @@ -61,12 +60,12 @@ repos: args: ["--in-place"] - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.9.1 + rev: v1.0.0 hooks: - id: sphinx-lint - repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 + rev: 0.7.21 hooks: - id: mdformat args: ["--number"] @@ -113,7 +112,7 @@ repos: - id: text-unicode-replacement-char - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.8.6 hooks: # try to fix what is possible - id: ruff @@ -124,11 +123,11 @@ repos: - id: ruff - repo: https://github.com/tox-dev/pyproject-fmt - rev: 2.1.3 + rev: v2.5.0 hooks: - id: pyproject-fmt additional_dependencies: [tox] - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.18 + rev: v0.23 hooks: - id: validate-pyproject diff --git a/CHANGELOG.md b/CHANGELOG.md index 75f0842b15a..bb5ba2fcfb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -839,7 +839,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `PearsonCorrcoef` * `SpearmanCorrcoef` - Removed deprecated functions, and warnings in detection and pairwise ([#804](https://github.com/Lightning-AI/metrics/pull/804)) - * `MAP` and `functional.pairwise.manhatten` + * `MAP` and `functional.pairwise.manhattan` - Removed deprecated functions, and warnings in Audio ([#805](https://github.com/Lightning-AI/metrics/pull/805)) * `PESQ` and `functional.audio.pesq` * `PIT` and `functional.audio.pit` @@ -1029,7 +1029,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `pairwise_cosine_similarity` - `pairwise_euclidean_distance` - `pairwise_linear_similarity` - - `pairwise_manhatten_distance` + - `pairwise_manhattan_distance` ### Changed diff --git a/_samples/bert_score-own_model.py b/_samples/bert_score-own_model.py index 982d7f63876..1a3ec618d53 100644 --- a/_samples/bert_score-own_model.py +++ b/_samples/bert_score-own_model.py @@ -23,6 +23,7 @@ import torch from torch import Tensor, nn from torch.nn import Module + from torchmetrics.text.bert import BERTScore _NUM_LAYERS = 2 diff --git a/_samples/detection_map.py b/_samples/detection_map.py index 2bdb907d461..83175ed4390 100644 --- a/_samples/detection_map.py +++ b/_samples/detection_map.py @@ -14,6 +14,7 @@ """An example of how the predictions and target should be defined for the MAP object detection metric.""" from torch import BoolTensor, IntTensor, Tensor + from torchmetrics.detection.mean_ap import MeanAveragePrecision # Preds should be a list of elements, where each element is a dict diff --git a/docs/source/conf.py b/docs/source/conf.py index 81f842e7a12..e7c65a549ca 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -18,9 +18,10 @@ from typing import Optional import lai_sphinx_theme -import torchmetrics from lightning_utilities.docs.formatting import _linkcode_resolve, _transform_changelog +import torchmetrics + _PATH_HERE = os.path.abspath(os.path.dirname(__file__)) _PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", "..")) sys.path.insert(0, os.path.abspath(_PATH_ROOT)) diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index 5ab044e1be1..1288b529dac 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -30,7 +30,7 @@ We provide the remaining interface, such as ``reset()`` that will make sure to c states that have been added using ``add_state``. You should therefore not implement ``reset()`` yourself, only in rare cases where not all the state variables should be reset to their default value. Adding metric states with ``add_state`` will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are -synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state()` docs from the base +synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state` docs from the base :class:`~torchmetrics.Metric` class. Below is a basic implementation of a custom accuracy metric. In the ``__init__`` method we add the metric states diff --git a/docs/source/pyplots/binary_accuracy.py b/docs/source/pyplots/binary_accuracy.py index 8bca70c5f3d..90539961e4f 100644 --- a/docs/source/pyplots/binary_accuracy.py +++ b/docs/source/pyplots/binary_accuracy.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch + import torchmetrics N = 10 diff --git a/docs/source/pyplots/binary_accuracy_multistep.py b/docs/source/pyplots/binary_accuracy_multistep.py index 3d878495736..c2f29b6e2a4 100644 --- a/docs/source/pyplots/binary_accuracy_multistep.py +++ b/docs/source/pyplots/binary_accuracy_multistep.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch + import torchmetrics N = 10 diff --git a/docs/source/pyplots/collection_binary.py b/docs/source/pyplots/collection_binary.py index fadd8a15c71..9f4b34de74b 100644 --- a/docs/source/pyplots/collection_binary.py +++ b/docs/source/pyplots/collection_binary.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch + import torchmetrics N = 10 diff --git a/docs/source/pyplots/collection_binary_together.py b/docs/source/pyplots/collection_binary_together.py index 7f804a18320..f58f104f4e0 100644 --- a/docs/source/pyplots/collection_binary_together.py +++ b/docs/source/pyplots/collection_binary_together.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch + import torchmetrics N = 10 diff --git a/docs/source/pyplots/confusion_matrix.py b/docs/source/pyplots/confusion_matrix.py index b41e56add4c..1ab881bded2 100644 --- a/docs/source/pyplots/confusion_matrix.py +++ b/docs/source/pyplots/confusion_matrix.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch + import torchmetrics N = 10 diff --git a/docs/source/pyplots/multiclass_accuracy.py b/docs/source/pyplots/multiclass_accuracy.py index 9c26be67a0d..ba08c16859c 100644 --- a/docs/source/pyplots/multiclass_accuracy.py +++ b/docs/source/pyplots/multiclass_accuracy.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch + import torchmetrics N = 10 diff --git a/docs/source/pyplots/tracker_binary.py b/docs/source/pyplots/tracker_binary.py index 50197cd16de..64f93453417 100644 --- a/docs/source/pyplots/tracker_binary.py +++ b/docs/source/pyplots/tracker_binary.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch + import torchmetrics N = 10 diff --git a/examples/audio/pesq.py b/examples/audio/pesq.py index 6afde2bfdd5..75a705e17f9 100644 --- a/examples/audio/pesq.py +++ b/examples/audio/pesq.py @@ -19,6 +19,7 @@ import numpy as np import torch import torchaudio + from torchmetrics.audio import PerceptualEvaluationSpeechQuality # %% diff --git a/examples/audio/signal_to_noise_ratio.py b/examples/audio/signal_to_noise_ratio.py index 7099fc08d2b..56cf02383e4 100644 --- a/examples/audio/signal_to_noise_ratio.py +++ b/examples/audio/signal_to_noise_ratio.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt import numpy as np import torch + from torchmetrics.audio import SignalNoiseRatio # %% diff --git a/examples/image/clip_score.py b/examples/image/clip_score.py index f73c5d68333..019eb5c8f56 100644 --- a/examples/image/clip_score.py +++ b/examples/image/clip_score.py @@ -15,6 +15,7 @@ import torch from matplotlib.table import Table from skimage.data import astronaut, cat, coffee + from torchmetrics.multimodal import CLIPScore # %% diff --git a/examples/image/spatial_correlation_coef.py b/examples/image/spatial_correlation_coef.py index d5a5296aa82..aea790ebdef 100644 --- a/examples/image/spatial_correlation_coef.py +++ b/examples/image/spatial_correlation_coef.py @@ -16,6 +16,7 @@ import torch from skimage.data import shepp_logan_phantom from skimage.transform import iradon, radon, rescale + from torchmetrics.image import SpatialCorrelationCoefficient # %% diff --git a/examples/text/bertscore.py b/examples/text/bertscore.py index 09e2fbff418..f8e7963f4a5 100644 --- a/examples/text/bertscore.py +++ b/examples/text/bertscore.py @@ -6,9 +6,10 @@ Let's consider a use case in natural language processing where BERTScore is used to evaluate the quality of a text generation model. In this case we are imaging that we are developing a automated news summarization system. The goal is to create concise summaries of news articles that accurately capture the key points of the original articles. To evaluate the performance of your summarization system, you need a metric that can compare the generated summaries to human-written summaries. This is where the BERTScore can be used. """ -from torchmetrics.text import BERTScore, ROUGEScore from transformers import AutoTokenizer, pipeline +from torchmetrics.text import BERTScore, ROUGEScore + pipe = pipeline("text-generation", model="openai-community/gpt2") tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") diff --git a/examples/text/perplexity.py b/examples/text/perplexity.py index a5cd3a4fa01..69b2b106d5f 100644 --- a/examples/text/perplexity.py +++ b/examples/text/perplexity.py @@ -12,9 +12,10 @@ # Here's a hypothetical Python example demonstrating the usage of Perplexity to evaluate a generative language model import torch -from torchmetrics.text import Perplexity from transformers import AutoModelWithLMHead, AutoTokenizer +from torchmetrics.text import Perplexity + # %% # Load the GPT-2 model and tokenizer diff --git a/examples/text/rouge.py b/examples/text/rouge.py index a76716bc92a..c85d419d646 100644 --- a/examples/text/rouge.py +++ b/examples/text/rouge.py @@ -9,9 +9,10 @@ # %% # Here's a hypothetical Python example demonstrating the usage of unigram ROUGE F-score to evaluate a generative language model: -from torchmetrics.text import ROUGEScore from transformers import AutoTokenizer, pipeline +from torchmetrics.text import ROUGEScore + pipe = pipeline("text-generation", model="openai-community/gpt2") tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") diff --git a/src/conftest.py b/src/conftest.py index 5f4a26123d3..f093fe19767 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -11,7 +11,7 @@ MANUAL_SEED = doctest.register_optionflag("MANUAL_SEED") @pytest.fixture(autouse=True) - def reset_random_seed(seed: int = 42) -> None: # noqa: PT004 + def reset_random_seed(seed: int = 42) -> None: """Reset the random seed before running each doctest.""" import random diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 4ff997183f6..536db30b867 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -169,6 +169,7 @@ __all__ = [ "AUROC", + "ROC", "Accuracy", "AveragePrecision", "BLEUScore", @@ -229,7 +230,6 @@ "PrecisionAtFixedRecall", "PrecisionRecallCurve", "R2Score", - "ROC", "Recall", "RecallAtFixedPrecision", "RelativeAverageSpectralError", diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index 312197ccbc4..d4784a13c57 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -65,8 +65,7 @@ def __init__( allowed_nan_strategy = ("error", "warn", "ignore") if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float): raise ValueError( - f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy}" - f" but got {nan_strategy}." + f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}." ) self.nan_strategy = nan_strategy diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 24ff9e737e8..da5271a4cda 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -41,13 +41,13 @@ scipy.signal.hamming = scipy.signal.windows.hamming __all__ = [ + "ComplexScaleInvariantSignalNoiseRatio", "PermutationInvariantTraining", "ScaleInvariantSignalDistortionRatio", - "SignalDistortionRatio", - "SourceAggregatedSignalDistortionRatio", "ScaleInvariantSignalNoiseRatio", + "SignalDistortionRatio", "SignalNoiseRatio", - "ComplexScaleInvariantSignalNoiseRatio", + "SourceAggregatedSignalDistortionRatio", ] if _PESQ_AVAILABLE: diff --git a/src/torchmetrics/audio/srmr.py b/src/torchmetrics/audio/srmr.py index 453f1bb7eab..d0d23341c28 100644 --- a/src/torchmetrics/audio/srmr.py +++ b/src/torchmetrics/audio/srmr.py @@ -58,7 +58,7 @@ class SpeechReverberationModulationEnergyRatio(Metric): This implementation is experimental, and might not be consistent with the matlab implementation `SRMRToolbox`_, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have - a relatively small inconsistence. + a relatively small inconsistency. Args: fs: the sampling rate diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index bbc5321bf7a..a6e49e79a1b 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -130,106 +130,106 @@ ) __all__ = [ - "Accuracy", - "BinaryAccuracy", - "MulticlassAccuracy", - "MultilabelAccuracy", "AUROC", - "BinaryAUROC", - "MulticlassAUROC", - "MultilabelAUROC", + "ROC", + "Accuracy", "AveragePrecision", + "BinaryAUROC", + "BinaryAccuracy", "BinaryAveragePrecision", - "MulticlassAveragePrecision", - "MultilabelAveragePrecision", "BinaryCalibrationError", - "CalibrationError", - "MulticlassCalibrationError", "BinaryCohenKappa", - "CohenKappa", - "MulticlassCohenKappa", "BinaryConfusionMatrix", + "BinaryF1Score", + "BinaryFBetaScore", + "BinaryFairness", + "BinaryGroupStatRates", + "BinaryHammingDistance", + "BinaryHingeLoss", + "BinaryJaccardIndex", + "BinaryLogAUC", + "BinaryMatthewsCorrCoef", + "BinaryNegativePredictiveValue", + "BinaryPrecision", + "BinaryPrecisionAtFixedRecall", + "BinaryPrecisionRecallCurve", + "BinaryROC", + "BinaryRecall", + "BinaryRecallAtFixedPrecision", + "BinarySensitivityAtSpecificity", + "BinarySpecificity", + "BinarySpecificityAtSensitivity", + "BinaryStatScores", + "CalibrationError", + "CohenKappa", "ConfusionMatrix", - "MulticlassConfusionMatrix", - "MultilabelConfusionMatrix", "Dice", "ExactMatch", - "MulticlassExactMatch", - "MultilabelExactMatch", - "BinaryF1Score", - "BinaryFBetaScore", "F1Score", "FBetaScore", + "HammingDistance", + "HingeLoss", + "JaccardIndex", + "LogAUC", + "MatthewsCorrCoef", + "MulticlassAUROC", + "MulticlassAccuracy", + "MulticlassAveragePrecision", + "MulticlassCalibrationError", + "MulticlassCohenKappa", + "MulticlassConfusionMatrix", + "MulticlassExactMatch", "MulticlassF1Score", "MulticlassFBetaScore", - "MultilabelF1Score", - "MultilabelFBetaScore", - "BinaryFairness", - "BinaryGroupStatRates", - "BinaryHammingDistance", - "HammingDistance", "MulticlassHammingDistance", - "MultilabelHammingDistance", - "BinaryHingeLoss", - "HingeLoss", "MulticlassHingeLoss", - "BinaryJaccardIndex", - "JaccardIndex", "MulticlassJaccardIndex", - "MultilabelJaccardIndex", - "BinaryMatthewsCorrCoef", - "MatthewsCorrCoef", + "MulticlassLogAUC", "MulticlassMatthewsCorrCoef", - "MultilabelMatthewsCorrCoef", - "BinaryPrecision", - "BinaryRecall", + "MulticlassNegativePredictiveValue", "MulticlassPrecision", + "MulticlassPrecisionAtFixedRecall", + "MulticlassPrecisionRecallCurve", + "MulticlassROC", "MulticlassRecall", + "MulticlassRecallAtFixedPrecision", + "MulticlassSensitivityAtSpecificity", + "MulticlassSpecificity", + "MulticlassSpecificityAtSensitivity", + "MulticlassStatScores", + "MultilabelAUROC", + "MultilabelAccuracy", + "MultilabelAveragePrecision", + "MultilabelConfusionMatrix", + "MultilabelCoverageError", + "MultilabelExactMatch", + "MultilabelF1Score", + "MultilabelFBetaScore", + "MultilabelHammingDistance", + "MultilabelJaccardIndex", + "MultilabelLogAUC", + "MultilabelMatthewsCorrCoef", + "MultilabelNegativePredictiveValue", "MultilabelPrecision", - "MultilabelRecall", - "Precision", - "Recall", - "BinaryPrecisionRecallCurve", - "MulticlassPrecisionRecallCurve", + "MultilabelPrecisionAtFixedRecall", "MultilabelPrecisionRecallCurve", - "PrecisionRecallCurve", - "MultilabelCoverageError", + "MultilabelROC", "MultilabelRankingAveragePrecision", "MultilabelRankingLoss", - "RecallAtFixedPrecision", - "BinaryRecallAtFixedPrecision", - "MulticlassRecallAtFixedPrecision", + "MultilabelRecall", "MultilabelRecallAtFixedPrecision", - "BinaryROC", - "MulticlassROC", - "MultilabelROC", - "ROC", - "BinarySpecificity", - "MulticlassSpecificity", + "MultilabelSensitivityAtSpecificity", "MultilabelSpecificity", - "Specificity", - "BinarySpecificityAtSensitivity", - "MulticlassSpecificityAtSensitivity", "MultilabelSpecificityAtSensitivity", - "SpecificityAtSensitivity", - "BinaryStatScores", - "MulticlassStatScores", "MultilabelStatScores", - "StatScores", + "NegativePredictiveValue", + "Precision", "PrecisionAtFixedRecall", - "BinaryPrecisionAtFixedRecall", - "MulticlassPrecisionAtFixedRecall", - "MultilabelPrecisionAtFixedRecall", - "BinarySensitivityAtSpecificity", - "MulticlassSensitivityAtSpecificity", - "MultilabelSensitivityAtSpecificity", + "PrecisionRecallCurve", + "Recall", + "RecallAtFixedPrecision", "SensitivityAtSpecificity", - "BinaryLogAUC", - "LogAUC", - "MulticlassLogAUC", - "MultilabelLogAUC", - "BinaryNegativePredictiveValue", - "MulticlassNegativePredictiveValue", - "MultilabelNegativePredictiveValue", - "NegativePredictiveValue", + "Specificity", + "SpecificityAtSensitivity", + "StatScores", ] diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 5f9a15b2e4f..c5f2e41f37b 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -40,7 +40,7 @@ class BinaryJaccardIndex(BinaryConfusionMatrix): r"""Calculate the Jaccard index for binary tasks. - The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic + The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: @@ -158,7 +158,7 @@ def plot( # type: ignore[override] class MulticlassJaccardIndex(MulticlassConfusionMatrix): r"""Calculate the Jaccard index for multiclass tasks. - The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic + The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: @@ -295,7 +295,7 @@ def plot( # type: ignore[override] class MultilabelJaccardIndex(MultilabelConfusionMatrix): r"""Calculate the Jaccard index for multilabel tasks. - The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic + The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: @@ -434,7 +434,7 @@ def plot( # type: ignore[override] class JaccardIndex(_ClassificationTaskWrapper): r"""Calculate the Jaccard index for multilabel tasks. - The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic + The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index e034027fe0d..c9157425dea 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -311,7 +311,7 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: state1 = getattr(metric1, key) state2 = getattr(metric2, key) - if type(state1) != type(state2): + if type(state1) != type(state2): # noqa: E721 return False if isinstance(state1, Tensor) and isinstance(state2, Tensor): diff --git a/src/torchmetrics/detection/__init__.py b/src/torchmetrics/detection/__init__.py index 7932d4b33b8..968135b0428 100644 --- a/src/torchmetrics/detection/__init__.py +++ b/src/torchmetrics/detection/__init__.py @@ -24,9 +24,9 @@ from torchmetrics.detection.mean_ap import MeanAveragePrecision __all__ += [ - "MeanAveragePrecision", - "GeneralizedIntersectionOverUnion", - "IntersectionOverUnion", "CompleteIntersectionOverUnion", "DistanceIntersectionOverUnion", + "GeneralizedIntersectionOverUnion", + "IntersectionOverUnion", + "MeanAveragePrecision", ] diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index 9831842734d..9be56153568 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -95,13 +95,13 @@ def __delattr__(self, key: str) -> None: class MAPMetricResults(BaseMetricResults): """Class to wrap the final mAP results.""" - __slots__ = ("map", "map_50", "map_75", "map_small", "map_medium", "map_large", "classes") + __slots__ = ("classes", "map", "map_50", "map_75", "map_large", "map_medium", "map_small") class MARMetricResults(BaseMetricResults): """Class to wrap the final mAR results.""" - __slots__ = ("mar_1", "mar_10", "mar_100", "mar_small", "mar_medium", "mar_large") + __slots__ = ("mar_1", "mar_10", "mar_100", "mar_large", "mar_medium", "mar_small") class COCOMetricResults(BaseMetricResults): @@ -111,17 +111,17 @@ class COCOMetricResults(BaseMetricResults): "map", "map_50", "map_75", - "map_small", - "map_medium", "map_large", + "map_medium", + "map_per_class", + "map_small", "mar_1", "mar_10", "mar_100", - "mar_small", - "mar_medium", - "mar_large", - "map_per_class", "mar_100_per_class", + "mar_large", + "mar_medium", + "mar_small", ) diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index a4f175ce02d..ed772eeba26 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -170,8 +170,8 @@ "jaccard_index", "kendall_rank_corrcoef", "kl_divergence", - "logauc", "log_cosh_error", + "logauc", "match_error_rate", "matthews_corrcoef", "mean_absolute_error", diff --git a/src/torchmetrics/functional/audio/__init__.py b/src/torchmetrics/functional/audio/__init__.py index edd9e9283e0..09faa97334a 100644 --- a/src/torchmetrics/functional/audio/__init__.py +++ b/src/torchmetrics/functional/audio/__init__.py @@ -41,14 +41,14 @@ scipy.signal.hamming = scipy.signal.windows.hamming __all__ = [ + "complex_scale_invariant_signal_noise_ratio", "permutation_invariant_training", "pit_permutate", "scale_invariant_signal_distortion_ratio", - "source_aggregated_signal_distortion_ratio", - "signal_distortion_ratio", "scale_invariant_signal_noise_ratio", + "signal_distortion_ratio", "signal_noise_ratio", - "complex_scale_invariant_signal_noise_ratio", + "source_aggregated_signal_distortion_ratio", ] if _PESQ_AVAILABLE: diff --git a/src/torchmetrics/functional/audio/srmr.py b/src/torchmetrics/functional/audio/srmr.py index 20ad898fe8d..9495132b59c 100644 --- a/src/torchmetrics/functional/audio/srmr.py +++ b/src/torchmetrics/functional/audio/srmr.py @@ -211,7 +211,7 @@ def speech_reverberation_modulation_energy_ratio( This implementation is experimental, and might not be consistent with the matlab implementation `SRMRToolbox`_, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have - a relatively small inconsistence. + a relatively small inconsistency. Returns: Scalar tensor with srmr value with shape ``(...)`` diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 925f977e419..247b70be99e 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -145,108 +145,108 @@ __all__ = [ "accuracy", - "binary_accuracy", - "multiclass_accuracy", - "multilabel_accuracy", "auroc", - "binary_auroc", - "multiclass_auroc", - "multilabel_auroc", "average_precision", + "binary_accuracy", + "binary_auroc", "binary_average_precision", - "multiclass_average_precision", - "multilabel_average_precision", "binary_calibration_error", - "calibration_error", - "multiclass_calibration_error", "binary_cohen_kappa", - "cohen_kappa", - "multiclass_cohen_kappa", "binary_confusion_matrix", + "binary_f1_score", + "binary_fairness", + "binary_fbeta_score", + "binary_groups_stat_rates", + "binary_hamming_distance", + "binary_hinge_loss", + "binary_jaccard_index", + "binary_logauc", + "binary_matthews_corrcoef", + "binary_negative_predictive_value", + "binary_precision", + "binary_precision_at_fixed_recall", + "binary_precision_recall_curve", + "binary_recall", + "binary_recall_at_fixed_precision", + "binary_roc", + "binary_sensitivity_at_specificity", + "binary_specificity", + "binary_specificity_at_sensitivity", + "binary_stat_scores", + "calibration_error", + "cohen_kappa", "confusion_matrix", - "multiclass_confusion_matrix", - "multilabel_confusion_matrix", - "generalized_dice_score", + "demographic_parity", "dice", + "equal_opportunity", "exact_match", - "multiclass_exact_match", - "multilabel_exact_match", - "binary_f1_score", - "binary_fbeta_score", "f1_score", "fbeta_score", + "generalized_dice_score", + "hamming_distance", + "hinge_loss", + "jaccard_index", + "logauc", + "matthews_corrcoef", + "multiclass_accuracy", + "multiclass_auroc", + "multiclass_average_precision", + "multiclass_calibration_error", + "multiclass_cohen_kappa", + "multiclass_confusion_matrix", + "multiclass_exact_match", "multiclass_f1_score", "multiclass_fbeta_score", - "multilabel_f1_score", - "multilabel_fbeta_score", - "binary_fairness", - "binary_groups_stat_rates", - "binary_hamming_distance", - "hamming_distance", "multiclass_hamming_distance", - "multilabel_hamming_distance", - "binary_hinge_loss", - "hinge_loss", "multiclass_hinge_loss", - "binary_jaccard_index", - "jaccard_index", "multiclass_jaccard_index", - "multilabel_jaccard_index", - "binary_matthews_corrcoef", - "matthews_corrcoef", + "multiclass_logauc", "multiclass_matthews_corrcoef", - "multilabel_matthews_corrcoef", - "binary_precision", - "binary_recall", + "multiclass_negative_predictive_value", "multiclass_precision", + "multiclass_precision_at_fixed_recall", + "multiclass_precision_recall_curve", "multiclass_recall", + "multiclass_recall_at_fixed_precision", + "multiclass_roc", + "multiclass_sensitivity_at_specificity", + "multiclass_specificity", + "multiclass_specificity_at_sensitivity", + "multiclass_stat_scores", + "multilabel_accuracy", + "multilabel_auroc", + "multilabel_average_precision", + "multilabel_confusion_matrix", + "multilabel_coverage_error", + "multilabel_exact_match", + "multilabel_f1_score", + "multilabel_fbeta_score", + "multilabel_hamming_distance", + "multilabel_jaccard_index", + "multilabel_logauc", + "multilabel_matthews_corrcoef", + "multilabel_negative_predictive_value", "multilabel_precision", - "multilabel_recall", - "precision", - "recall", - "binary_precision_recall_curve", - "multiclass_precision_recall_curve", + "multilabel_precision_at_fixed_recall", "multilabel_precision_recall_curve", - "precision_recall_curve", - "multilabel_coverage_error", "multilabel_ranking_average_precision", "multilabel_ranking_loss", - "recall_at_fixed_precision", - "binary_recall_at_fixed_precision", - "multiclass_recall_at_fixed_precision", + "multilabel_recall", "multilabel_recall_at_fixed_precision", - "binary_roc", - "multiclass_roc", "multilabel_roc", - "roc", - "binary_sensitivity_at_specificity", - "multiclass_sensitivity_at_specificity", "multilabel_sensitivity_at_specificity", - "sensitivity_at_specificity", - "binary_specificity", - "multiclass_specificity", "multilabel_specificity", - "specificity", - "binary_specificity_at_sensitivity", - "multiclass_specificity_at_sensitivity", "multilabel_specificity_at_sensitivity", - "specificity_at_sensitivity", - "binary_stat_scores", - "multiclass_stat_scores", "multilabel_stat_scores", - "stat_scores", - "binary_precision_at_fixed_recall", - "multilabel_precision_at_fixed_recall", - "multiclass_precision_at_fixed_recall", - "demographic_parity", - "equal_opportunity", - "precision_at_fixed_recall", - "binary_logauc", - "multiclass_logauc", - "multilabel_logauc", - "logauc", - "binary_negative_predictive_value", - "multiclass_negative_predictive_value", - "multilabel_negative_predictive_value", "negative_predictive_value", + "precision", + "precision_at_fixed_recall", + "precision_recall_curve", + "recall", + "recall_at_fixed_precision", + "roc", + "sensitivity_at_specificity", + "specificity", + "specificity_at_sensitivity", + "stat_scores", ] diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index dfddd68255f..d3194a13c5f 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -107,7 +107,7 @@ def binary_jaccard_index( ) -> Tensor: r"""Calculate the Jaccard index for binary tasks. - The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic + The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: @@ -179,7 +179,7 @@ def multiclass_jaccard_index( ) -> Tensor: r"""Calculate the Jaccard index for multiclass tasks. - The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic + The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: @@ -264,7 +264,7 @@ def multilabel_jaccard_index( ) -> Tensor: r"""Calculate the Jaccard index for multilabel tasks. - The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic + The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: @@ -337,7 +337,7 @@ def jaccard_index( ) -> Tensor: r"""Calculate the Jaccard index. - The `Jaccard index`_ (also known as the intersetion over union or jaccard similarity coefficient) is an statistic + The `Jaccard index`_ (also known as the intersection over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: diff --git a/src/torchmetrics/functional/detection/__init__.py b/src/torchmetrics/functional/detection/__init__.py index fab5ccb5f91..a1dfbef7fcd 100644 --- a/src/torchmetrics/functional/detection/__init__.py +++ b/src/torchmetrics/functional/detection/__init__.py @@ -26,8 +26,8 @@ from torchmetrics.functional.detection.iou import intersection_over_union __all__ += [ - "generalized_intersection_over_union", - "intersection_over_union", "complete_intersection_over_union", "distance_intersection_over_union", + "generalized_intersection_over_union", + "intersection_over_union", ] diff --git a/src/torchmetrics/functional/detection/_panoptic_quality_common.py b/src/torchmetrics/functional/detection/_panoptic_quality_common.py index 16d0463ba45..1cf9b218d2f 100644 --- a/src/torchmetrics/functional/detection/_panoptic_quality_common.py +++ b/src/torchmetrics/functional/detection/_panoptic_quality_common.py @@ -112,8 +112,7 @@ def _validate_inputs(preds: Tensor, target: torch.Tensor) -> None: ) if preds.dim() < 3: raise ValueError( - "Expected argument `preds` to have at least one spatial dimension (B, *spatial_dims, 2), " - f"got {preds.shape}" + f"Expected argument `preds` to have at least one spatial dimension (B, *spatial_dims, 2), got {preds.shape}" ) if preds.shape[-1] != 2: raise ValueError( diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index d485179b246..4ab230c81eb 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -33,22 +33,22 @@ from torchmetrics.functional.image.vif import visual_information_fidelity __all__ = [ - "spectral_distortion_index", - "spatial_distortion_index", "error_relative_global_dimensionless_synthesis", "image_gradients", + "learned_perceptual_image_patch_similarity", + "multiscale_structural_similarity_index_measure", "peak_signal_noise_ratio", "peak_signal_noise_ratio_with_blocked_effect", + "perceptual_path_length", + "quality_with_no_reference", "relative_average_spectral_error", "root_mean_squared_error_using_sliding_window", + "spatial_correlation_coefficient", + "spatial_distortion_index", "spectral_angle_mapper", - "multiscale_structural_similarity_index_measure", + "spectral_distortion_index", "structural_similarity_index_measure", "total_variation", "universal_image_quality_index", "visual_information_fidelity", - "learned_perceptual_image_patch_similarity", - "perceptual_path_length", - "spatial_correlation_coefficient", - "quality_with_no_reference", ] diff --git a/src/torchmetrics/functional/image/ergas.py b/src/torchmetrics/functional/image/ergas.py index ae940250dbe..45d14ccaddc 100644 --- a/src/torchmetrics/functional/image/ergas.py +++ b/src/torchmetrics/functional/image/ergas.py @@ -36,8 +36,7 @@ def _ergas_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: _check_same_shape(preds, target) if len(preds.shape) != 4: raise ValueError( - "Expected `preds` and `target` to have BxCxHxW shape." - f" Got preds: {preds.shape} and target: {target.shape}." + f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}." ) return preds, target diff --git a/src/torchmetrics/functional/image/rase.py b/src/torchmetrics/functional/image/rase.py index 832a759b562..51181852aa6 100644 --- a/src/torchmetrics/functional/image/rase.py +++ b/src/torchmetrics/functional/image/rase.py @@ -90,7 +90,7 @@ def relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: ValueError: If ``window_size`` is not a positive integer. """ - if not isinstance(window_size, int) or isinstance(window_size, int) and window_size < 1: + if not isinstance(window_size, int) or (isinstance(window_size, int) and window_size < 1): raise ValueError("Argument `window_size` is expected to be a positive integer.") img_shape = target.shape[1:] # [num_channels, width, height] diff --git a/src/torchmetrics/functional/image/rmse_sw.py b/src/torchmetrics/functional/image/rmse_sw.py index 3b0eaf6221f..ba1e88710df 100644 --- a/src/torchmetrics/functional/image/rmse_sw.py +++ b/src/torchmetrics/functional/image/rmse_sw.py @@ -136,7 +136,7 @@ def root_mean_squared_error_using_sliding_window( ValueError: If ``window_size`` is not a positive integer. """ - if not isinstance(window_size, int) or isinstance(window_size, int) and window_size < 1: + if not isinstance(window_size, int) or (isinstance(window_size, int) and window_size < 1): raise ValueError("Argument `window_size` is expected to be a positive integer.") rmse_val_sum, rmse_map, total_images = _rmse_sw_update( diff --git a/src/torchmetrics/functional/image/sam.py b/src/torchmetrics/functional/image/sam.py index 21efde6f9c4..af5edb5f41e 100644 --- a/src/torchmetrics/functional/image/sam.py +++ b/src/torchmetrics/functional/image/sam.py @@ -36,8 +36,7 @@ def _sam_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: _check_same_shape(preds, target) if len(preds.shape) != 4: raise ValueError( - "Expected `preds` and `target` to have BxCxHxW shape." - f" Got preds: {preds.shape} and target: {target.shape}." + f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}." ) if (preds.shape[1] <= 1) or (target.shape[1] <= 1): raise ValueError( diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index 30b5e781e55..2b8d6f3cfa7 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -39,8 +39,7 @@ def _uqi_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: _check_same_shape(preds, target) if len(preds.shape) != 4: raise ValueError( - "Expected `preds` and `target` to have BxCxHxW shape." - f" Got preds: {preds.shape} and target: {target.shape}." + f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}." ) return preds, target diff --git a/src/torchmetrics/functional/multimodal/__init__.py b/src/torchmetrics/functional/multimodal/__init__.py index 812f5e69805..e32c32e3fbf 100644 --- a/src/torchmetrics/functional/multimodal/__init__.py +++ b/src/torchmetrics/functional/multimodal/__init__.py @@ -17,4 +17,4 @@ from torchmetrics.functional.multimodal.clip_iqa import clip_image_quality_assessment from torchmetrics.functional.multimodal.clip_score import clip_score - __all__ = ["clip_score", "clip_image_quality_assessment"] + __all__ = ["clip_image_quality_assessment", "clip_score"] diff --git a/src/torchmetrics/functional/regression/r2.py b/src/torchmetrics/functional/regression/r2.py index f52d082fc74..c8227ae5ccf 100644 --- a/src/torchmetrics/functional/regression/r2.py +++ b/src/torchmetrics/functional/regression/r2.py @@ -108,8 +108,7 @@ def _r2_score_compute( if adjusted != 0: if adjusted > num_obs - 1: rank_zero_warn( - "More independent regressions than data points in" - " adjusted r2 score. Falls back to standard r2 score.", + "More independent regressions than data points in adjusted r2 score. Falls back to standard r2 score.", UserWarning, ) elif adjusted == num_obs - 1: diff --git a/src/torchmetrics/functional/segmentation/__init__.py b/src/torchmetrics/functional/segmentation/__init__.py index d54807a3c18..5f65ea00d0f 100644 --- a/src/torchmetrics/functional/segmentation/__init__.py +++ b/src/torchmetrics/functional/segmentation/__init__.py @@ -16,4 +16,4 @@ from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance from torchmetrics.functional.segmentation.mean_iou import mean_iou -__all__ = ["generalized_dice_score", "mean_iou", "hausdorff_distance", "dice_score"] +__all__ = ["dice_score", "generalized_dice_score", "hausdorff_distance", "mean_iou"] diff --git a/src/torchmetrics/functional/text/eed.py b/src/torchmetrics/functional/text/eed.py index bde77680bea..20bc367d8f0 100644 --- a/src/torchmetrics/functional/text/eed.py +++ b/src/torchmetrics/functional/text/eed.py @@ -403,7 +403,7 @@ def extended_edit_distance( """ # input validation for parameters for param_name, param in zip(["alpha", "rho", "deletion", "insertion"], [alpha, rho, deletion, insertion]): - if not isinstance(param, float) or isinstance(param, float) and param < 0: + if not isinstance(param, float) or (isinstance(param, float) and param < 0): raise ValueError(f"Parameter `{param_name}` is expected to be a non-negative float.") sentence_level_scores = _eed_update(preds, target, language, alpha, rho, deletion, insertion) diff --git a/src/torchmetrics/functional/text/helper_embedding_metric.py b/src/torchmetrics/functional/text/helper_embedding_metric.py index 17c89558163..19c77f767df 100644 --- a/src/torchmetrics/functional/text/helper_embedding_metric.py +++ b/src/torchmetrics/functional/text/helper_embedding_metric.py @@ -143,7 +143,7 @@ def _get_progress_bar(dataloader: DataLoader, verbose: bool = False) -> Union[Da """Wrap dataloader in progressbar if asked for. Function will return either the dataloader itself when `verbose = False`, or it wraps the dataloader with - `tqdm.auto.tqdm`, when `verbose = True` to display a progress bar during the embbeddings calculation. + `tqdm.auto.tqdm`, when `verbose = True` to display a progress bar during the embeddings calculation. """ import tqdm diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index 1ac0e22e2bc..565ca527479 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -32,22 +32,22 @@ ) __all__ = [ - "SpectralDistortionIndex", - "SpatialDistortionIndex", "ErrorRelativeGlobalDimensionlessSynthesis", + "MemorizationInformedFrechetInceptionDistance", + "MultiScaleStructuralSimilarityIndexMeasure", "PeakSignalNoiseRatio", "PeakSignalNoiseRatioWithBlockedEffect", + "QualityWithNoReference", "RelativeAverageSpectralError", "RootMeanSquaredErrorUsingSlidingWindow", + "SpatialCorrelationCoefficient", + "SpatialDistortionIndex", "SpectralAngleMapper", - "MultiScaleStructuralSimilarityIndexMeasure", - "MemorizationInformedFrechetInceptionDistance", + "SpectralDistortionIndex", "StructuralSimilarityIndexMeasure", + "TotalVariation", "UniversalImageQualityIndex", "VisualInformationFidelity", - "TotalVariation", - "SpatialCorrelationCoefficient", - "QualityWithNoReference", ] if _TORCH_FIDELITY_AVAILABLE: diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index bca9504c1aa..26ecdccfac2 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -74,7 +74,7 @@ def __init__( ) -> None: super().__init__(**kwargs) - if not isinstance(window_size, int) or isinstance(window_size, int) and window_size < 1: + if not isinstance(window_size, int) or (isinstance(window_size, int) and window_size < 1): raise ValueError(f"Argument `window_size` is expected to be a positive integer, but got {window_size}") self.window_size = window_size diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index 6312174b1be..032705e66a8 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -72,7 +72,7 @@ def __init__( **kwargs: dict[str, Any], ) -> None: super().__init__(**kwargs) - if not isinstance(window_size, int) or isinstance(window_size, int) and window_size < 1: + if not isinstance(window_size, int) or (isinstance(window_size, int) and window_size < 1): raise ValueError("Argument `window_size` is expected to be a positive integer.") self.window_size = window_size diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index b270903eafd..f580f8f6165 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -306,8 +306,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: # check if states are already synced if self._is_synced: raise TorchMetricsUserError( - "The Metric shouldn't be synced when performing ``forward``. " - "HINT: Did you forget to call ``unsync`` ?." + "The Metric shouldn't be synced when performing ``forward``. HINT: Did you forget to call ``unsync`` ?." ) if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step: diff --git a/src/torchmetrics/multimodal/__init__.py b/src/torchmetrics/multimodal/__init__.py index 4a4c77d8baa..5b745dd0095 100644 --- a/src/torchmetrics/multimodal/__init__.py +++ b/src/torchmetrics/multimodal/__init__.py @@ -17,4 +17,4 @@ from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment from torchmetrics.multimodal.clip_score import CLIPScore - __all__ = ["CLIPScore", "CLIPImageQualityAssessment"] + __all__ = ["CLIPImageQualityAssessment", "CLIPScore"] diff --git a/src/torchmetrics/retrieval/__init__.py b/src/torchmetrics/retrieval/__init__.py index 18a9df576af..b770bfe49fd 100644 --- a/src/torchmetrics/retrieval/__init__.py +++ b/src/torchmetrics/retrieval/__init__.py @@ -31,7 +31,7 @@ "RetrievalNormalizedDCG", "RetrievalPrecision", "RetrievalPrecisionRecallCurve", + "RetrievalRPrecision", "RetrievalRecall", "RetrievalRecallAtFixedPrecision", - "RetrievalRPrecision", ] diff --git a/src/torchmetrics/segmentation/__init__.py b/src/torchmetrics/segmentation/__init__.py index c492cccba46..b7513963ca8 100644 --- a/src/torchmetrics/segmentation/__init__.py +++ b/src/torchmetrics/segmentation/__init__.py @@ -16,4 +16,4 @@ from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance from torchmetrics.segmentation.mean_iou import MeanIoU -__all__ = ["GeneralizedDiceScore", "MeanIoU", "HausdorffDistance", "DiceScore"] +__all__ = ["DiceScore", "GeneralizedDiceScore", "HausdorffDistance", "MeanIoU"] diff --git a/src/torchmetrics/text/__init__.py b/src/torchmetrics/text/__init__.py index 48807a98fc4..6af056246cd 100644 --- a/src/torchmetrics/text/__init__.py +++ b/src/torchmetrics/text/__init__.py @@ -29,15 +29,15 @@ __all__ = [ "BLEUScore", - "CharErrorRate", "CHRFScore", + "CharErrorRate", "EditDistance", "ExtendedEditDistance", "MatchErrorRate", "Perplexity", "ROUGEScore", - "SacreBLEUScore", "SQuAD", + "SacreBLEUScore", "TranslationEditRate", "WordErrorRate", "WordInfoLost", diff --git a/src/torchmetrics/text/eed.py b/src/torchmetrics/text/eed.py index c776eba2331..86f366c84cc 100644 --- a/src/torchmetrics/text/eed.py +++ b/src/torchmetrics/text/eed.py @@ -86,7 +86,7 @@ def __init__( # input validation for parameters for param_name, param in zip(["alpha", "rho", "deletion", "insertion"], [alpha, rho, deletion, insertion]): - if not isinstance(param, float) or isinstance(param, float) and param < 0: + if not isinstance(param, float) or (isinstance(param, float) and param < 0): raise ValueError(f"Parameter `{param_name}` is expected to be a non-negative float.") self.alpha = alpha diff --git a/src/torchmetrics/utilities/__init__.py b/src/torchmetrics/utilities/__init__.py index faf56efae0f..3225491739b 100644 --- a/src/torchmetrics/utilities/__init__.py +++ b/src/torchmetrics/utilities/__init__.py @@ -25,13 +25,13 @@ __all__ = [ "check_forward_full_state_property", "class_reduce", - "reduce", - "rank_zero_debug", - "rank_zero_info", - "rank_zero_warn", "dim_zero_cat", "dim_zero_max", "dim_zero_mean", "dim_zero_min", "dim_zero_sum", + "rank_zero_debug", + "rank_zero_info", + "rank_zero_warn", + "reduce", ] diff --git a/src/torchmetrics/utilities/prints.py b/src/torchmetrics/utilities/prints.py index 272ab650399..0824d06bea3 100644 --- a/src/torchmetrics/utilities/prints.py +++ b/src/torchmetrics/utilities/prints.py @@ -40,7 +40,7 @@ def wrapped_fn(*args: Any, **kwargs: Any) -> Any: def _warn(*args: Any, **kwargs: Any) -> None: - warnings.warn(*args, **kwargs) # noqa: B028 + warnings.warn(*args, **kwargs) def _info(*args: Any, **kwargs: Any) -> None: diff --git a/src/torchmetrics/wrappers/__init__.py b/src/torchmetrics/wrappers/__init__.py index d25aece83ed..f3a41f990a2 100644 --- a/src/torchmetrics/wrappers/__init__.py +++ b/src/torchmetrics/wrappers/__init__.py @@ -32,9 +32,9 @@ "FeatureShare", "LambdaInputTransformer", "MetricInputTransformer", + "MetricTracker", "MinMaxMetric", "MultioutputWrapper", "MultitaskWrapper", - "MetricTracker", "Running", ] diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index 0a8ca7eac1d..54cb6d44408 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -112,8 +112,7 @@ def __init__( super().__init__() if not isinstance(metric, (Metric, MetricCollection)): raise TypeError( - "Metric arg need to be an instance of a torchmetrics" - f" `Metric` or `MetricCollection` but got {metric}" + f"Metric arg need to be an instance of a torchmetrics `Metric` or `MetricCollection` but got {metric}" ) self._base_metric = metric diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 05799b2711d..7afe917abac 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -25,6 +25,7 @@ from pytorch_lightning import LightningModule, Trainer, seed_everything from pytorch_lightning.loggers import CSVLogger +from integrations.lightning.boring_model import BoringModel from torchmetrics import MetricCollection from torchmetrics.aggregation import SumMetric from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, MulticlassAccuracy @@ -32,13 +33,11 @@ from torchmetrics.utilities.prints import rank_zero_only from torchmetrics.wrappers import ClasswiseWrapper, MinMaxMetric, MultitaskWrapper -from integrations.lightning.boring_model import BoringModel - seed_everything(42) class DiffMetric(SumMetric): - """DiffMetric inherited from `SumMetric` by overidding its `update` method.""" + """DiffMetric inherited from `SumMetric` by overriding its `update` method.""" def update(self, value): """Update state.""" diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py index 6cc99ce84a4..74bc9a0e32c 100644 --- a/tests/unittests/__init__.py +++ b/tests/unittests/__init__.py @@ -55,12 +55,12 @@ class _GroupInput(NamedTuple): __all__ = [ "BATCH_SIZE", "EXTRA_DIM", - "_Input", - "_GroupInput", "NUM_BATCHES", "NUM_CLASSES", "NUM_PROCESSES", "THRESHOLD", "USE_PYTEST_POOL", + "_GroupInput", + "_Input", "setup_ddp", ] diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index 1622e4ad8a3..5223fef9acb 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -23,9 +23,9 @@ import torch from lightning_utilities import apply_to_collection from torch import Tensor, tensor + from torchmetrics import Metric from torchmetrics.utilities.data import _flatten - from unittests import NUM_PROCESSES, _reference_cachier @@ -146,7 +146,7 @@ def _class_test( # check that metric can be cloned clone = metric.clone() assert clone is not metric, "Clone is not a different object than the metric" - assert type(clone) == type(metric), "Type of clone did not match metric type" + assert type(clone) == type(metric), "Type of clone did not match metric type" # noqa: E721 # move to device metric = metric.to(device) diff --git a/tests/unittests/audio/test_c_si_snr.py b/tests/unittests/audio/test_c_si_snr.py index 2ed148aef65..216809f3f99 100644 --- a/tests/unittests/audio/test_c_si_snr.py +++ b/tests/unittests/audio/test_c_si_snr.py @@ -15,9 +15,9 @@ import pytest import torch from scipy.io import wavfile + from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_dnsmos.py b/tests/unittests/audio/test_dnsmos.py index 80607057467..b89dd82fb5c 100644 --- a/tests/unittests/audio/test_dnsmos.py +++ b/tests/unittests/audio/test_dnsmos.py @@ -19,6 +19,7 @@ import pytest import torch from torch import Tensor + from torchmetrics.audio.dnsmos import DeepNoiseSuppressionMeanOpinionScore from torchmetrics.functional.audio.dnsmos import ( DNSMOS_DIR, @@ -30,7 +31,6 @@ _ONNXRUNTIME_AVAILABLE, _REQUESTS_AVAILABLE, ) - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_nisqa.py b/tests/unittests/audio/test_nisqa.py index 06eac64710c..ff40375ec95 100644 --- a/tests/unittests/audio/test_nisqa.py +++ b/tests/unittests/audio/test_nisqa.py @@ -18,9 +18,9 @@ import pytest import torch from torch import Tensor + from torchmetrics.audio.nisqa import NonIntrusiveSpeechQualityAssessment from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment - from unittests._helpers.testers import MetricTester # reference values below were calculated using the method described in https://github.com/gabrielmittag/NISQA/blob/master/README.md diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index dd0e3caba9c..a32a21db63d 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -18,9 +18,9 @@ from pesq import pesq as pesq_backend from scipy.io import wavfile from torch import Tensor + from torchmetrics.audio import PerceptualEvaluationSpeechQuality from torchmetrics.functional.audio import perceptual_evaluation_speech_quality - from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index 70c5b44c6ab..182f78218b6 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -19,6 +19,7 @@ import torch from scipy.optimize import linear_sum_assignment from torch import Tensor + from torchmetrics.audio import PermutationInvariantTraining from torchmetrics.functional.audio import ( permutation_invariant_training, @@ -29,7 +30,6 @@ _find_best_perm_by_exhaustive_method, _find_best_perm_by_linear_sum_assignment, ) - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_sa_sdr.py b/tests/unittests/audio/test_sa_sdr.py index 3de3c4900cf..b52fb6c478f 100644 --- a/tests/unittests/audio/test_sa_sdr.py +++ b/tests/unittests/audio/test_sa_sdr.py @@ -16,13 +16,13 @@ import pytest import torch from torch import Tensor + from torchmetrics.audio import SourceAggregatedSignalDistortionRatio from torchmetrics.functional.audio import ( scale_invariant_signal_distortion_ratio, signal_noise_ratio, source_aggregated_signal_distortion_ratio, ) - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_sdr.py b/tests/unittests/audio/test_sdr.py index 8d5a8c7ab8f..dfd762410e2 100644 --- a/tests/unittests/audio/test_sdr.py +++ b/tests/unittests/audio/test_sdr.py @@ -19,9 +19,9 @@ from mir_eval.separation import bss_eval_sources from scipy.io import wavfile from torch import Tensor + from torchmetrics.audio import SignalDistortionRatio from torchmetrics.functional import signal_distortion_ratio - from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_si_sdr.py b/tests/unittests/audio/test_si_sdr.py index 0531d65683e..f78acb6b786 100644 --- a/tests/unittests/audio/test_si_sdr.py +++ b/tests/unittests/audio/test_si_sdr.py @@ -17,9 +17,9 @@ import pytest import torch from torch import Tensor + from torchmetrics.audio import ScaleInvariantSignalDistortionRatio from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_si_snr.py b/tests/unittests/audio/test_si_snr.py index 8b6f35c54f1..c57d86c2806 100644 --- a/tests/unittests/audio/test_si_snr.py +++ b/tests/unittests/audio/test_si_snr.py @@ -17,9 +17,9 @@ import pytest import torch from torch import Tensor + from torchmetrics.audio import ScaleInvariantSignalNoiseRatio from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_snr.py b/tests/unittests/audio/test_snr.py index 332707028ff..92824c0c101 100644 --- a/tests/unittests/audio/test_snr.py +++ b/tests/unittests/audio/test_snr.py @@ -17,9 +17,9 @@ import torch from mir_eval.separation import bss_eval_images as mir_eval_bss_eval_images from torch import Tensor + from torchmetrics.audio import SignalNoiseRatio from torchmetrics.functional.audio import signal_noise_ratio - from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_srmr.py b/tests/unittests/audio/test_srmr.py index d3a18cca357..ae5fab8f876 100644 --- a/tests/unittests/audio/test_srmr.py +++ b/tests/unittests/audio/test_srmr.py @@ -18,9 +18,9 @@ import torch from srmrpy import srmr as srmrpy_srmr from torch import Tensor + from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/audio/test_stoi.py b/tests/unittests/audio/test_stoi.py index d7998aaf8b2..178e466d3cb 100644 --- a/tests/unittests/audio/test_stoi.py +++ b/tests/unittests/audio/test_stoi.py @@ -18,9 +18,9 @@ from pystoi import stoi as stoi_backend from scipy.io import wavfile from torch import Tensor + from torchmetrics.audio import ShortTimeObjectiveIntelligibility from torchmetrics.functional.audio import short_time_objective_intelligibility - from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/bases/test_aggregation.py b/tests/unittests/bases/test_aggregation.py index 0b593208e54..9768e5c4c0a 100644 --- a/tests/unittests/bases/test_aggregation.py +++ b/tests/unittests/bases/test_aggregation.py @@ -1,9 +1,9 @@ import numpy as np import pytest import torch + from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric from torchmetrics.collections import MetricCollection - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index e498e18dc84..4873ea2e9d3 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -17,6 +17,7 @@ import pytest import torch + from torchmetrics import ClasswiseWrapper, Metric, MetricCollection from torchmetrics.classification import ( BinaryAccuracy, @@ -33,7 +34,6 @@ MultilabelAveragePrecision, ) from torchmetrics.utilities.checks import _allclose_recursive - from unittests._helpers import seed_all from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum diff --git a/tests/unittests/bases/test_composition.py b/tests/unittests/bases/test_composition.py index f33d37f2015..5d7e436f66e 100644 --- a/tests/unittests/bases/test_composition.py +++ b/tests/unittests/bases/test_composition.py @@ -17,6 +17,7 @@ import pytest import torch from torch import tensor + from torchmetrics.metric import CompositionalMetric, Metric diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 07dee96f4da..8f1920b8bf2 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -19,11 +19,11 @@ import pytest import torch from torch import tensor + from torchmetrics import Metric from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_PROCESSES, USE_PYTEST_POOL from unittests._helpers import seed_all from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum @@ -87,7 +87,7 @@ def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> assert val == 2 * worldsize -@pytest.mark.DDP() +@pytest.mark.DDP @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") @pytest.mark.parametrize( @@ -136,7 +136,7 @@ def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NU assert torch.allclose(grad, a * torch.ones_like(x)) -@pytest.mark.DDP() +@pytest.mark.DDP @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") @pytest.mark.parametrize( @@ -166,7 +166,7 @@ def compute(self): metric.update(torch.randn(10, 5)[:, 0]) -@pytest.mark.DDP() +@pytest.mark.DDP @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") def test_non_contiguous_tensors(): @@ -274,7 +274,7 @@ def reload_state_dict(state_dict, expected_x, expected_c): torch.save(metric.state_dict(), filepath) -@pytest.mark.DDP() +@pytest.mark.DDP @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") def test_state_dict_is_synced(tmpdir): @@ -303,7 +303,7 @@ def _test_sync_on_compute_list_state(rank, sync_on_compute): assert val == [tensor(rank + 1)] -@pytest.mark.DDP() +@pytest.mark.DDP @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") @pytest.mark.parametrize("sync_on_compute", [True, False]) @@ -319,7 +319,7 @@ def _test_sync_with_empty_lists(rank): assert torch.allclose(val, tensor([])) -@pytest.mark.DDP() +@pytest.mark.DDP @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") @@ -336,7 +336,7 @@ def _test_sync_with_unequal_size_lists(rank): assert torch.all(dummy.compute() == tensor([0.0, 0.0])) -@pytest.mark.DDP() +@pytest.mark.DDP @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_sync_with_unequal_size_lists(): diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 363b2d31a66..f4ed318e38f 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -24,12 +24,12 @@ import torch from torch import Tensor, tensor from torch.nn import Module, Parameter + from torchmetrics.aggregation import MeanMetric, SumMetric from torchmetrics.classification import BinaryAccuracy from torchmetrics.clustering import AdjustedRandScore from torchmetrics.image import StructuralSimilarityIndexMeasure from torchmetrics.regression import PearsonCorrCoef, R2Score - from unittests._helpers import seed_all from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum diff --git a/tests/unittests/bases/test_saving_loading.py b/tests/unittests/bases/test_saving_loading.py index f808674367c..ff075b29ded 100644 --- a/tests/unittests/bases/test_saving_loading.py +++ b/tests/unittests/bases/test_saving_loading.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from torchmetrics.classification import MulticlassAccuracy diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 59b466d33c1..49e5e6e4be4 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -19,6 +19,7 @@ from scipy.special import expit as sigmoid from sklearn.metrics import accuracy_score as sk_accuracy from sklearn.metrics import confusion_matrix as sk_confusion_matrix + from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy from torchmetrics.functional.classification.accuracy import ( accuracy, @@ -28,7 +29,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_auc.py b/tests/unittests/classification/test_auc.py index 8644de105a0..e4a23623071 100644 --- a/tests/unittests/classification/test_auc.py +++ b/tests/unittests/classification/test_auc.py @@ -18,8 +18,8 @@ import pytest from sklearn.metrics import auc as _sk_auc from torch import Tensor, tensor -from torchmetrics.utilities.compute import auc +from torchmetrics.utilities.compute import auc from unittests import NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index c7fdb54d6c1..7a03a7a2156 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -20,12 +20,12 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import roc_auc_score as sk_roc_auc_score + from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index cf37360e832..0ee1cbee56a 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -20,6 +20,7 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import average_precision_score as sk_average_precision_score + from torchmetrics.classification.average_precision import ( AveragePrecision, BinaryAveragePrecision, @@ -34,7 +35,6 @@ from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 8e2556c0533..0f6ef71a6d2 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -19,6 +19,7 @@ from netcal.metrics import ECE, MCE from scipy.special import expit as sigmoid from scipy.special import softmax + from torchmetrics.classification.calibration_error import ( BinaryCalibrationError, CalibrationError, @@ -30,7 +31,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 4c4a411aab7..b48f622dc76 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -18,11 +18,11 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa + from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 7d7a5f28cb0..4a71f11890f 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -18,6 +18,7 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix + from torchmetrics.classification.confusion_matrix import ( BinaryConfusionMatrix, ConfusionMatrix, @@ -31,7 +32,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_dice.py b/tests/unittests/classification/test_dice.py index 6854265d3d9..5f5a4aeadc5 100644 --- a/tests/unittests/classification/test_dice.py +++ b/tests/unittests/classification/test_dice.py @@ -17,12 +17,12 @@ import pytest from scipy.spatial.distance import dice as sc_dice from torch import Tensor, tensor + from torchmetrics.classification import Dice from torchmetrics.functional import dice from torchmetrics.functional.classification.stat_scores import _del_column from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester from unittests.classification._inputs import _input_binary, _input_binary_logits, _input_binary_prob diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 3cb8caa2061..a8c02353e2a 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -17,11 +17,11 @@ import pytest import torch from scipy.special import expit as sigmoid + from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch from torchmetrics.functional.classification.exact_match import multiclass_exact_match, multilabel_exact_match from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 73e988dc36f..b57cafcb00d 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -21,6 +21,7 @@ from sklearn.metrics import f1_score as sk_f1_score from sklearn.metrics import fbeta_score as sk_fbeta_score from torch import Tensor + from torchmetrics.classification.f_beta import ( BinaryF1Score, BinaryFBetaScore, @@ -41,7 +42,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -478,9 +478,9 @@ def test_multiclassf1score_with_top_k(num_classes): previous_score = score if k == num_classes: - assert torch.isclose( - score, torch.tensor(1.0) - ), f"F1 score is not 1 for top_k={k} when num_classes={num_classes}" + assert torch.isclose(score, torch.tensor(1.0)), ( + f"F1 score is not 1 for top_k={k} when num_classes={num_classes}" + ) def test_multiclass_f1_score_top_k_equivalence(): @@ -506,9 +506,9 @@ def test_multiclass_f1_score_top_k_equivalence(): score_top3 = f1_val_top3(preds, target) score_corrected = f1_val_top1(pred_corrected_top3, target) - assert torch.isclose( - score_top3, score_corrected - ), f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})" + assert torch.isclose(score_top3, score_corrected), ( + f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})" + ) def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division): diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index d1899831d7e..c8cdae99bef 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -23,11 +23,11 @@ from fairlearn.metrics import MetricFrame, selection_rate, true_positive_rate from scipy.special import expit as sigmoid from torch import Tensor + from torchmetrics import Metric from torchmetrics.classification.group_fairness import BinaryFairness from torchmetrics.functional.classification.group_fairness import binary_fairness from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import ( @@ -86,9 +86,9 @@ def _assert_allclose( # todo: unify with the general assert_allclose ) -> None: if isinstance(pl_result, dict) and key is None: for (pl_key, pl_val), (sk_key, sk_val) in zip(pl_result.items(), sk_result.items()): - assert np.allclose( - pl_val.detach().cpu().numpy(), sk_val.numpy(), atol=atol, equal_nan=True - ), f"{pl_key} != {sk_key}" + assert np.allclose(pl_val.detach().cpu().numpy(), sk_val.numpy(), atol=atol, equal_nan=True), ( + f"{pl_key} != {sk_key}" + ) else: _core_assert_allclose(pl_result, sk_result, atol, key) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 6d4f0f824cc..af7708d5e95 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -19,6 +19,7 @@ from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import hamming_loss as sk_hamming_loss + from torchmetrics.classification.hamming import ( BinaryHammingDistance, HammingDistance, @@ -32,7 +33,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index 8963177d7a0..db156eb0902 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -20,10 +20,10 @@ from scipy.special import softmax from sklearn.metrics import hinge_loss as sk_hinge from sklearn.preprocessing import OneHotEncoder + from torchmetrics.classification.hinge import BinaryHingeLoss, HingeLoss, MulticlassHingeLoss from torchmetrics.functional.classification.hinge import binary_hinge_loss, multiclass_hinge_loss from torchmetrics.metric import Metric - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 606825f7e71..36166facd8c 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -19,6 +19,7 @@ from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import jaccard_score as sk_jaccard_index + from torchmetrics.classification.jaccard import ( BinaryJaccardIndex, JaccardIndex, @@ -33,7 +34,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases diff --git a/tests/unittests/classification/test_logauc.py b/tests/unittests/classification/test_logauc.py index 6494ac72372..0278ea35b8c 100644 --- a/tests/unittests/classification/test_logauc.py +++ b/tests/unittests/classification/test_logauc.py @@ -18,6 +18,7 @@ import torch from scipy.special import expit as sigmoid from scipy.special import softmax + from torchmetrics.utilities.imports import _PYTDC_AVAILABLE if _PYTDC_AVAILABLE: @@ -28,7 +29,6 @@ from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index fc4d762384b..30b5326b0cc 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -18,6 +18,7 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef + from torchmetrics.classification.matthews_corrcoef import ( BinaryMatthewsCorrCoef, MatthewsCorrCoef, @@ -31,7 +32,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_negative_predictive_value.py b/tests/unittests/classification/test_negative_predictive_value.py index 2fb352bc74f..a3a288e70f2 100644 --- a/tests/unittests/classification/test_negative_predictive_value.py +++ b/tests/unittests/classification/test_negative_predictive_value.py @@ -19,6 +19,7 @@ from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix from torch import Tensor, tensor + from torchmetrics.classification.negative_predictive_value import ( BinaryNegativePredictiveValue, MulticlassNegativePredictiveValue, @@ -32,7 +33,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index 03c8ee7654f..615526521eb 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -20,6 +20,7 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve + from torchmetrics.classification.precision_fixed_recall import ( BinaryPrecisionAtFixedRecall, MulticlassPrecisionAtFixedRecall, @@ -33,7 +34,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index c95ececa1dd..40a161ce831 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -21,6 +21,7 @@ from sklearn.metrics import precision_score as sk_precision_score from sklearn.metrics import recall_score as sk_recall_score from torch import Tensor, tensor + from torchmetrics.classification.precision_recall import ( BinaryPrecision, BinaryRecall, @@ -41,7 +42,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index d6a79b9b5cb..972cb7ebd5c 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -20,6 +20,7 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve + from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve, @@ -33,7 +34,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index 4727ab882a1..50b0f0e9e07 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -20,6 +20,7 @@ from sklearn.metrics import coverage_error as sk_coverage_error from sklearn.metrics import label_ranking_average_precision_score as sk_label_ranking from sklearn.metrics import label_ranking_loss as sk_label_ranking_loss + from torchmetrics.classification.ranking import ( MultilabelCoverageError, MultilabelRankingAveragePrecision, @@ -31,7 +32,6 @@ multilabel_ranking_loss, ) from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index 5bbf2e55e58..3c7347192ea 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -20,6 +20,7 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve + from torchmetrics.classification.recall_fixed_precision import ( BinaryRecallAtFixedPrecision, MulticlassRecallAtFixedPrecision, @@ -33,7 +34,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index 5ad6dee35fa..9aea2bf7f90 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -20,11 +20,11 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import roc_curve as sk_roc_curve + from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index cc85f6e4e28..4fca8f9f941 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -20,6 +20,7 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import roc_curve as sk_roc_curve + from torchmetrics.classification.sensitivity_specificity import ( BinarySensitivityAtSpecificity, MulticlassSensitivityAtSpecificity, @@ -34,7 +35,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3, _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 4c2c2023630..d7b69c75f4f 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -19,6 +19,7 @@ from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix from torch import Tensor + from torchmetrics.classification.specificity import ( BinarySpecificity, MulticlassSpecificity, @@ -32,7 +33,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index 9e866dbabd9..4c31dd5d7ee 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -20,6 +20,7 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import roc_curve as sk_roc_curve + from torchmetrics.classification.specificity_sensitivity import ( BinarySpecificityAtSensitivity, MulticlassSpecificityAtSensitivity, @@ -34,7 +35,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 47a7a7cab28..37f35063882 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -19,6 +19,7 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix + from torchmetrics.classification.stat_scores import ( BinaryStatScores, MulticlassStatScores, @@ -33,7 +34,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -395,7 +395,7 @@ def test_refine_preds_oh(top_k, expected_result): result = _refine_preds_oh(preds, preds_oh, target, top_k) assert torch.equal(result, expected_result), ( - f"Test failed for top_k={top_k}. " f"Expected result: {expected_result}, but got: {result}" + f"Test failed for top_k={top_k}. Expected result: {expected_result}, but got: {result}" ) diff --git a/tests/unittests/clustering/test_adjusted_mutual_info_score.py b/tests/unittests/clustering/test_adjusted_mutual_info_score.py index 474e221d6a5..15ed75d2508 100644 --- a/tests/unittests/clustering/test_adjusted_mutual_info_score.py +++ b/tests/unittests/clustering/test_adjusted_mutual_info_score.py @@ -16,9 +16,9 @@ import pytest import torch from sklearn.metrics import adjusted_mutual_info_score as sklearn_ami + from torchmetrics.clustering.adjusted_mutual_info_score import AdjustedMutualInfoScore from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score - from unittests import BATCH_SIZE, NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_adjusted_rand_score.py b/tests/unittests/clustering/test_adjusted_rand_score.py index b98536aad15..4542bf27425 100644 --- a/tests/unittests/clustering/test_adjusted_rand_score.py +++ b/tests/unittests/clustering/test_adjusted_rand_score.py @@ -14,9 +14,9 @@ import pytest import torch from sklearn.metrics import adjusted_rand_score as sklearn_adjusted_rand_score + from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score - from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 diff --git a/tests/unittests/clustering/test_calinski_harabasz_score.py b/tests/unittests/clustering/test_calinski_harabasz_score.py index f81da592389..c2c23d87f4a 100644 --- a/tests/unittests/clustering/test_calinski_harabasz_score.py +++ b/tests/unittests/clustering/test_calinski_harabasz_score.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest from sklearn.metrics import calinski_harabasz_score as sklearn_calinski_harabasz_score + from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 diff --git a/tests/unittests/clustering/test_davies_bouldin_score.py b/tests/unittests/clustering/test_davies_bouldin_score.py index bea2018c2cc..230e65e5cf5 100644 --- a/tests/unittests/clustering/test_davies_bouldin_score.py +++ b/tests/unittests/clustering/test_davies_bouldin_score.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest from sklearn.metrics import davies_bouldin_score as sklearn_davies_bouldin_score + from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 diff --git a/tests/unittests/clustering/test_dunn_index.py b/tests/unittests/clustering/test_dunn_index.py index c2e6adcd2cb..e2a5b8af9ab 100644 --- a/tests/unittests/clustering/test_dunn_index.py +++ b/tests/unittests/clustering/test_dunn_index.py @@ -16,9 +16,9 @@ import numpy as np import pytest + from torchmetrics.clustering.dunn_index import DunnIndex from torchmetrics.functional.clustering.dunn_index import dunn_index - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import ( diff --git a/tests/unittests/clustering/test_fowlkes_mallows_index.py b/tests/unittests/clustering/test_fowlkes_mallows_index.py index 6e5674ae337..22e775ecb76 100644 --- a/tests/unittests/clustering/test_fowlkes_mallows_index.py +++ b/tests/unittests/clustering/test_fowlkes_mallows_index.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest from sklearn.metrics import fowlkes_mallows_score as sklearn_fowlkes_mallows_score + from torchmetrics.clustering import FowlkesMallowsIndex from torchmetrics.functional.clustering import fowlkes_mallows_index - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _single_target_extrinsic1, _single_target_extrinsic2 diff --git a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py index dd716182b4b..853955017f9 100644 --- a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py +++ b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py @@ -17,6 +17,7 @@ from sklearn.metrics import completeness_score as sklearn_completeness_score from sklearn.metrics import homogeneity_score as sklearn_homogeneity_score from sklearn.metrics import v_measure_score as sklearn_v_measure_score + from torchmetrics.clustering.homogeneity_completeness_v_measure import ( CompletenessScore, HomogeneityScore, @@ -27,7 +28,6 @@ homogeneity_score, v_measure_score, ) - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index ab9222b8082..054d9e711f5 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -14,9 +14,9 @@ import pytest import torch from sklearn.metrics import mutual_info_score as sklearn_mutual_info_score + from torchmetrics.clustering.mutual_info_score import MutualInfoScore from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score - from unittests import BATCH_SIZE, NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_normalized_mutual_info_score.py b/tests/unittests/clustering/test_normalized_mutual_info_score.py index 07109771b0f..13b83bd9899 100644 --- a/tests/unittests/clustering/test_normalized_mutual_info_score.py +++ b/tests/unittests/clustering/test_normalized_mutual_info_score.py @@ -16,9 +16,9 @@ import pytest import torch from sklearn.metrics import normalized_mutual_info_score as sklearn_nmi + from torchmetrics.clustering import NormalizedMutualInfoScore from torchmetrics.functional.clustering import normalized_mutual_info_score - from unittests import BATCH_SIZE, NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index 824f5e90fbd..cd1240420a6 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -14,9 +14,9 @@ import pytest import torch from sklearn.metrics import rand_score as sklearn_rand_score + from torchmetrics.clustering.rand_score import RandScore from torchmetrics.functional.clustering.rand_score import rand_score - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py index 2b1e4dbc755..0cf01b86b1b 100644 --- a/tests/unittests/clustering/test_utils.py +++ b/tests/unittests/clustering/test_utils.py @@ -19,13 +19,13 @@ from sklearn.metrics.cluster import entropy as sklearn_entropy from sklearn.metrics.cluster import pair_confusion_matrix as sklearn_pair_confusion_matrix from sklearn.metrics.cluster._supervised import _generalized_average as sklearn_generalized_average + from torchmetrics.functional.clustering.utils import ( calculate_contingency_matrix, calculate_entropy, calculate_generalized_mean, calculate_pair_cluster_confusion_matrix, ) - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index 58967ba2521..65c086cb9ba 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -36,8 +36,8 @@ USE_PYTEST_POOL = os.getenv("USE_PYTEST_POOL", "0") == "1" -@pytest.fixture() -def use_deterministic_algorithms(): # noqa: PT004 +@pytest.fixture +def use_deterministic_algorithms(): """Set deterministic algorithms for the test.""" torch.use_deterministic_algorithms(True) yield diff --git a/tests/unittests/deprecations/root_class_imports.py b/tests/unittests/deprecations/root_class_imports.py index 5c4aa1a7155..566f489d48f 100644 --- a/tests/unittests/deprecations/root_class_imports.py +++ b/tests/unittests/deprecations/root_class_imports.py @@ -3,6 +3,7 @@ from functools import partial import pytest + from torchmetrics import ( BLEUScore, CharErrorRate, diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index 88a6408c536..7906939a634 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -16,6 +16,7 @@ import pytest import torch from torch import IntTensor, Tensor + from torchmetrics.detection.ciou import CompleteIntersectionOverUnion from torchmetrics.detection.diou import DistanceIntersectionOverUnion from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 221f1f87aef..f316bf0b9cd 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -25,12 +25,12 @@ from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from torch import IntTensor, Tensor + from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchmetrics.utilities.imports import ( _FASTER_COCO_EVAL_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, ) - from unittests._helpers.testers import MetricTester from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index f4fe1d1ee06..dfc2411ad98 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -16,9 +16,9 @@ import numpy as np import pytest import torch + from torchmetrics.detection import ModifiedPanopticQuality from torchmetrics.functional.detection import modified_panoptic_quality - from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index 58287aa7fcd..941d887466a 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -16,9 +16,9 @@ import numpy as np import pytest import torch + from torchmetrics.detection.panoptic_qualities import PanopticQuality from torchmetrics.functional.detection.panoptic_qualities import panoptic_quality - from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_csi.py b/tests/unittests/image/test_csi.py index 85e370c008f..0bdb87bc1ea 100644 --- a/tests/unittests/image/test_csi.py +++ b/tests/unittests/image/test_csi.py @@ -17,9 +17,9 @@ import pytest import torch from sklearn.metrics import jaccard_score + from torchmetrics.functional.regression.csi import critical_success_index from torchmetrics.regression.csi import CriticalSuccessIndex - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_d_lambda.py b/tests/unittests/image/test_d_lambda.py index a4eb684b5ac..9f80355741a 100644 --- a/tests/unittests/image/test_d_lambda.py +++ b/tests/unittests/image/test_d_lambda.py @@ -18,10 +18,10 @@ import pytest import torch from torch import Tensor + from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.image.d_lambda import SpectralDistortionIndex - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index 09a9675d380..da9112e448c 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -21,10 +21,10 @@ from scipy.ndimage import uniform_filter from skimage.transform import resize from torch import Tensor + from torchmetrics.functional.image.d_s import spatial_distortion_index from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.image.d_s import SpatialDistortionIndex - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_ergas.py b/tests/unittests/image/test_ergas.py index 0d712292a21..f96f21d1b8a 100644 --- a/tests/unittests/image/test_ergas.py +++ b/tests/unittests/image/test_ergas.py @@ -17,10 +17,10 @@ import pytest import torch from torch import Tensor + from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index a55b2738ccf..75829cf1b2d 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -19,9 +19,9 @@ import torch from torch.nn import Module from torch.utils.data import Dataset + from torchmetrics.image.fid import FrechetInceptionDistance, NoTrainInceptionV3 from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE - from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/image/test_image_gradients.py b/tests/unittests/image/test_image_gradients.py index 1181c114fd8..8f2eac57691 100644 --- a/tests/unittests/image/test_image_gradients.py +++ b/tests/unittests/image/test_image_gradients.py @@ -14,6 +14,7 @@ import pytest import torch from torch import Tensor + from torchmetrics.functional import image_gradients diff --git a/tests/unittests/image/test_inception.py b/tests/unittests/image/test_inception.py index f5cb5ec5aed..a750298f15e 100644 --- a/tests/unittests/image/test_inception.py +++ b/tests/unittests/image/test_inception.py @@ -18,9 +18,9 @@ import torch from torch.nn import Module from torch.utils.data import Dataset + from torchmetrics.image.inception import InceptionScore from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE - from unittests._helpers import seed_all seed_all(42) @@ -41,9 +41,9 @@ def forward(self, x): model = MyModel() model.train() assert model.training - assert ( - not model.metric.inception.training - ), "InceptionScore metric was changed to training mode which should not happen" + assert not model.metric.inception.training, ( + "InceptionScore metric was changed to training mode which should not happen" + ) @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") diff --git a/tests/unittests/image/test_kid.py b/tests/unittests/image/test_kid.py index 0d69f9eda6e..b61f85e7b63 100644 --- a/tests/unittests/image/test_kid.py +++ b/tests/unittests/image/test_kid.py @@ -18,9 +18,9 @@ import torch from torch.nn import Module from torch.utils.data import Dataset + from torchmetrics.image.kid import KernelInceptionDistance from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE - from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index 5e148a0d984..49a011d0c82 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -17,10 +17,10 @@ import pytest import torch from torch import Tensor + from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE - from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_mifid.py b/tests/unittests/image/test_mifid.py index 247a8508f72..e294a4d2ca4 100644 --- a/tests/unittests/image/test_mifid.py +++ b/tests/unittests/image/test_mifid.py @@ -18,9 +18,9 @@ import pytest import torch from scipy.linalg import sqrtm + from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance, NoTrainInceptionV3 from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE - from unittests import _reference_cachier from unittests._helpers import seed_all diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index 8201e877332..031a575e4c1 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -15,9 +15,9 @@ import pytest import torch from pytorch_msssim import ms_ssim + from torchmetrics.functional.image.ssim import multiscale_structural_similarity_index_measure from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure - from unittests import NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py index 26f48be51ae..0ce1f94bf66 100644 --- a/tests/unittests/image/test_perceptual_path_length.py +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -19,11 +19,11 @@ from torch import nn from torch_fidelity.sample_similarity_lpips import SampleSimilarityLPIPS from torch_fidelity.utils import batch_interp + from torchmetrics.functional.image.lpips import _LPIPS from torchmetrics.functional.image.perceptual_path_length import _interpolate, perceptual_path_length from torchmetrics.image.perceptual_path_length import PerceptualPathLength from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE - from unittests._helpers import seed_all, skip_on_running_out_of_memory seed_all(42) diff --git a/tests/unittests/image/test_psnr.py b/tests/unittests/image/test_psnr.py index cdb0e58aac4..11e27ad7faa 100644 --- a/tests/unittests/image/test_psnr.py +++ b/tests/unittests/image/test_psnr.py @@ -18,10 +18,10 @@ import pytest import torch from skimage.metrics import peak_signal_noise_ratio as skimage_peak_signal_noise_ratio + from torchmetrics.functional import peak_signal_noise_ratio from torchmetrics.image import PeakSignalNoiseRatio from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_psnrb.py b/tests/unittests/image/test_psnrb.py index 2d59efa1f79..50b296ea532 100644 --- a/tests/unittests/image/test_psnrb.py +++ b/tests/unittests/image/test_psnrb.py @@ -17,9 +17,9 @@ import pytest import torch from sewar.utils import _compute_bef + from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio_with_blocked_effect from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_qnr.py b/tests/unittests/image/test_qnr.py index 52d77896235..738df2422a5 100644 --- a/tests/unittests/image/test_qnr.py +++ b/tests/unittests/image/test_qnr.py @@ -18,9 +18,9 @@ import pytest import torch from torch import Tensor + from torchmetrics.functional.image.qnr import quality_with_no_reference from torchmetrics.image.qnr import QualityWithNoReference - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_rase.py b/tests/unittests/image/test_rase.py index f9227285e87..062bb6aac34 100644 --- a/tests/unittests/image/test_rase.py +++ b/tests/unittests/image/test_rase.py @@ -18,10 +18,10 @@ import sewar import torch from torch import Tensor + from torchmetrics.functional import relative_average_spectral_error from torchmetrics.functional.image.utils import _uniform_filter from torchmetrics.image import RelativeAverageSpectralError - from unittests import BATCH_SIZE from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_rmse_sw.py b/tests/unittests/image/test_rmse_sw.py index 307d66ac9b1..9c0a2ee21e9 100644 --- a/tests/unittests/image/test_rmse_sw.py +++ b/tests/unittests/image/test_rmse_sw.py @@ -18,9 +18,9 @@ import sewar import torch from torch import Tensor + from torchmetrics.functional import root_mean_squared_error_using_sliding_window from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_sam.py b/tests/unittests/image/test_sam.py index e71b0b230e3..e1950b12ca1 100644 --- a/tests/unittests/image/test_sam.py +++ b/tests/unittests/image/test_sam.py @@ -17,10 +17,10 @@ import torch from torch import Tensor from torch.nn import functional as F # noqa: N812 + from torchmetrics.functional.image.sam import spectral_angle_mapper from torchmetrics.image.sam import SpectralAngleMapper from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index ecf6c355677..16c797cb38a 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -17,9 +17,9 @@ import pytest import torch from sewar.full_ref import scc as sewar_scc + from torchmetrics.functional.image import spatial_correlation_coefficient from torchmetrics.image import SpatialCorrelationCoefficient - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 49954f45cd7..f0c3d527ec2 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -19,9 +19,9 @@ from pytorch_msssim import ssim from skimage.metrics import structural_similarity from torch import Tensor + from torchmetrics.functional import structural_similarity_index_measure from torchmetrics.image import StructuralSimilarityIndexMeasure - from unittests import NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index add144897ae..5d2735a4ab3 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -17,9 +17,9 @@ import pytest import torch from kornia.losses import total_variation as kornia_total_variation + from torchmetrics.functional.image.tv import total_variation from torchmetrics.image.tv import TotalVariation - from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_uqi.py b/tests/unittests/image/test_uqi.py index e79296cf2b5..ae5463b0445 100644 --- a/tests/unittests/image/test_uqi.py +++ b/tests/unittests/image/test_uqi.py @@ -18,10 +18,10 @@ import torch from skimage.metrics import structural_similarity from torch import Tensor + from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.image.uqi import UniversalImageQualityIndex from torchmetrics.utilities.imports import _TORCH_LESS_THAN_2_6 - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/image/test_vif.py b/tests/unittests/image/test_vif.py index e1ea8eb1401..2ecaccd47b7 100644 --- a/tests/unittests/image/test_vif.py +++ b/tests/unittests/image/test_vif.py @@ -16,9 +16,9 @@ import pytest import torch from sewar.full_ref import vifp + from torchmetrics.functional.image.vif import visual_information_fidelity from torchmetrics.image.vif import VisualInformationFidelity - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/multimodal/test_clip_iqa.py b/tests/unittests/multimodal/test_clip_iqa.py index 2e88840e81b..43c117a3008 100644 --- a/tests/unittests/multimodal/test_clip_iqa.py +++ b/tests/unittests/multimodal/test_clip_iqa.py @@ -21,11 +21,11 @@ import torch from PIL import Image from torch import Tensor +from torchvision.transforms import PILToTensor + from torchmetrics.functional.multimodal.clip_iqa import clip_image_quality_assessment from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10 -from torchvision.transforms import PILToTensor - from unittests._helpers import skip_on_connection_issues from unittests._helpers.testers import MetricTester from unittests.image import _SAMPLE_IMAGE diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index 9e71a30ca0a..491d64a8a78 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -19,12 +19,12 @@ import pytest import torch from torch import Tensor -from torchmetrics.functional.multimodal.clip_score import clip_score -from torchmetrics.multimodal.clip_score import CLIPScore -from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor +from torchmetrics.functional.multimodal.clip_score import clip_score +from torchmetrics.multimodal.clip_score import CLIPScore +from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 from unittests._helpers import seed_all, skip_on_connection_issues from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/nominal/test_cramers.py b/tests/unittests/nominal/test_cramers.py index 78df22d21b9..3de0c37cbfd 100644 --- a/tests/unittests/nominal/test_cramers.py +++ b/tests/unittests/nominal/test_cramers.py @@ -16,9 +16,9 @@ import pytest import torch + from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.nominal.cramers import CramersV - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers.testers import MetricTester @@ -43,7 +43,7 @@ ) -@pytest.fixture() +@pytest.fixture def cramers_matrix_input(): """Define input in matrix format for the metric.""" matrix = torch.cat( diff --git a/tests/unittests/nominal/test_fleiss_kappa.py b/tests/unittests/nominal/test_fleiss_kappa.py index 911f2814ad5..659c9d30241 100644 --- a/tests/unittests/nominal/test_fleiss_kappa.py +++ b/tests/unittests/nominal/test_fleiss_kappa.py @@ -17,9 +17,9 @@ import pytest import torch from statsmodels.stats.inter_rater import fleiss_kappa as sk_fleiss_kappa + from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa from torchmetrics.nominal.fleiss_kappa import FleissKappa - from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/nominal/test_pearson.py b/tests/unittests/nominal/test_pearson.py index 2f59461d276..29a2fbeb31d 100644 --- a/tests/unittests/nominal/test_pearson.py +++ b/tests/unittests/nominal/test_pearson.py @@ -16,12 +16,12 @@ import pandas as pd import pytest import torch + from torchmetrics.functional.nominal.pearson import ( pearsons_contingency_coefficient, pearsons_contingency_coefficient_matrix, ) from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers.testers import MetricTester @@ -39,7 +39,7 @@ # No testing with replacing NaN's values is done as not supported in SciPy -@pytest.fixture() +@pytest.fixture def pearson_matrix_input(): """Define input in matrix format for the metric.""" return torch.cat( diff --git a/tests/unittests/nominal/test_theils_u.py b/tests/unittests/nominal/test_theils_u.py index a8bc2bc9952..fd4cdebb162 100644 --- a/tests/unittests/nominal/test_theils_u.py +++ b/tests/unittests/nominal/test_theils_u.py @@ -16,9 +16,9 @@ import pytest import torch + from torchmetrics.functional.nominal.theils_u import theils_u, theils_u_matrix from torchmetrics.nominal import TheilsU - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers.testers import MetricTester @@ -43,7 +43,7 @@ ) -@pytest.fixture() +@pytest.fixture def theils_u_matrix_input(): """Define input in matrix format for the metric.""" matrix = torch.cat( diff --git a/tests/unittests/nominal/test_tschuprows.py b/tests/unittests/nominal/test_tschuprows.py index a6a8c1d2b39..94fdc3d7474 100644 --- a/tests/unittests/nominal/test_tschuprows.py +++ b/tests/unittests/nominal/test_tschuprows.py @@ -16,9 +16,9 @@ import pandas as pd import pytest import torch + from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix from torchmetrics.nominal.tschuprows import TschuprowsT - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers.testers import MetricTester @@ -36,7 +36,7 @@ # No testing with replacing NaN's values is done as not supported in SciPy -@pytest.fixture() +@pytest.fixture def tschuprows_matrix_input(): """Define input in matrix format for the metric.""" return torch.cat( diff --git a/tests/unittests/pairwise/test_pairwise_distance.py b/tests/unittests/pairwise/test_pairwise_distance.py index 6538423c592..13d48f7f368 100644 --- a/tests/unittests/pairwise/test_pairwise_distance.py +++ b/tests/unittests/pairwise/test_pairwise_distance.py @@ -24,6 +24,7 @@ pairwise_distances, ) from torch import Tensor + from torchmetrics.functional import ( pairwise_cosine_similarity, pairwise_euclidean_distance, @@ -31,7 +32,6 @@ pairwise_manhattan_distance, pairwise_minkowski_distance, ) - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester @@ -81,7 +81,7 @@ def _wrap_reduction(x, y, sk_fn, reduction): [ pytest.param(pairwise_cosine_similarity, cosine_similarity, id="cosine"), pytest.param(pairwise_euclidean_distance, euclidean_distances, id="euclidean"), - pytest.param(pairwise_manhattan_distance, manhattan_distances, id="manhatten"), + pytest.param(pairwise_manhattan_distance, manhattan_distances, id="manhattan"), pytest.param(pairwise_linear_similarity, linear_kernel, id="linear"), pytest.param( partial(pairwise_minkowski_distance, exponent=3), diff --git a/tests/unittests/regression/test_concordance.py b/tests/unittests/regression/test_concordance.py index 06493af0e0d..d5dd3378998 100644 --- a/tests/unittests/regression/test_concordance.py +++ b/tests/unittests/regression/test_concordance.py @@ -17,10 +17,10 @@ import pytest import torch from scipy.stats import pearsonr + from torchmetrics.functional.regression.concordance import concordance_corrcoef from torchmetrics.regression.concordance import ConcordanceCorrCoef from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_cosine_similarity.py b/tests/unittests/regression/test_cosine_similarity.py index de689cabb45..3b67936b2dd 100644 --- a/tests/unittests/regression/test_cosine_similarity.py +++ b/tests/unittests/regression/test_cosine_similarity.py @@ -17,9 +17,9 @@ import pytest import torch from sklearn.metrics.pairwise import cosine_similarity as sk_cosine + from torchmetrics.functional.regression.cosine_similarity import cosine_similarity from torchmetrics.regression.cosine_similarity import CosineSimilarity - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_explained_variance.py b/tests/unittests/regression/test_explained_variance.py index 4838efa5df1..5d7f862a577 100644 --- a/tests/unittests/regression/test_explained_variance.py +++ b/tests/unittests/regression/test_explained_variance.py @@ -16,9 +16,9 @@ import pytest import torch from sklearn.metrics import explained_variance_score + from torchmetrics.functional import explained_variance from torchmetrics.regression import ExplainedVariance - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index 69c32106aba..ffb9f0adc46 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -19,10 +19,10 @@ import torch from lightning_utilities.core.imports import compare_version from scipy.stats import kendalltau + from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef from torchmetrics.regression.kendall import KendallRankCorrCoef from torchmetrics.utilities.imports import _SCIPY_GREATER_EQUAL_1_8 - from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_kl_divergence.py b/tests/unittests/regression/test_kl_divergence.py index 53a3817fa82..b6ff80815f5 100644 --- a/tests/unittests/regression/test_kl_divergence.py +++ b/tests/unittests/regression/test_kl_divergence.py @@ -19,10 +19,10 @@ import torch from scipy.stats import entropy from torch import Tensor + from torchmetrics.functional.regression.kl_divergence import kl_divergence from torchmetrics.regression.kl_divergence import KLDivergence from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_log_cosh_error.py b/tests/unittests/regression/test_log_cosh_error.py index 9931ec91c22..0d89e66bff9 100644 --- a/tests/unittests/regression/test_log_cosh_error.py +++ b/tests/unittests/regression/test_log_cosh_error.py @@ -16,9 +16,9 @@ import numpy as np import pytest import torch + from torchmetrics.functional.regression.log_cosh import log_cosh_error from torchmetrics.regression.log_cosh import LogCoshError - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 38c86817184..d8324e220f2 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -25,6 +25,7 @@ from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error from sklearn.metrics._regression import _check_reg_targets from sklearn.utils import check_consistent_length + from torchmetrics.functional import ( mean_absolute_error, mean_absolute_percentage_error, @@ -43,7 +44,6 @@ ) from torchmetrics.regression.nrmse import NormalizedRootMeanSquaredError from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_minkowski_distance.py b/tests/unittests/regression/test_minkowski_distance.py index 9c0a27b6e54..e7a43d1b9a0 100644 --- a/tests/unittests/regression/test_minkowski_distance.py +++ b/tests/unittests/regression/test_minkowski_distance.py @@ -3,10 +3,10 @@ import pytest import torch from scipy.spatial.distance import minkowski as scipy_minkowski + from torchmetrics.functional import minkowski_distance from torchmetrics.regression import MinkowskiDistance from torchmetrics.utilities.exceptions import TorchMetricsUserError - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 07cbf3fd65c..2e55b4c1af0 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -16,9 +16,9 @@ import pytest import torch from scipy.stats import pearsonr + from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation - from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_r2.py b/tests/unittests/regression/test_r2.py index 8649a3392e9..26ed9347e6c 100644 --- a/tests/unittests/regression/test_r2.py +++ b/tests/unittests/regression/test_r2.py @@ -16,9 +16,9 @@ import pytest import torch from sklearn.metrics import r2_score as sk_r2score + from torchmetrics.functional import r2_score from torchmetrics.regression import R2Score - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_rse.py b/tests/unittests/regression/test_rse.py index 4ec677aa9ff..2dab5074654 100644 --- a/tests/unittests/regression/test_rse.py +++ b/tests/unittests/regression/test_rse.py @@ -16,10 +16,10 @@ import numpy as np import pytest import torch + from torchmetrics.functional import relative_squared_error from torchmetrics.regression import RelativeSquaredError from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_spearman.py b/tests/unittests/regression/test_spearman.py index b8d096e3d0e..8052673cfbd 100644 --- a/tests/unittests/regression/test_spearman.py +++ b/tests/unittests/regression/test_spearman.py @@ -16,10 +16,10 @@ import pytest import torch from scipy.stats import rankdata, spearmanr + from torchmetrics.functional.regression.spearman import _rank_data, spearman_corrcoef from torchmetrics.regression.spearman import SpearmanCorrCoef from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 - from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/regression/test_tweedie_deviance.py b/tests/unittests/regression/test_tweedie_deviance.py index ec45b8ceb49..88ba02f0969 100644 --- a/tests/unittests/regression/test_tweedie_deviance.py +++ b/tests/unittests/regression/test_tweedie_deviance.py @@ -17,9 +17,9 @@ import torch from sklearn.metrics import mean_tweedie_deviance from torch import Tensor + from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore - from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/retrieval/test_auroc.py b/tests/unittests/retrieval/test_auroc.py index 26dbf62a9aa..653d957c64c 100644 --- a/tests/unittests/retrieval/test_auroc.py +++ b/tests/unittests/retrieval/test_auroc.py @@ -17,10 +17,10 @@ import pytest from sklearn.metrics import roc_auc_score from torch import Tensor -from torchmetrics.functional.retrieval.auroc import retrieval_auroc -from torchmetrics.retrieval.auroc import RetrievalAUROC from typing_extensions import Literal +from torchmetrics.functional.retrieval.auroc import retrieval_auroc +from torchmetrics.retrieval.auroc import RetrievalAUROC from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, diff --git a/tests/unittests/retrieval/test_fallout.py b/tests/unittests/retrieval/test_fallout.py index 9b0d8a3ebe0..20963940a02 100644 --- a/tests/unittests/retrieval/test_fallout.py +++ b/tests/unittests/retrieval/test_fallout.py @@ -16,10 +16,10 @@ import numpy as np import pytest from torch import Tensor -from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out -from torchmetrics.retrieval.fall_out import RetrievalFallOut from typing_extensions import Literal +from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out +from torchmetrics.retrieval.fall_out import RetrievalFallOut from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, diff --git a/tests/unittests/retrieval/test_hit_rate.py b/tests/unittests/retrieval/test_hit_rate.py index 377c304ae5e..8d9dda583eb 100644 --- a/tests/unittests/retrieval/test_hit_rate.py +++ b/tests/unittests/retrieval/test_hit_rate.py @@ -16,10 +16,10 @@ import numpy as np import pytest from torch import Tensor -from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate -from torchmetrics.retrieval.hit_rate import RetrievalHitRate from typing_extensions import Literal +from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate +from torchmetrics.retrieval.hit_rate import RetrievalHitRate from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, diff --git a/tests/unittests/retrieval/test_map.py b/tests/unittests/retrieval/test_map.py index f3ac6b9989d..ea2a9e6a3a2 100644 --- a/tests/unittests/retrieval/test_map.py +++ b/tests/unittests/retrieval/test_map.py @@ -17,10 +17,10 @@ import pytest from sklearn.metrics import average_precision_score as sk_average_precision_score from torch import Tensor -from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision -from torchmetrics.retrieval.average_precision import RetrievalMAP from typing_extensions import Literal +from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision +from torchmetrics.retrieval.average_precision import RetrievalMAP from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, diff --git a/tests/unittests/retrieval/test_mrr.py b/tests/unittests/retrieval/test_mrr.py index 22cc946e8a8..992a29012a4 100644 --- a/tests/unittests/retrieval/test_mrr.py +++ b/tests/unittests/retrieval/test_mrr.py @@ -17,10 +17,10 @@ import pytest from sklearn.metrics import label_ranking_average_precision_score from torch import Tensor -from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank -from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR from typing_extensions import Literal +from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank +from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index 1f68839eb4e..5c8fe20a2f7 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -18,10 +18,10 @@ import torch from sklearn.metrics import ndcg_score from torch import Tensor -from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg -from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG from typing_extensions import Literal +from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg +from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, diff --git a/tests/unittests/retrieval/test_precision.py b/tests/unittests/retrieval/test_precision.py index a6e18756cab..c96138c4256 100644 --- a/tests/unittests/retrieval/test_precision.py +++ b/tests/unittests/retrieval/test_precision.py @@ -16,10 +16,10 @@ import numpy as np import pytest from torch import Tensor -from torchmetrics.functional.retrieval.precision import retrieval_precision -from torchmetrics.retrieval.precision import RetrievalPrecision from typing_extensions import Literal +from torchmetrics.functional.retrieval.precision import retrieval_precision +from torchmetrics.retrieval.precision import RetrievalPrecision from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, @@ -50,7 +50,7 @@ def _precision_at_k(target: np.ndarray, preds: np.ndarray, top_k: Optional[int] assert target.shape == preds.shape assert len(target.shape) == 1 # works only with single dimension inputs - if top_k is None or adaptive_k and top_k > len(preds): + if top_k is None or (adaptive_k and top_k > len(preds)): top_k = len(preds) if target.sum() > 0: diff --git a/tests/unittests/retrieval/test_precision_recall_curve.py b/tests/unittests/retrieval/test_precision_recall_curve.py index d8e9817bd51..915909cbc99 100644 --- a/tests/unittests/retrieval/test_precision_recall_curve.py +++ b/tests/unittests/retrieval/test_precision_recall_curve.py @@ -19,10 +19,10 @@ import torch from numpy import array from torch import Tensor, tensor -from torchmetrics.retrieval import RetrievalPrecisionRecallCurve -from torchmetrics.retrieval.base import _retrieval_aggregate from typing_extensions import Literal +from torchmetrics.retrieval import RetrievalPrecisionRecallCurve +from torchmetrics.retrieval.base import _retrieval_aggregate from unittests._helpers import seed_all from unittests._helpers.testers import Metric, MetricTester from unittests.retrieval.helpers import _custom_aggregate_fn, _default_metric_class_input_arguments, get_group_indexes diff --git a/tests/unittests/retrieval/test_r_precision.py b/tests/unittests/retrieval/test_r_precision.py index 50a0384a58d..27e7721b794 100644 --- a/tests/unittests/retrieval/test_r_precision.py +++ b/tests/unittests/retrieval/test_r_precision.py @@ -16,10 +16,10 @@ import numpy as np import pytest from torch import Tensor -from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision -from torchmetrics.retrieval.r_precision import RetrievalRPrecision from typing_extensions import Literal +from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision +from torchmetrics.retrieval.r_precision import RetrievalRPrecision from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, diff --git a/tests/unittests/retrieval/test_recall.py b/tests/unittests/retrieval/test_recall.py index 24ff4b6a756..dd9ffa79afa 100644 --- a/tests/unittests/retrieval/test_recall.py +++ b/tests/unittests/retrieval/test_recall.py @@ -16,10 +16,10 @@ import numpy as np import pytest from torch import Tensor -from torchmetrics.functional.retrieval.recall import retrieval_recall -from torchmetrics.retrieval.recall import RetrievalRecall from typing_extensions import Literal +from torchmetrics.functional.retrieval.recall import retrieval_recall +from torchmetrics.retrieval.recall import RetrievalRecall from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index 2828faaf987..b974b15b135 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -16,10 +16,10 @@ import pytest import torch from sklearn.metrics import f1_score + from torchmetrics import MetricCollection from torchmetrics.functional.segmentation.dice import dice_score from torchmetrics.segmentation.dice import DiceScore - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index c87fd6aa22e..e2014be6496 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -18,9 +18,9 @@ import torch from lightning_utilities.core.imports import RequirementCache from monai.metrics.generalized_dice import compute_generalized_dice + from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore - from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/segmentation/test_hausdorff_distance.py b/tests/unittests/segmentation/test_hausdorff_distance.py index afd77c1f4b2..f4f86b3d9d1 100644 --- a/tests/unittests/segmentation/test_hausdorff_distance.py +++ b/tests/unittests/segmentation/test_hausdorff_distance.py @@ -17,9 +17,9 @@ import pytest import torch from monai.metrics.hausdorff_distance import compute_hausdorff_distance as monai_hausdorff_distance + from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance - from unittests import NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/segmentation/test_mean_iou.py b/tests/unittests/segmentation/test_mean_iou.py index 8c21d5c70c3..6aa5182562d 100644 --- a/tests/unittests/segmentation/test_mean_iou.py +++ b/tests/unittests/segmentation/test_mean_iou.py @@ -17,9 +17,9 @@ import pytest import torch from monai.metrics.meaniou import compute_iou + from torchmetrics.functional.segmentation.mean_iou import mean_iou from torchmetrics.segmentation.mean_iou import MeanIoU - from unittests import NUM_CLASSES from unittests._helpers.testers import MetricTester from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 diff --git a/tests/unittests/segmentation/test_utils.py b/tests/unittests/segmentation/test_utils.py index 39cff09a2dd..30a6da7d954 100644 --- a/tests/unittests/segmentation/test_utils.py +++ b/tests/unittests/segmentation/test_utils.py @@ -21,6 +21,7 @@ from scipy.ndimage import distance_transform_cdt as scidistance_transform_cdt from scipy.ndimage import distance_transform_edt as scidistance_transform_edt from scipy.ndimage import generate_binary_structure as scigenerate_binary_structure + from torchmetrics.functional.segmentation.utils import ( binary_erosion, distance_transform, diff --git a/tests/unittests/shape/test_procrustes.py b/tests/unittests/shape/test_procrustes.py index a3b89e13eb7..107b6494851 100644 --- a/tests/unittests/shape/test_procrustes.py +++ b/tests/unittests/shape/test_procrustes.py @@ -17,9 +17,9 @@ import pytest import torch from scipy.spatial import procrustes as scipy_procrustes + from torchmetrics.functional.shape.procrustes import procrustes_disparity from torchmetrics.shape.procrustes import ProcrustesDisparity - from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/test_deprecated.py b/tests/unittests/test_deprecated.py index cadc860e349..8bc49ff3bfa 100644 --- a/tests/unittests/test_deprecated.py +++ b/tests/unittests/test_deprecated.py @@ -1,5 +1,6 @@ import pytest import torch + from torchmetrics.classification import Dice from torchmetrics.functional.classification import dice diff --git a/tests/unittests/text/_helpers.py b/tests/unittests/text/_helpers.py index 2b9f8381d42..7f80438e6f3 100644 --- a/tests/unittests/text/_helpers.py +++ b/tests/unittests/text/_helpers.py @@ -21,8 +21,8 @@ import pytest import torch from torch import Tensor -from torchmetrics import Metric +from torchmetrics import Metric from unittests import NUM_PROCESSES, USE_PYTEST_POOL, _reference_cachier from unittests._helpers import seed_all from unittests._helpers.testers import ( diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index d7e1fb22609..75f920ddb8f 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -17,11 +17,11 @@ import pytest from torch import Tensor +from typing_extensions import Literal + from torchmetrics.functional.text.bert import bert_score from torchmetrics.text.bert import BERTScore from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 -from typing_extensions import Literal - from unittests._helpers import skip_on_connection_issues from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_single_reference diff --git a/tests/unittests/text/test_bleu.py b/tests/unittests/text/test_bleu.py index 03ce0faba02..66cfc764a27 100644 --- a/tests/unittests/text/test_bleu.py +++ b/tests/unittests/text/test_bleu.py @@ -17,9 +17,9 @@ import pytest from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from torch import tensor + from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.text.bleu import BLEUScore - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index 99de422ba26..855b01c39d2 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -14,9 +14,9 @@ from typing import Union import pytest + from torchmetrics.functional.text.cer import char_error_rate from torchmetrics.text.cer import CharErrorRate - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py index ca3995e2519..12169fac561 100644 --- a/tests/unittests/text/test_chrf.py +++ b/tests/unittests/text/test_chrf.py @@ -16,9 +16,9 @@ import pytest from torch import Tensor, tensor + from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.text.chrf import CHRFScore - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references diff --git a/tests/unittests/text/test_edit.py b/tests/unittests/text/test_edit.py index 457bcfa18ad..8ec976cda6b 100644 --- a/tests/unittests/text/test_edit.py +++ b/tests/unittests/text/test_edit.py @@ -15,9 +15,9 @@ import pytest from nltk.metrics.distance import edit_distance as nltk_edit_distance + from torchmetrics.functional.text.edit import edit_distance from torchmetrics.text.edit import EditDistance - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_single_reference diff --git a/tests/unittests/text/test_eed.py b/tests/unittests/text/test_eed.py index a9c30d384de..dd4a2a02d70 100644 --- a/tests/unittests/text/test_eed.py +++ b/tests/unittests/text/test_eed.py @@ -16,9 +16,9 @@ import pytest from torch import Tensor, tensor + from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.text.eed import ExtendedEditDistance - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_single_reference, _inputs_single_sentence_multiple_references diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index b3fd26026ca..43675aa9f14 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -15,10 +15,10 @@ import pytest import torch + from torchmetrics.functional.text.infolm import infolm from torchmetrics.text.infolm import InfoLM from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 - from unittests._helpers import skip_on_connection_issues from unittests.text._helpers import TextTester from unittests.text._inputs import HYPOTHESIS_A, HYPOTHESIS_C, _inputs_single_reference diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index 56592c34762..ecef5825799 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -14,9 +14,9 @@ from typing import Union import pytest + from torchmetrics.functional.text.mer import match_error_rate from torchmetrics.text.mer import MatchErrorRate - from unittests._helpers import seed_all from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 diff --git a/tests/unittests/text/test_perplexity.py b/tests/unittests/text/test_perplexity.py index 42930ec21ff..5a3c5fe48f9 100644 --- a/tests/unittests/text/test_perplexity.py +++ b/tests/unittests/text/test_perplexity.py @@ -16,10 +16,10 @@ import pytest import torch from torch.nn import functional as F # noqa: N812 + from torchmetrics.functional.text.perplexity import perplexity from torchmetrics.text.perplexity import Perplexity from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2 - from unittests._helpers.testers import MetricTester from unittests.text._inputs import ( MASK_INDEX, diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index e371afae796..f993cbf4c82 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -20,11 +20,11 @@ import pytest import torch from torch import Tensor +from typing_extensions import Literal + from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.text.rouge import ROUGEScore from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE -from typing_extensions import Literal - from unittests._helpers import skip_on_connection_issues from unittests.text._helpers import TextTester from unittests.text._inputs import _Input, _inputs_multiple_references, _inputs_single_sentence_single_reference @@ -322,6 +322,6 @@ def test_rouge_score_accumulate_best(preds, references, expected_scores): # Assert each expected score for key in expected_scores: - assert torch.isclose( - result[key], torch.tensor(expected_scores[key]) - ), f"Expected {expected_scores[key]} for {key}, but got {result[key]}" + assert torch.isclose(result[key], torch.tensor(expected_scores[key])), ( + f"Expected {expected_scores[key]} for {key}, but got {result[key]}" + ) diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index 54da47a6a8a..a2fc3d95144 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -18,9 +18,9 @@ import pytest from lightning_utilities.core.imports import RequirementCache from torch import Tensor, tensor + from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, _TokenizersLiteral, sacre_bleu_score from torchmetrics.text.sacre_bleu import SacreBLEUScore - from unittests._helpers import skip_on_connection_issues from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references diff --git a/tests/unittests/text/test_squad.py b/tests/unittests/text/test_squad.py index 8a3d26af8ab..2a9f84e3b28 100644 --- a/tests/unittests/text/test_squad.py +++ b/tests/unittests/text/test_squad.py @@ -17,9 +17,9 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp + from torchmetrics.functional.text import squad from torchmetrics.text.squad import SQuAD - from unittests._helpers.testers import _assert_allclose, _assert_tensor from unittests.text._inputs import _inputs_squad_batch_match, _inputs_squad_exact_match, _inputs_squad_exact_mismatch diff --git a/tests/unittests/text/test_ter.py b/tests/unittests/text/test_ter.py index 1f896bdbbbe..ef7df3e688d 100644 --- a/tests/unittests/text/test_ter.py +++ b/tests/unittests/text/test_ter.py @@ -16,9 +16,9 @@ import pytest from torch import Tensor, tensor + from torchmetrics.functional.text.ter import translation_edit_rate from torchmetrics.text.ter import TranslationEditRate - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index e57c51af5fd..fa22d7c9db1 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -14,9 +14,9 @@ from typing import Union import pytest + from torchmetrics.functional.text.wer import word_error_rate from torchmetrics.text.wer import WordErrorRate - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index 9b1615e5ee6..e7d37f22a6e 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -14,9 +14,9 @@ from typing import Union import pytest + from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.text.wil import WordInfoLost - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index b0393d12b29..1c19165b29f 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -14,9 +14,9 @@ from typing import Union import pytest + from torchmetrics.functional.text.wip import word_information_preserved from torchmetrics.text.wip import WordInfoPreserved - from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 4ebd41fd300..4480090a23e 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -21,6 +21,7 @@ import pytest import torch from torch import tensor + from torchmetrics import MetricCollection from torchmetrics.aggregation import MaxMetric, MeanMetric, MinMetric, SumMetric from torchmetrics.audio import ( diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index c2ffd89047a..6f9f5d62843 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -18,6 +18,7 @@ import torch from lightning_utilities.test.warning import no_warning_call from torch import tensor + from torchmetrics.regression import MeanSquaredError, PearsonCorrCoef from torchmetrics.utilities import check_forward_full_state_property, rank_zero_debug, rank_zero_info, rank_zero_warn from torchmetrics.utilities.checks import _allclose_recursive diff --git a/tests/unittests/wrappers/test_bootstrapping.py b/tests/unittests/wrappers/test_bootstrapping.py index bd6e072ce12..138b7407839 100644 --- a/tests/unittests/wrappers/test_bootstrapping.py +++ b/tests/unittests/wrappers/test_bootstrapping.py @@ -21,10 +21,10 @@ from lightning_utilities import apply_to_collection from sklearn.metrics import mean_squared_error, precision_score, recall_score from torch import Tensor + from torchmetrics.classification import MulticlassF1Score, MulticlassPrecision, MulticlassRecall from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler - from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/wrappers/test_classwise.py b/tests/unittests/wrappers/test_classwise.py index e6491903145..fd0a7fb2047 100644 --- a/tests/unittests/wrappers/test_classwise.py +++ b/tests/unittests/wrappers/test_classwise.py @@ -1,5 +1,6 @@ import pytest import torch + from torchmetrics import MetricCollection from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassRecall from torchmetrics.clustering import CalinskiHarabaszScore diff --git a/tests/unittests/wrappers/test_feature_share.py b/tests/unittests/wrappers/test_feature_share.py index d908d6e6d60..fe94002133f 100644 --- a/tests/unittests/wrappers/test_feature_share.py +++ b/tests/unittests/wrappers/test_feature_share.py @@ -15,6 +15,7 @@ import pytest import torch + from torchmetrics import MetricCollection from torchmetrics.image import ( FrechetInceptionDistance, @@ -111,12 +112,12 @@ def test_memory(): feature_share = FeatureShare([fid, inception, kid]).cuda() memory_after_fs = torch.cuda.memory_allocated() - assert ( - memory_after_fs > base_memory - ), "The memory usage should be higher after initializing the feature share wrapper." - assert ( - memory_after_fs < memory_before_fs - ), "The memory usage should be higher after initializing the feature share wrapper." + assert memory_after_fs > base_memory, ( + "The memory usage should be higher after initializing the feature share wrapper." + ) + assert memory_after_fs < memory_before_fs, ( + "The memory usage should be higher after initializing the feature share wrapper." + ) img1 = torch.randint(255, (50, 3, 220, 220), dtype=torch.uint8).to("cuda") img2 = torch.randint(255, (50, 3, 220, 220), dtype=torch.uint8).to("cuda") diff --git a/tests/unittests/wrappers/test_minmax.py b/tests/unittests/wrappers/test_minmax.py index 480fd1f8f7c..b01242b5774 100644 --- a/tests/unittests/wrappers/test_minmax.py +++ b/tests/unittests/wrappers/test_minmax.py @@ -18,10 +18,10 @@ import pytest import torch from torch import Tensor + from torchmetrics.classification import BinaryAccuracy, BinaryConfusionMatrix, MulticlassAccuracy from torchmetrics.regression import MeanSquaredError from torchmetrics.wrappers import MinMaxMetric - from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/wrappers/test_multioutput.py b/tests/unittests/wrappers/test_multioutput.py index 3d95b20d69a..dd837ec29df 100644 --- a/tests/unittests/wrappers/test_multioutput.py +++ b/tests/unittests/wrappers/test_multioutput.py @@ -19,11 +19,11 @@ from sklearn.metrics import accuracy_score from sklearn.metrics import r2_score as sk_r2score from torch import Tensor, tensor + from torchmetrics import Metric from torchmetrics.classification import ConfusionMatrix, MulticlassAccuracy from torchmetrics.regression import R2Score from torchmetrics.wrappers.multioutput import MultioutputWrapper - from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester diff --git a/tests/unittests/wrappers/test_multitask.py b/tests/unittests/wrappers/test_multitask.py index fb3ae8987cc..ee3b586c67a 100644 --- a/tests/unittests/wrappers/test_multitask.py +++ b/tests/unittests/wrappers/test_multitask.py @@ -16,12 +16,12 @@ import pytest import torch + from torchmetrics import MetricCollection from torchmetrics.classification import BinaryAccuracy, BinaryF1Score from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_5 from torchmetrics.wrappers import MultitaskWrapper - from unittests import BATCH_SIZE, NUM_BATCHES from unittests._helpers import seed_all diff --git a/tests/unittests/wrappers/test_running.py b/tests/unittests/wrappers/test_running.py index ecac751453d..4180c4840bf 100644 --- a/tests/unittests/wrappers/test_running.py +++ b/tests/unittests/wrappers/test_running.py @@ -16,12 +16,12 @@ import pytest import torch + from torchmetrics.aggregation import MeanMetric, SumMetric from torchmetrics.classification import BinaryAccuracy, BinaryConfusionMatrix from torchmetrics.collections import MetricCollection from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, PearsonCorrCoef from torchmetrics.wrappers import Running - from unittests import NUM_PROCESSES, USE_PYTEST_POOL @@ -68,9 +68,9 @@ def test_forward(): for i in range(10): assert compare_metric(i) == metric(i) - assert metric.compute() == (i + max(i - 1, 0) + max(i - 2, 0)) / min( - i + 1, 3 - ), f"Running mean is not correct in step {i}" + assert metric.compute() == (i + max(i - 1, 0) + max(i - 2, 0)) / min(i + 1, 3), ( + f"Running mean is not correct in step {i}" + ) @pytest.mark.parametrize( @@ -143,7 +143,7 @@ def _test_ddp_running(rank, dist_sync_on_step, expected): assert metric.compute() == 6 -@pytest.mark.DDP() +@pytest.mark.DDP @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") @pytest.mark.parametrize(("dist_sync_on_step", "expected"), [(False, 1), (True, 2)]) diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index 97c2ae37234..75b9f596b81 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -16,6 +16,7 @@ import pytest import torch + from torchmetrics import Metric, MetricCollection from torchmetrics.classification import ( MulticlassAccuracy, @@ -26,7 +27,6 @@ from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_6 from torchmetrics.wrappers import MetricTracker, MultioutputWrapper - from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/wrappers/test_transformations.py b/tests/unittests/wrappers/test_transformations.py index 5a30ce1d1ff..ae661568b2a 100644 --- a/tests/unittests/wrappers/test_transformations.py +++ b/tests/unittests/wrappers/test_transformations.py @@ -15,11 +15,11 @@ import pytest from torch import Tensor + from torchmetrics.aggregation import MeanMetric from torchmetrics.classification import BinaryAccuracy from torchmetrics.retrieval import RetrievalMAP from torchmetrics.wrappers import BinaryTargetTransformer, LambdaInputTransformer, MetricInputTransformer - from unittests._helpers import seed_all seed_all(42) From 2dfa7b44d9e67051bcbd2da3ff669e5b0d844db5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:39:30 +0900 Subject: [PATCH 2/5] build(deps): update scipy requirement from <1.15.0,>1.0.0 to >1.0.0,<1.16.0 in /requirements (#2898) build(deps): update scipy requirement in /requirements Updates the requirements on [scipy](https://github.com/scipy/scipy) to permit the latest version. - [Release notes](https://github.com/scipy/scipy/releases) - [Commits](https://github.com/scipy/scipy/compare/v1.0.1...v1.15.0) --- updated-dependencies: - dependency-name: scipy dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/image.txt | 2 +- requirements/nominal_test.txt | 2 +- requirements/segmentation_test.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/image.txt b/requirements/image.txt index e594578d6e2..adbdfad90ad 100644 --- a/requirements/image.txt +++ b/requirements/image.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -scipy >1.0.0, <1.15.0 +scipy >1.0.0, <1.16.0 torchvision >=0.15.1, <0.22.0 torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing diff --git a/requirements/nominal_test.txt b/requirements/nominal_test.txt index 70beddada96..5b579d59a48 100644 --- a/requirements/nominal_test.txt +++ b/requirements/nominal_test.txt @@ -4,5 +4,5 @@ pandas >1.4.0, <=2.2.3 # cannot pin version due to numpy version incompatibility dython ==0.7.6 ; python_version <"3.9" dython ~=0.7.8 ; python_version > "3.8" # we do not use `> =` -scipy >1.0.0, <1.15.0 # cannot pin version due to some version conflicts with `oldest` CI configuration +scipy >1.0.0, <1.16.0 # cannot pin version due to some version conflicts with `oldest` CI configuration statsmodels >0.13.5, <0.15.0 diff --git a/requirements/segmentation_test.txt b/requirements/segmentation_test.txt index 75d7b97ac6c..1b492d053b8 100644 --- a/requirements/segmentation_test.txt +++ b/requirements/segmentation_test.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -scipy >1.0.0, <1.15.0 +scipy >1.0.0, <1.16.0 monai ==1.3.2 ; python_version < "3.9" monai ==1.4.0 ; python_version > "3.8" From f218f6d04ffaf51eb5f3c0345a0ef6840e55127a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:41:23 +0900 Subject: [PATCH 3/5] build(deps): update sacrebleu requirement from <2.5.0,>=2.3.0 to >=2.3.0,<2.6.0 in /requirements (#2899) build(deps): update sacrebleu requirement in /requirements Updates the requirements on [sacrebleu](https://github.com/mjpost/sacrebleu) to permit the latest version. - [Release notes](https://github.com/mjpost/sacrebleu/releases) - [Changelog](https://github.com/mjpost/sacrebleu/blob/master/CHANGELOG.md) - [Commits](https://github.com/mjpost/sacrebleu/compare/v2.3.0...v2.5.1) --- updated-dependencies: - dependency-name: sacrebleu dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/text_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/text_test.txt b/requirements/text_test.txt index 0080bb28fe9..175396df735 100644 --- a/requirements/text_test.txt +++ b/requirements/text_test.txt @@ -5,7 +5,7 @@ jiwer >=2.3.0, <3.1.0 rouge-score >0.1.0, <=0.1.2 bert_score ==0.3.13 huggingface-hub <0.28 -sacrebleu >=2.3.0, <2.5.0 +sacrebleu >=2.3.0, <2.6.0 mecab-ko >=1.0.0, <1.1.0 ; python_version < "3.12" # strict # todo: unpin python_version mecab-ko-dic >=1.0.0, <1.1.0 ; python_version < "3.12" # todo: unpin python_version From d9a62d40fbe5b4cad48b77a9c859e54525324da6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:46:25 +0900 Subject: [PATCH 4/5] build(deps): update dython requirement from ~=0.7.8 to ~=0.7.9 in /requirements (#2901) * build(deps): update dython requirement in /requirements Updates the requirements on [dython](https://github.com/shakedzy/dython) to permit the latest version. - [Release notes](https://github.com/shakedzy/dython/releases) - [Changelog](https://github.com/shakedzy/dython/blob/master/CHANGELOG.md) - [Commits](https://github.com/shakedzy/dython/compare/v0.7.8...v0.7.9) --- updated-dependencies: - dependency-name: dython dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Apply suggestions from code review --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- requirements/nominal_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/nominal_test.txt b/requirements/nominal_test.txt index 5b579d59a48..c73659aca4c 100644 --- a/requirements/nominal_test.txt +++ b/requirements/nominal_test.txt @@ -3,6 +3,6 @@ pandas >1.4.0, <=2.2.3 # cannot pin version due to numpy version incompatibility dython ==0.7.6 ; python_version <"3.9" -dython ~=0.7.8 ; python_version > "3.8" # we do not use `> =` +dython ==0.7.9 ; python_version > "3.8" # we do not use `> =` scipy >1.0.0, <1.16.0 # cannot pin version due to some version conflicts with `oldest` CI configuration statsmodels >0.13.5, <0.15.0 From e690bbda34c10689d3cb411289aa174f9a4653ff Mon Sep 17 00:00:00 2001 From: Jirka B Date: Tue, 7 Jan 2025 23:11:13 +0900 Subject: [PATCH 5/5] ci: update `torch` URL --- .github/workflows/ci-integrate.yml | 2 +- .github/workflows/ci-tests.yml | 2 +- .github/workflows/docs-build.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-integrate.yml b/.github/workflows/ci-integrate.yml index bc8e47c72c6..8826cf49ac1 100644 --- a/.github/workflows/ci-integrate.yml +++ b/.github/workflows/ci-integrate.yml @@ -36,7 +36,7 @@ jobs: - { python-version: "3.10", requires: "latest", os: "ubuntu-22.04" } # - { python-version: "3.10", requires: "latest", os: "macOS-14" } # M1 machine # todo: crashing for MPS out of memory env: - PYTORCH_URL: "http://download.pytorch.org/whl/cpu/" + PYTORCH_URL: "https://download.pytorch.org/whl/cpu/" FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} PYPI_CACHE: "_ci-cache_PyPI" diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 65d0dd1e49f..0f6628277c0 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -64,7 +64,7 @@ jobs: FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} TOKENIZERS_PARALLELISM: false TEST_DIRS: ${{ needs.check-diff.outputs.test-dirs }} - PIP_EXTRA_INDEX_URL: "--extra-index-url=http://download.pytorch.org/whl/cpu/" + PIP_EXTRA_INDEX_URL: "--extra-index-url=https://download.pytorch.org/whl/cpu/" UNITTEST_TIMEOUT: "" # by default, it is not set # Timeout: https://stackoverflow.com/a/59076067/4521646 diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index db3bbadc2ce..a7c6130d93f 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -18,7 +18,7 @@ defaults: env: FREEZE_REQUIREMENTS: "1" - TORCH_URL: "http://download.pytorch.org/whl/cpu/" + TORCH_URL: "https://download.pytorch.org/whl/cpu/" PYPI_CACHE: "_ci-cache_PyPI" PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: "python" TOKENIZERS_PARALLELISM: false