Skip to content

Commit

Permalink
Merge branch 'master' into use-pretty_errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 18, 2024
2 parents 28978c8 + dfcdfe0 commit 925c016
Show file tree
Hide file tree
Showing 25 changed files with 232 additions and 149 deletions.
5 changes: 3 additions & 2 deletions .azure/gpu-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,15 @@ jobs:
displayName: "Show caches"
- bash: |
python -m pytest torchmetrics -s --cov=torchmetrics \
python -m pytest torchmetrics --cov=torchmetrics \
--timeout=240 --durations=50 \
--reruns 2 --reruns-delay 1
# --numprocesses=5 --dist=loadfile
env:
DOCTEST_DOWNLOAD_TIMEOUT: "180"
SKIP_SLOW_DOCTEST: "1"
workingDirectory: src
timeoutInMinutes: "40"
displayName: "DocTesting"
- bash: |
Expand All @@ -154,7 +155,7 @@ jobs:
displayName: "Pull testing data from S3"
- bash: |
python -m pytest $(TEST_DIRS) -v \
python -m pytest $(TEST_DIRS) \
-m "not DDP" --numprocesses=5 --dist=loadfile \
--cov=torchmetrics --timeout=240 --durations=100 \
--reruns 3 --reruns-delay 1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ jobs:
if: ${{ env.TEST_DIRS != '' }}
working-directory: ./tests
run: |
python -m pytest -v \
python -m pytest \
$TEST_DIRS \
--cov=torchmetrics \
--durations=50 \
Expand Down
9 changes: 6 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
rev: v0.3.2
hooks:
- id: ruff-format
args: ["--preview"]
# try to fix what is possible
- id: ruff
args: ["--fix"]
# perform formatting updates
- id: ruff-format
# validate if all is fine with preview mode
- id: ruff

- repo: https://github.com/tox-dev/pyproject-fmt
rev: 1.7.0
Expand Down
10 changes: 8 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed negative variance estimates in certain image metrics ([#2378](https://github.com/Lightning-AI/torchmetrics/pull/2378))
-


- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))
## [1.3.2] - 2024-03-18

### Fixed

- Fixed negative variance estimates in certain image metrics ([#2378](https://github.com/Lightning-AI/torchmetrics/pull/2378))
- Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379))
- Fixed plotting of metric collection when prefix/postfix is set ([#2429](https://github.com/Lightning-AI/torchmetrics/pull/2429))
- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))
- Fixed case where label prediction tensors in classification metrics were not validated correctly ([#2427](https://github.com/Lightning-AI/torchmetrics/pull/2427))
- Fixed how auc scores are calculated in `PrecisionRecallCurve.plot` methods ([#2437](https://github.com/Lightning-AI/torchmetrics/pull/2437))


## [1.3.1] - 2024-02-12
Expand Down
58 changes: 26 additions & 32 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ requires = [
[tool.ruff]
target-version = "py38"
line-length = 120
# Enable Pyflakes `E` and `F` codes by default.
lint.select = [

[tool.ruff.format]
preview = true

[tool.ruff.lint]
select = [
"E",
"W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
Expand All @@ -18,7 +22,7 @@ lint.select = [
"S", # see: https://pypi.org/project/flake8-bandit
"UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up
]
lint.extend-select = [
extend-select = [
"A", # see: https://pypi.org/project/flake8-builtins
"B", # see: https://pypi.org/project/flake8-bugbear
"C4", # see: https://pypi.org/project/flake8-comprehensions
Expand All @@ -38,24 +42,31 @@ lint.extend-select = [
"PERF", # see: https://pypi.org/project/perflint/
"PYI", # see: https://pypi.org/project/flake8-pyi/
]
lint.ignore = [
ignore = [
"E731", # Do not assign a lambda expression, use a def
"D100", # todo: Missing docstring in public module
"D104", # todo: Missing docstring in public package
"D107", # Missing docstring in `__init__`
"ANN101", # Missing type annotation for `self` in method
"S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue # todo
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo
"B905", # todo: `zip()` without an explicit `strict=` parameter
]
lint.ignore-init-module-imports = true
lint.unfixable = ["F401"]
ignore-init-module-imports = true
unfixable = ["F401"]

[tool.ruff.lint.per-file-ignores]
"setup.py" = ["ANN202", "ANN401"]
"docs/source/conf.py" = ["A001", "D103"]
"src/**" = ["ANN401"]
"tests/**" = ["S101", "ANN001", "ANN201", "ANN202", "ANN401"]
"src/**" = [
"ANN401",
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo
]
"tests/**" = [
"ANN001",
"ANN201",
"ANN202",
"ANN401",
"S101",
"S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue # todo
]

[tool.ruff.lint.pydocstyle]
# Use Google-style docstrings.
Expand Down Expand Up @@ -106,7 +117,10 @@ addopts = [
markers = [
"DDP: mark a test as Distributed Data Parallel",
]
#filterwarnings = ["error::FutureWarning"] # ToDo
filterwarnings = [
"ignore::FutureWarning",
"default:::torchmetrics",
]
xfail_strict = true
junit_duration_report = "call"

Expand All @@ -133,26 +147,6 @@ disable_error_code = "attr-defined"
# style choices
warn_no_return = "False"

# Ignore mypy errors for these files
# TODO: the goal is for this to be empty
[[tool.mypy.overrides]]
module = [
"torchmetrics.classification.exact_match",
"torchmetrics.classification.f_beta",
"torchmetrics.classification.precision_recall",
"torchmetrics.classification.ranking",
"torchmetrics.classification.recall_at_fixed_precision",
"torchmetrics.classification.roc",
"torchmetrics.classification.stat_scores",
"torchmetrics.detection._mean_ap",
"torchmetrics.detection.mean_ap",
"torchmetrics.functional.image.psnr",
"torchmetrics.functional.image.ssim",
"torchmetrics.image.psnr",
"torchmetrics.image.ssim",
]
ignore_errors = "True"

[tool.typos.default]
extend-ignore-identifiers-re = [
# *sigh* this just isn't worth the cost of fixing
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class ExactMatch(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["ExactMatch"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ class FBetaScore(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["FBetaScore"],
task: Literal["binary", "multiclass", "multilabel"],
beta: float = 1.0,
Expand Down Expand Up @@ -1122,7 +1122,7 @@ class F1Score(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["F1Score"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ class Precision(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["Precision"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down Expand Up @@ -995,7 +995,7 @@ class Recall(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["Recall"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down Expand Up @@ -1028,4 +1028,4 @@ def __new__(
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
return MultilabelRecall(num_labels, threshold, average, **kwargs)
return None
return None # type: ignore[return-value]
19 changes: 13 additions & 6 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def plot(
curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will
automatically call `metric.compute` and plot that result.
score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided,
will automatically compute the score.
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Expand All @@ -215,7 +216,7 @@ def plot(
curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2])

score = (
_auc_compute_without_check(curve_computed[0], curve_computed[1], 1.0)
_auc_compute_without_check(curve_computed[0], curve_computed[1], direction=-1.0)
if not curve and score is True
else None
)
Expand Down Expand Up @@ -390,7 +391,8 @@ def plot(
curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will
automatically call `metric.compute` and plot that result.
score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided,
will automatically compute the score.
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Expand All @@ -416,7 +418,9 @@ def plot(
# switch order as the standard way is recall along x-axis and precision along y-axis
curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2])
score = (
_reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None
_reduce_auroc(curve_computed[0], curve_computed[1], average=None, direction=-1.0)
if not curve and score is True
else None
)
return plot_curve(
curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__
Expand Down Expand Up @@ -583,7 +587,8 @@ def plot(
curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will
automatically call `metric.compute` and plot that result.
score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided,
will automatically compute the score.
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Expand All @@ -609,7 +614,9 @@ def plot(
# switch order as the standard way is recall along x-axis and precision along y-axis
curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2])
score = (
_reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None
_reduce_auroc(curve_computed[0], curve_computed[1], average=None, direction=-1.0)
if not curve and score is True
else None
)
return plot_curve(
curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__
Expand Down
17 changes: 10 additions & 7 deletions src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class BinaryROC(BinaryPrecisionRecallCurve):
def compute(self) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _binary_roc_compute(state, self.thresholds)
return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type]

def plot(
self,
Expand All @@ -134,7 +134,8 @@ def plot(
curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will
automatically call `metric.compute` and plot that result.
score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided,
will automatically compute the score.
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Expand Down Expand Up @@ -289,7 +290,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average)
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) # type: ignore[arg-type]

def plot(
self,
Expand All @@ -303,7 +304,8 @@ def plot(
curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will
automatically call `metric.compute` and plot that result.
score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided,
will automatically compute the score.
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Expand Down Expand Up @@ -447,7 +449,7 @@ class MultilabelROC(MultilabelPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index)
return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) # type: ignore[arg-type]

def plot(
self,
Expand All @@ -461,7 +463,8 @@ def plot(
curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will
automatically call `metric.compute` and plot that result.
score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided,
will automatically compute the score.
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Expand Down Expand Up @@ -561,7 +564,7 @@ class ROC(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["ROC"],
task: Literal["binary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, List[float], Tensor]] = None,
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def _create_state(
def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None:
"""Update states depending on multidim_average argument."""
if self.multidim_average == "samplewise":
self.tp.append(tp)
self.fp.append(fp)
self.tn.append(tn)
self.fn.append(fn)
self.tp.append(tp) # type: ignore[union-attr]
self.fp.append(fp) # type: ignore[union-attr]
self.tn.append(tn) # type: ignore[union-attr]
self.fn.append(fn) # type: ignore[union-attr]
else:
self.tp += tp
self.fp += fp
Expand Down Expand Up @@ -515,7 +515,7 @@ class StatScores(_ClassificationTaskWrapper):
"""

def __new__(
def __new__( # type: ignore[misc]
cls: Type["StatScores"],
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
Expand Down
3 changes: 1 addition & 2 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,12 +647,11 @@ def plot(
f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the "
f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`"
)

val = val or self.compute()
if together:
return plot_single_or_multi_val(val, ax=ax)
fig_axs = []
for i, (k, m) in enumerate(self.items(keep_base=True, copy_state=False)):
for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)):
if isinstance(val, dict):
f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
elif isinstance(val, Sequence):
Expand Down
Loading

0 comments on commit 925c016

Please sign in to comment.