From db395b534a389933f91724f07cf5e3f9c9390d0f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 23 Nov 2023 10:25:58 +0100 Subject: [PATCH 1/3] fix implementation --- src/torchmetrics/functional/image/lpips.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 1c6e1b58906..63a708969c0 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -426,6 +426,6 @@ def learned_perceptual_image_patch_similarity( tensor(0.1008, grad_fn=) """ - net = _NoTrainLpips(net=net_type) + net = _NoTrainLpips(net=net_type).to(device=img1.device, dtype=img1.dtype) loss, total = _lpips_update(img1, img2, net, normalize) return _lpips_compute(loss.sum(), total, reduction) From b3d9e0e95634445e62209a7e797ddd31f3aa291c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 23 Nov 2023 10:27:02 +0100 Subject: [PATCH 2/3] tests --- tests/unittests/image/test_lpips.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index c29730be0b2..e7e535f191b 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -18,6 +18,7 @@ import torch from lpips import LPIPS as LPIPS_reference # noqa: N811 from torch import Tensor +from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.utilities.imports import _LPIPS_AVAILABLE @@ -68,6 +69,16 @@ def test_lpips(self, net_type, ddp): metric_args={"net_type": net_type}, ) + def test_lpips_functional(self): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=_inputs.img1, + target=_inputs.img2, + metric_functional=learned_perceptual_image_patch_similarity, + reference_metric=partial(_compare_fn, net_type="alex"), + metric_args={"net_type": "alex"}, + ) + def test_lpips_differentiability(self): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" self.run_differentiability_test( From 660840a0c6f86e0f36a8ca0083ff9faf7e16cf0a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 23 Nov 2023 10:29:28 +0100 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72549469d59..5d0b6db3d7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed numerical stability issue in `UniversalImageQualityIndex` metric ([#2222](https://github.com/Lightning-AI/torchmetrics/pull/2222)) +- Fix device and dtype for `LearnedPerceptualImagePatchSimilarity` functional metric ([#2234](https://github.com/Lightning-AI/torchmetrics/pull/2234)) + + ## [1.2.0] - 2023-09-22 ### Added