From aea28391eeb91ea1c4946244938efb307060863a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 11 Nov 2024 12:46:51 +0100 Subject: [PATCH] fix tests --- tests/unittests/segmentation/inputs.py | 4 ++++ tests/unittests/segmentation/test_dice.py | 3 ++- tests/unittests/segmentation/test_generalized_dice_score.py | 3 ++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unittests/segmentation/inputs.py b/tests/unittests/segmentation/inputs.py index b773ba29ebd..ce46c8b597c 100644 --- a/tests/unittests/segmentation/inputs.py +++ b/tests/unittests/segmentation/inputs.py @@ -34,3 +34,7 @@ preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)), ) +_input4 = _Input( + preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)), + target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)), +) diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index d5bfc08b4ae..b009401f481 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -22,7 +22,7 @@ from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester -from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 +from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3 seed_all(42) @@ -55,6 +55,7 @@ def _reference_dice_score( (_inputs1.preds, _inputs1.target, "one-hot"), (_inputs2.preds, _inputs2.target, "one-hot"), (_inputs3.preds, _inputs3.target, "index"), + (_input4.preds, _input4.target, "index"), ], ) @pytest.mark.parametrize("include_background", [True, False]) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 6c353800379..c87fd6aa22e 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -24,7 +24,7 @@ from unittests import NUM_CLASSES from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester -from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3 +from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3 seed_all(42) @@ -53,6 +53,7 @@ def _reference_generalized_dice( (_inputs1.preds, _inputs1.target, "one-hot"), (_inputs2.preds, _inputs2.target, "one-hot"), (_inputs3.preds, _inputs3.target, "index"), + (_input4.preds, _input4.target, "index"), ], ) @pytest.mark.parametrize("include_background", [True, False])