diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 070d81bf54c..fcdcea1d979 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, List, Union, cast import torch from torch import Tensor @@ -41,53 +41,140 @@ def _download_clip_for_clip_score() -> None: _CLIPProcessor = None +def _detect_modality(input_data: Union[Tensor, List[Tensor], List[str], str]) -> Literal["image", "text"]: + """Automatically detect the modality of the input data. + + Args: + input_data: Input data that can be either image tensors or text strings + + Returns: + str: Either "image" or "text" + + Raises: + ValueError: If the input_data is an empty list or modality cannot be determined + + """ + if isinstance(input_data, Tensor): + return "image" + + if isinstance(input_data, list): + if len(input_data) == 0: + raise ValueError("Empty input list") + if isinstance(input_data[0], Tensor): + return "image" + if isinstance(input_data[0], str): + return "text" + + if isinstance(input_data, str): + return "text" + + raise ValueError("Could not automatically determine modality for input_data") + + +def _process_image_data(images: Union[Tensor, List[Tensor]]) -> List[Tensor]: + """Helper function to process image data.""" + images = [images] if not isinstance(images, list) and images.ndim == 3 else list(images) + if not all(i.ndim == 3 for i in images): + raise ValueError("Expected all images to be 3d but found image that has either more or less") + return images + + +def _process_text_data(texts: Union[str, List[str]]) -> List[str]: + """Helper function to process text data.""" + if not isinstance(texts, list): + texts = [texts] + return texts + + +def _get_features( + data: List[Union[Tensor, str]], + modality: str, + device: torch.device, + model: "_CLIPModel", + processor: "_CLIPProcessor", +) -> Tensor: + """Get features from the CLIP model for either images or text. + + Args: + data: List of input data (images or text) + modality: String indicating the type of input data (must be either "image" or "text") + device: Device to run the model on + model: CLIP model instance + processor: CLIP processor instance + + Returns: + Tensor of features from the CLIP model + + Raises: + ValueError: If modality is not "image" or "text" + + """ + if modality == "image": + # Add type checking for images + image_data = [i for i in data if isinstance(i, Tensor)] + processed = processor(images=[i.cpu() for i in image_data], return_tensors="pt", padding=True) + return model.get_image_features(processed["pixel_values"].to(device)) + if modality == "text": + processed = processor(text=data, return_tensors="pt", padding=True) + max_position_embeddings = model.config.text_config.max_position_embeddings + if processed["attention_mask"].shape[-1] > max_position_embeddings: + rank_zero_warn( + f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length." + "If longer captions are needed, initialize argument `model_name_or_path` with a model that supports" + "longer sequences", + UserWarning, + ) + processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings] + processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings] + return model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device)) + raise ValueError(f"invalid modality {modality}") + + def _clip_score_update( - images: Union[Tensor, List[Tensor]], - text: Union[str, list[str]], + source: Union[Tensor, List[Tensor], List[str], str], + target: Union[Tensor, List[Tensor], List[str], str], model: _CLIPModel, processor: _CLIPProcessor, ) -> tuple[Tensor, int]: - if not isinstance(images, list): - if images.ndim == 3: - images = [images] - else: # unwrap into list - images = list(images) + source_modality = _detect_modality(source) + target_modality = _detect_modality(target) - if not all(i.ndim == 3 for i in images): - raise ValueError("Expected all images to be 3d but found image that has either more or less") - - if not isinstance(text, list): - text = [text] + source_data = ( + _process_image_data(cast(Union[Tensor, List[Tensor]], source)) + if source_modality == "image" + else _process_text_data(cast(Union[str, List[str]], source)) + ) + target_data = ( + _process_image_data(cast(Union[Tensor, List[Tensor]], target)) + if target_modality == "image" + else _process_text_data(cast(Union[str, List[str]], target)) + ) - if len(text) != len(images): + if len(source_data) != len(target_data): raise ValueError( - f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}" + "Expected the number of source and target examples to be the same but got " + f"{len(source_data)} and {len(target_data)}" ) - device = images[0].device - processed_input = processor(text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True) - - img_features = model.get_image_features(processed_input["pixel_values"].to(device)) - img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) - - max_position_embeddings = model.config.text_config.max_position_embeddings - if processed_input["attention_mask"].shape[-1] > max_position_embeddings: - rank_zero_warn( - f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length." - "If longer captions are needed, initialize argument `model_name_or_path` with a model that supports" - "longer sequences", - UserWarning, - ) - processed_input["attention_mask"] = processed_input["attention_mask"][..., :max_position_embeddings] - processed_input["input_ids"] = processed_input["input_ids"][..., :max_position_embeddings] - txt_features = model.get_text_features( - processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device) + device = torch.device("cpu") + if source_modality == "image" and isinstance(source_data[0], Tensor): + device = source_data[0].device + elif target_modality == "image" and isinstance(target_data[0], Tensor): + device = target_data[0].device + model = model.to(device) + + source_features = _get_features( + cast(List[Union[Tensor, str]], source_data), source_modality, device, model, processor + ) + target_features = _get_features( + cast(List[Union[Tensor, str]], target_data), target_modality, device, model, processor ) - txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) + source_features = source_features / source_features.norm(p=2, dim=-1, keepdim=True) + target_features = target_features / target_features.norm(p=2, dim=-1, keepdim=True) - # cosine similarity between feature vectors - score = 100 * (img_features * txt_features).sum(axis=-1) - return score, len(text) + # Calculate cosine similarity + score = 100 * (source_features * target_features).sum(axis=-1) + return score, len(source_data) def _get_clip_model_and_processor( @@ -113,8 +200,8 @@ def _get_clip_model_and_processor( def clip_score( - images: Union[Tensor, List[Tensor]], - text: Union[str, list[str]], + source: Union[Tensor, List[Tensor], List[str], str], + target: Union[Tensor, List[Tensor], List[str], str], model_name_or_path: Literal[ "openai/clip-vit-base-patch16", "openai/clip-vit-base-patch32", @@ -122,11 +209,11 @@ def clip_score( "openai/clip-vit-large-patch14", ] = "openai/clip-vit-large-patch14", ) -> Tensor: - r"""Calculate `CLIP Score`_ which is a text-to-image similarity metric. + r"""Calculates `CLIP Score`_ which is a text-to-image similarity metric. CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for - an image and the actual content of the image. It has been found to be highly correlated with human judgement. The - metric is defined as: + an image and the actual content of the image, as well as the similarity between texts or images. It has been found + to be highly correlated with human judgement. The metric is defined as: .. math:: \text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0) @@ -135,15 +222,33 @@ def clip_score( textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer to 100 the better. - .. caution:: - Metric is not scriptable + Additionally, the CLIP Score can be calculated for the same modalities: + + .. math:: + \text{CLIPScore(I_1, I_2)} = max(100 * cos(E_{I_1}, E_{I_2}), 0) + + where :math:`E_{I_1}` and :math:`E_{I_2}` are the visual embeddings for images :math:`I_1` and :math:`I_2`. + + .. math:: + \text{CLIPScore(T_1, T_2)} = max(100 * cos(E_{T_1}, E_{T_2}), 0) + + where :math:`E_{T_1}` and :math:`E_{T_2}` are the textual embeddings for texts :math:`T_1` and :math:`T_2`. + + .. note:: Metric is not scriptable Args: - images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors - text: Either a single caption or a list of captions - model_name_or_path: string indicating the version of the CLIP model to use. Available models are - `"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"` - and `"openai/clip-vit-large-patch14"`, + source: Source input. This can be: + - Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors. + - Text: Either a single caption or a list of captions. + target: Target input. This can be: + - Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors. + - Text: Either a single caption or a list of captions. + model_name_or_path: String indicating the version of the CLIP model to use. Available models are: + - `"openai/clip-vit-base-patch16"` + - `"openai/clip-vit-base-patch32"` + - `"openai/clip-vit-large-patch14-336"` + - `"openai/clip-vit-large-patch14"` + Raises: ModuleNotFoundError: @@ -155,13 +260,31 @@ def clip_score( Example: >>> from torchmetrics.functional.multimodal import clip_score - >>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16") + >>> image = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42)) + >>> score = clip_score(image, "a photo of a cat", "openai/clip-vit-base-patch16") >>> score.detach() tensor(24.4255) + Example: + >>> from torchmetrics.functional.multimodal import clip_score + >>> image1 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42)) + >>> image2 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(43)) + >>> score = clip_score(image1, image2, "openai/clip-vit-base-patch16") + >>> score.detach() + tensor(99.4859) + + Example: + >>> from torchmetrics.functional.multimodal import clip_score + >>> score = clip_score( + ... "28-year-old chef found dead in San Francisco mall", + ... "A 28-year-old chef who recently moved to San Francisco was found dead.", + ... "openai/clip-vit-base-patch16" + ... ) + >>> score.detach() + tensor(91.3950) + """ model, processor = _get_clip_model_and_processor(model_name_or_path) - device = images.device if isinstance(images, Tensor) else images[0].device - score, _ = _clip_score_update(images, text, model.to(device), processor) + score, _ = _clip_score_update(source, target, model, processor) score = score.mean(0) return torch.max(score, torch.zeros_like(score)) diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index c89384fbb35..690a371c17d 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor @@ -45,8 +44,8 @@ class CLIPScore(Metric): r"""Calculates `CLIP Score`_ which is a text-to-image similarity metric. CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for - an image and the actual content of the image. It has been found to be highly correlated with human judgement. The - metric is defined as: + an image and the actual content of the image, as well as the similarity between texts or images. It has been found + to be highly correlated with human judgement. The metric is defined as: .. math:: \text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0) @@ -55,15 +54,33 @@ class CLIPScore(Metric): textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer to 100 the better. + Additionally, the CLIP Score can be calculated for the same modalities: + + .. math:: + \text{CLIPScore(I_1, I_2)} = max(100 * cos(E_{I_1}, E_{I_2}), 0) + + where :math:`E_{I_1}` and :math:`E_{I_2}` are the visual embeddings for images :math:`I_1` and :math:`I_2`. + + .. math:: + \text{CLIPScore(T_1, T_2)} = max(100 * cos(E_{T_1}, E_{T_2}), 0) + + where :math:`E_{T_1}` and :math:`E_{T_2}` are the textual embeddings for texts :math:`T_1` and :math:`T_2`. + .. caution:: Metric is not scriptable As input to ``forward`` and ``update`` the metric accepts the following input - - ``images`` (:class:`~torch.Tensor` or list of tensors): tensor with images feed to the feature extractor with. If - a single tensor it should have shape ``(N, C, H, W)``. If a list of tensors, each tensor should have shape - ``(C, H, W)``. ``C`` is the number of channels, ``H`` and ``W`` are the height and width of the image. - - ``text`` (:class:`~str` or :class:`~list` of :class:`~str`): text to compare with the images, one for each image. + - source: Source input. This can be: + - Images: (:class:`~torch.Tensor` or list of tensors): tensor with images feed to the feature extractor with. If + a single tensor it should have shape ``(N, C, H, W)``. If a list of tensors, each tensor should have shape + ``(C, H, W)``. ``C`` is the number of channels, ``H`` and ``W`` are the height and width of the image. + - Text: (:class:`~str` or :class:`~list` of :class:`~str`): text to compare with the images, one for each image. + - target: Target input. This can be: + - Images: (:class:`~torch.Tensor` or list of tensors): tensor with images feed to the feature extractor with. If + a single tensor it should have shape ``(N, C, H, W)``. If a list of tensors, each tensor should have shape + ``(C, H, W)``. ``C`` is the number of channels, ``H`` and ``W`` are the height and width of the image. + - Text: (:class:`~str` or :class:`~list` of :class:`~str`): text to compare with the images, one for each image. As output of `forward` and `compute` the metric returns the following output @@ -84,12 +101,29 @@ class CLIPScore(Metric): If transformers package is not installed or version is lower than 4.10.0 Example: - >>> from torch import randint >>> from torchmetrics.multimodal.clip_score import CLIPScore >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") - >>> score = metric(randint(255, (3, 224, 224)), "a photo of a cat") + >>> image = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42)) + >>> score = metric(image, "a photo of a cat") + >>> score.detach().round() + tensor(24.) + + Example: + >>> from torchmetrics.multimodal.clip_score import CLIPScore + >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") + >>> image1 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42)) + >>> image2 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(43)) + >>> score = metric(image1, image2) >>> score.detach().round() - tensor(25.) + tensor(99.) + + Example: + >>> from torchmetrics.multimodal.clip_score import CLIPScore + >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") + >>> score = metric("28-year-old chef found dead in San Francisco mall", + ... "A 28-year-old chef who recently moved to San Francisco was found dead.") + >>> score.detach().round() + tensor(91.) """ @@ -118,12 +152,18 @@ def __init__( self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum") - def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, list[str]]) -> None: + def update( + self, source: Union[Tensor, List[Tensor], List[str], str], target: Union[Tensor, List[Tensor], List[str], str] + ) -> None: """Update CLIP score on a batch of images and text. Args: - images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors - text: Either a single caption or a list of captions + source: Source input. This can be: + - Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors. + - Text: Either a single caption or a list of captions. + target: Target input. This can be: + - Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors. + - Text: Either a single caption or a list of captions. Raises: ValueError: @@ -132,7 +172,7 @@ def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, list[str] If the number of images and captions do not match """ - score, n_samples = _clip_score_update(images, text, self.model, self.processor) + score, n_samples = _clip_score_update(source, target, self.model, self.processor) self.score += score.sum(0) self.n_samples += n_samples diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index 491d64a8a78..cdabbbf502a 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -22,7 +22,12 @@ from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor -from torchmetrics.functional.multimodal.clip_score import clip_score +from torchmetrics.functional.multimodal.clip_score import ( + _detect_modality, + _process_image_data, + _process_text_data, + clip_score, +) from torchmetrics.multimodal.clip_score import CLIPScore from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 from unittests._helpers import seed_all, skip_on_connection_issues @@ -110,7 +115,7 @@ def test_clip_score_differentiability(self, inputs, model_name_or_path): def test_error_on_not_same_amount_of_input(self, inputs, model_name_or_path): """Test that an error is raised if the number of images and text examples does not match.""" metric = CLIPScore(model_name_or_path=model_name_or_path) - with pytest.raises(ValueError, match="Expected the number of images and text examples to be the same.*"): + with pytest.raises(ValueError, match="Expected the number of source and target examples to be the same.*"): metric(torch.randint(255, (2, 3, 64, 64)), "28-year-old chef found dead in San Francisco mall") @skip_on_connection_issues() @@ -143,3 +148,93 @@ def test_warning_on_long_caption(self, inputs, model_name_or_path): match="Encountered caption longer than max_position_embeddings=77. Will truncate captions to this length.*", ): metric.update(preds[0], target[0]) + + @skip_on_connection_issues() + def test_clip_score_image_to_image(self, inputs, model_name_or_path): + """Test CLIP score for image-to-image comparison.""" + metric = CLIPScore(model_name_or_path=model_name_or_path) + preds, _ = inputs + score = metric(preds[0][0], preds[0][1]) + assert score.detach().round() == torch.tensor(96.0) + + @skip_on_connection_issues() + def test_clip_score_text_to_text(self, inputs, model_name_or_path): + """Test CLIP score for text-to-text comparison.""" + metric = CLIPScore(model_name_or_path=model_name_or_path) + _, target = inputs + score = metric(target[0][0], target[0][1]) + assert score.detach().round() == torch.tensor(65.0) + + @skip_on_connection_issues() + def test_clip_score_functional_image_to_image(self, inputs, model_name_or_path): + """Test functional implementation of image-to-image CLIP score.""" + preds, _ = inputs + score = clip_score(preds[0][0], preds[0][1], model_name_or_path=model_name_or_path) + assert score.detach().round() == torch.tensor(96.0) + + @skip_on_connection_issues() + def test_clip_score_functional_text_to_text(self, inputs, model_name_or_path): + """Test functional implementation of text-to-text CLIP score.""" + _, target = inputs + score = clip_score(target[0][0], target[0][1], model_name_or_path=model_name_or_path) + assert score.detach().round() == torch.tensor(65.0) + + +@pytest.mark.parametrize( + ("input_data", "expected"), + [ + (torch.randn(3, 64, 64), "image"), + ([torch.randn(3, 64, 64)], "image"), + ("some text", "text"), + (["text1", "text2"], "text"), + ], +) +def test_detect_modality(input_data, expected): + """Test that modality detection works correctly.""" + assert _detect_modality(input_data) == expected + + with pytest.raises(ValueError, match="Empty input list"): + _detect_modality([]) + + with pytest.raises(ValueError, match="Could not automatically determine modality"): + _detect_modality(123) + + +@pytest.mark.parametrize( + ("images", "expected_len", "should_raise"), + [ + (torch.randn(3, 64, 64), 1, False), + (torch.randn(2, 3, 64, 64), 2, False), + ([torch.randn(3, 64, 64)], 1, False), + ([torch.randn(3, 64, 64), torch.randn(3, 64, 64)], 2, False), + (torch.randn(64, 64), 0, True), + ([torch.randn(64, 64)], 0, True), + ], +) +def test_process_image_data(images, expected_len, should_raise): + """Test that image processing works correctly.""" + if should_raise: + with pytest.raises(ValueError, match="Expected all images to be 3d"): + _process_image_data(images) + else: + processed = _process_image_data(images) + assert isinstance(processed, list) + assert len(processed) == expected_len + assert all(isinstance(img, Tensor) and img.ndim == 3 for img in processed) + + +@pytest.mark.parametrize( + ("texts", "expected_len"), + [ + ("single text", 1), + (["text1", "text2"], 2), + ([""], 1), + ([], 0), + ], +) +def test_process_text_data(texts, expected_len): + """Test that text processing works correctly.""" + processed = _process_text_data(texts) + assert isinstance(processed, list) + assert len(processed) == expected_len + assert all(isinstance(text, str) for text in processed)