diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index 62d12c4dab1..b948e1889c7 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -22,7 +22,12 @@ from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor -from torchmetrics.functional.multimodal.clip_score import clip_score +from torchmetrics.functional.multimodal.clip_score import ( + _detect_modality, + _process_image_data, + _process_text_data, + clip_score, +) from torchmetrics.multimodal.clip_score import CLIPScore from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 from unittests._helpers import seed_all, skip_on_connection_issues @@ -143,3 +148,63 @@ def test_warning_on_long_caption(self, inputs, model_name_or_path): match="Encountered caption longer than max_position_embeddings=77. Will truncate captions to this length.*", ): metric.update(preds[0], target[0]) + + +@pytest.mark.parametrize( + ("input_data", "expected"), + [ + (torch.randn(3, 64, 64), "image"), + ([torch.randn(3, 64, 64)], "image"), + ("some text", "text"), + (["text1", "text2"], "text"), + ], +) +def test_detect_modality(input_data, expected): + """Test that modality detection works correctly.""" + assert _detect_modality(input_data) == expected + + with pytest.raises(ValueError, match="Empty input list"): + _detect_modality([]) + + with pytest.raises(ValueError, match="Could not automatically determine modality"): + _detect_modality(123) + + +@pytest.mark.parametrize( + ("images", "expected_len", "should_raise"), + [ + (torch.randn(3, 64, 64), 1, False), + (torch.randn(2, 3, 64, 64), 2, False), + ([torch.randn(3, 64, 64)], 1, False), + ([torch.randn(3, 64, 64), torch.randn(3, 64, 64)], 2, False), + (torch.randn(64, 64), 0, True), + ([torch.randn(64, 64)], 0, True), + ], +) +def test_process_image_data(images, expected_len, should_raise): + """Test that image processing works correctly.""" + if should_raise: + with pytest.raises(ValueError, match="Expected all images to be 3d"): + _process_image_data(images) + else: + processed = _process_image_data(images) + assert isinstance(processed, list) + assert len(processed) == expected_len + assert all(isinstance(img, Tensor) and img.ndim == 3 for img in processed) + + +@pytest.mark.parametrize( + ("texts", "expected_len"), + [ + ("single text", 1), + (["text1", "text2"], 2), + ([""], 1), + ([], 0), + ], +) +def test_process_text_data(texts, expected_len): + """Test that text processing works correctly.""" + processed = _process_text_data(texts) + assert isinstance(processed, list) + assert len(processed) == expected_len + assert all(isinstance(text, str) for text in processed)