Skip to content

Commit

Permalink
CI/tests: cache reference metrics | others/base (#2407)
Browse files Browse the repository at this point in the history
* apply to base
* cachier ==3.0.0
* --show_hidden
* fix wrapper

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Feb 27, 2024
1 parent fc95ed8 commit 42eefbd
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .azure/gpu-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ jobs:
pip install -q py-tree
py-tree /var/tmp/torch
py-tree /var/tmp/hf
py-tree $(PYTEST_REFERENCE_CACHE)
py-tree $(PYTEST_REFERENCE_CACHE) --show_hidden
displayName: "Show caches"
- bash: |
Expand Down
2 changes: 1 addition & 1 deletion requirements/_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ fire <=0.5.0
cloudpickle >1.3, <=3.0.0
scikit-learn >=1.1.1, <1.4.0
# todo: set proper version when all lands and it is released
cachier @ https://github.com/python-cachier/cachier/archive/refs/heads/master.zip
cachier ==3.0.0
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/d_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _spatial_distortion_index_compute(
)
from torchvision.transforms.functional import resize

from torchmetrics.functional.image.helper import _uniform_filter
from torchmetrics.functional.image.utils import _uniform_filter

pan_degraded = _uniform_filter(pan, window_size=window_size)
pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False)
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/rase.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import torch
from torch import Tensor

from torchmetrics.functional.image.helper import _uniform_filter
from torchmetrics.functional.image.rmse_sw import _rmse_sw_compute, _rmse_sw_update
from torchmetrics.functional.image.utils import _uniform_filter


def _rase_update(
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/rmse_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from torch import Tensor

from torchmetrics.functional.image.helper import _uniform_filter
from torchmetrics.functional.image.utils import _uniform_filter
from torchmetrics.utilities.checks import _check_same_shape


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.nn import functional as F # noqa: N812
from typing_extensions import Literal

from torchmetrics.functional.image.helper import _gaussian_kernel_2d, _gaussian_kernel_3d, _reflection_pad_3d
from torchmetrics.functional.image.utils import _gaussian_kernel_2d, _gaussian_kernel_3d, _reflection_pad_3d
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/uqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch import Tensor, nn
from typing_extensions import Literal

from torchmetrics.functional.image.helper import _gaussian_kernel_2d
from torchmetrics.functional.image.utils import _gaussian_kernel_2d
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce

Expand Down
File renamed without changes.
80 changes: 29 additions & 51 deletions tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchmetrics import Metric
from torchmetrics.utilities.data import _flatten

from unittests import NUM_PROCESSES
from unittests import NUM_PROCESSES, _reference_cachier


def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: Optional[str] = None) -> None:
Expand All @@ -35,8 +35,8 @@ def _assert_allclose(tm_result: Any, ref_result: Any, atol: float = 1e-8, key: O
assert np.allclose(tm_result.detach().cpu().numpy(), ref_result, atol=atol, equal_nan=True)
# multi output compare
elif isinstance(tm_result, Sequence):
for pl_res, sk_res in zip(tm_result, ref_result):
_assert_allclose(pl_res, sk_res, atol=atol)
for pl_res, ref_res in zip(tm_result, ref_result):
_assert_allclose(pl_res, ref_res, atol=atol)
elif isinstance(tm_result, Dict):
if key is None:
raise KeyError("Provide Key for Dict based metric results.")
Expand Down Expand Up @@ -167,7 +167,7 @@ def _class_test(
k: torch.cat([v[i + r] for r in range(world_size)]).cpu() if isinstance(v, Tensor) else v
for k, v in (kwargs_update if fragment_kwargs else batch_kwargs_update).items()
}
ref_batch_result = reference_metric(ddp_preds, ddp_target, **ddp_kwargs_upd)
ref_batch_result = _reference_cachier(reference_metric)(ddp_preds, ddp_target, **ddp_kwargs_upd)
if isinstance(batch_result, dict):
for key in batch_result:
_assert_allclose(batch_result, ref_batch_result[key].numpy(), atol=atol, key=key)
Expand All @@ -181,7 +181,7 @@ def _class_test(
}
preds_ = preds[i].cpu() if isinstance(preds, Tensor) else preds[i]
target_ = target[i].cpu() if isinstance(target, Tensor) else target[i]
ref_batch_result = reference_metric(preds_, target_, **batch_kwargs_update)
ref_batch_result = _reference_cachier(reference_metric)(preds_, target_, **batch_kwargs_update)
if isinstance(batch_result, dict):
for key in batch_result:
_assert_allclose(batch_result, ref_batch_result[key].numpy(), atol=atol, key=key)
Expand Down Expand Up @@ -218,14 +218,14 @@ def _class_test(
k: torch.cat([v[i] for i in range(num_batches)]).cpu() if isinstance(v, Tensor) else v
for k, v in kwargs_update.items()
}
sk_result = reference_metric(total_preds, total_target, **total_kwargs_update)
ref_result = _reference_cachier(reference_metric)(total_preds, total_target, **total_kwargs_update)

# assert after aggregation
if isinstance(sk_result, dict):
for key in sk_result:
_assert_allclose(result, sk_result[key].numpy(), atol=atol, key=key)
if isinstance(ref_result, dict):
for key in ref_result:
_assert_allclose(result, ref_result[key].numpy(), atol=atol, key=key)
else:
_assert_allclose(result, sk_result, atol=atol)
_assert_allclose(result, ref_result, atol=atol)


def _functional_test(
Expand Down Expand Up @@ -282,7 +282,7 @@ def _functional_test(
k: v.cpu() if isinstance(v, Tensor) else v
for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items()
}
ref_result = reference_metric(
ref_result = _reference_cachier(reference_metric)(
preds[i].cpu() if isinstance(preds, Tensor) else preds[i],
target[i].cpu() if isinstance(target, Tensor) else target[i],
**extra_kwargs,
Expand Down Expand Up @@ -424,53 +424,31 @@ def run_class_metric_test(
target when running update on the metric.
"""
atol = atol or self.atol
metric_args = metric_args or {}
common_kwargs = {
"preds": preds,
"target": target,
"metric_class": metric_class,
"reference_metric": reference_metric,
"metric_args": metric_args or {},
"atol": atol or self.atol,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"dist_sync_on_step": dist_sync_on_step,
"check_dist_sync_on_step": check_dist_sync_on_step,
"check_batch": check_batch,
"fragment_kwargs": fragment_kwargs,
"check_scriptable": check_scriptable,
"check_state_dict": check_state_dict,
}

if ddp and hasattr(pytest, "pool"):
if sys.platform == "win32":
pytest.skip("DDP not supported on windows")

pytest.pool.starmap(
partial(
_class_test,
preds=preds,
target=target,
metric_class=metric_class,
reference_metric=reference_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
check_dist_sync_on_step=check_dist_sync_on_step,
check_batch=check_batch,
atol=atol,
device="cuda" if torch.cuda.is_available() else "cpu",
fragment_kwargs=fragment_kwargs,
check_scriptable=check_scriptable,
check_state_dict=check_state_dict,
**kwargs_update,
),
partial(_class_test, **common_kwargs, **kwargs_update),
[(rank, NUM_PROCESSES) for rank in range(NUM_PROCESSES)],
)
else:
device = "cuda" if torch.cuda.is_available() else "cpu"

_class_test(
rank=0,
world_size=1,
preds=preds,
target=target,
metric_class=metric_class,
reference_metric=reference_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
check_dist_sync_on_step=check_dist_sync_on_step,
check_batch=check_batch,
atol=atol,
device=device,
fragment_kwargs=fragment_kwargs,
check_scriptable=check_scriptable,
check_state_dict=check_state_dict,
**kwargs_update,
)
_class_test(rank=0, world_size=1, **common_kwargs, **kwargs_update)

@staticmethod
def run_precision_test_cpu(
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/image/test_rase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from torch import Tensor
from torchmetrics.functional import relative_average_spectral_error
from torchmetrics.functional.image.helper import _uniform_filter
from torchmetrics.functional.image.utils import _uniform_filter
from torchmetrics.image import RelativeAverageSpectralError

from unittests import BATCH_SIZE
Expand Down
44 changes: 16 additions & 28 deletions tests/unittests/image/test_scc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import numpy as np
import pytest
Expand All @@ -35,36 +36,26 @@
_kernels = [torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])]


def _reference_scc(preds, target):
"""Reference implementation of scc from sewar."""
def _reference_sewar_scc(preds, target, hp_filter, window_size, reduction):
"""Wrapper around reference implementation of scc from sewar."""
preds = torch.movedim(preds, 1, -1)
target = torch.movedim(target, 1, -1)
preds = preds.cpu().numpy()
target = target.cpu().numpy()
hp_filter = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
window_size = 8
scc = [
sewar_scc(GT=target[batch], P=preds[batch], win=hp_filter, ws=window_size) for batch in range(preds.shape[0])
]
return np.mean(scc)

if reduction == "mean":
return np.mean(scc)
if reduction == "none":
return scc
return None

def _wrapped_reference_scc(win, ws, reduction):
"""Wrapper around reference implementation of scc from sewar."""

def _wrapped(preds, target):
preds = torch.movedim(preds, 1, -1)
target = torch.movedim(target, 1, -1)
preds = preds.cpu().numpy()
target = target.cpu().numpy()
scc = [sewar_scc(GT=target[batch], P=preds[batch], win=win, ws=ws) for batch in range(preds.shape[0])]
if reduction == "mean":
return np.mean(scc)
if reduction == "none":
return scc
return None

return _wrapped
def _reference_sewar_scc_simple(preds, target):
"""Reference implementation of SCC from sewar."""
hp_filter = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
return _reference_sewar_scc(preds, target, hp_filter, window_size=8, reduction="mean")


@pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs])
Expand All @@ -77,22 +68,19 @@ class TestSpatialCorrelationCoefficient(MetricTester):
def test_scc(self, preds, target, ddp):
"""Test SpatialCorrelationCoefficient class usage."""
self.run_class_metric_test(
ddp, preds, target, metric_class=SpatialCorrelationCoefficient, reference_metric=_reference_scc
ddp, preds, target, metric_class=SpatialCorrelationCoefficient, reference_metric=_reference_sewar_scc_simple
)

@pytest.mark.parametrize("hp_filter", _kernels)
@pytest.mark.parametrize("window_size", [8, 11])
@pytest.mark.parametrize("reduction", ["mean", "none"])
def test_scc_functional(self, preds, target, hp_filter, window_size, reduction):
"""Test SpatialCorrelationCoefficient functional usage."""
kwargs = {"hp_filter": hp_filter, "window_size": window_size, "reduction": reduction}
self.run_functional_metric_test(
preds,
target,
metric_functional=spatial_correlation_coefficient,
reference_metric=_wrapped_reference_scc(hp_filter, window_size, reduction),
metric_args={
"hp_filter": hp_filter,
"window_size": window_size,
"reduction": reduction,
},
reference_metric=partial(_reference_sewar_scc, **kwargs),
metric_args=kwargs,
)

0 comments on commit 42eefbd

Please sign in to comment.