Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Jan 9, 2025
1 parent cd3663f commit 03efd65
Showing 1 changed file with 66 additions and 1 deletion.
67 changes: 66 additions & 1 deletion tests/unittests/multimodal/test_clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

from torchmetrics.functional.multimodal.clip_score import clip_score
from torchmetrics.functional.multimodal.clip_score import (
_detect_modality,
_process_image_data,
_process_text_data,
clip_score,
)
from torchmetrics.multimodal.clip_score import CLIPScore
from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10
from unittests._helpers import seed_all, skip_on_connection_issues
Expand Down Expand Up @@ -143,3 +148,63 @@ def test_warning_on_long_caption(self, inputs, model_name_or_path):
match="Encountered caption longer than max_position_embeddings=77. Will truncate captions to this length.*",
):
metric.update(preds[0], target[0])


@pytest.mark.parametrize(
("input_data", "expected"),
[
(torch.randn(3, 64, 64), "image"),
([torch.randn(3, 64, 64)], "image"),
("some text", "text"),
(["text1", "text2"], "text"),
],
)
def test_detect_modality(input_data, expected):
"""Test that modality detection works correctly."""
assert _detect_modality(input_data) == expected

with pytest.raises(ValueError, match="Empty input list"):
_detect_modality([])

with pytest.raises(ValueError, match="Could not automatically determine modality"):
_detect_modality(123)


@pytest.mark.parametrize(
("images", "expected_len", "should_raise"),
[
(torch.randn(3, 64, 64), 1, False),
(torch.randn(2, 3, 64, 64), 2, False),
([torch.randn(3, 64, 64)], 1, False),
([torch.randn(3, 64, 64), torch.randn(3, 64, 64)], 2, False),
(torch.randn(64, 64), 0, True),
([torch.randn(64, 64)], 0, True),
],
)
def test_process_image_data(images, expected_len, should_raise):
"""Test that image processing works correctly."""
if should_raise:
with pytest.raises(ValueError, match="Expected all images to be 3d"):
_process_image_data(images)
else:
processed = _process_image_data(images)
assert isinstance(processed, list)
assert len(processed) == expected_len
assert all(isinstance(img, Tensor) and img.ndim == 3 for img in processed)


@pytest.mark.parametrize(
("texts", "expected_len"),
[
("single text", 1),
(["text1", "text2"], 2),
([""], 1),
([], 0),
],
)
def test_process_text_data(texts, expected_len):
"""Test that text processing works correctly."""
processed = _process_text_data(texts)
assert isinstance(processed, list)
assert len(processed) == expected_len
assert all(isinstance(text, str) for text in processed)

0 comments on commit 03efd65

Please sign in to comment.