Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Nov 11, 2024
1 parent 15006ae commit aea2839
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
4 changes: 4 additions & 0 deletions tests/unittests/segmentation/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@
preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)),
target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32, 32)),
)
_input4 = _Input(
preds=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)),
target=torch.randint(0, NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, 32)),
)
3 changes: 2 additions & 1 deletion tests/unittests/segmentation/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from unittests import NUM_CLASSES
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester
from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3
from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3

seed_all(42)

Expand Down Expand Up @@ -55,6 +55,7 @@ def _reference_dice_score(
(_inputs1.preds, _inputs1.target, "one-hot"),
(_inputs2.preds, _inputs2.target, "one-hot"),
(_inputs3.preds, _inputs3.target, "index"),
(_input4.preds, _input4.target, "index"),
],
)
@pytest.mark.parametrize("include_background", [True, False])
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/segmentation/test_generalized_dice_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from unittests import NUM_CLASSES
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester
from unittests.segmentation.inputs import _inputs1, _inputs2, _inputs3
from unittests.segmentation.inputs import _input4, _inputs1, _inputs2, _inputs3

seed_all(42)

Expand Down Expand Up @@ -53,6 +53,7 @@ def _reference_generalized_dice(
(_inputs1.preds, _inputs1.target, "one-hot"),
(_inputs2.preds, _inputs2.target, "one-hot"),
(_inputs3.preds, _inputs3.target, "index"),
(_input4.preds, _input4.target, "index"),
],
)
@pytest.mark.parametrize("include_background", [True, False])
Expand Down

0 comments on commit aea2839

Please sign in to comment.