Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Clip_Score to calculate similarities between same modalities #2875

Open
wants to merge 75 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
8302300
Fix: Handle zero division error in binary IoU (Jaccard index) calcula…
rittik9 Sep 9, 2024
9098d0a
chlog
Borda Sep 9, 2024
7803302
Merge branch 'master' into fix/handle-zero-division-iou-calculation
Borda Sep 9, 2024
65b2714
Merge branch 'master' into fix/handle-zero-division-iou-calculation
mergify[bot] Sep 10, 2024
31087e3
Merge branch 'master' into fix/handle-zero-division-iou-calculation
mergify[bot] Sep 10, 2024
b792368
[wip]feat: enchance clip_score to claculate similarity between same m…
rittik9 Dec 19, 2024
74ccbb1
Update CHANGELOG.md
rittik9 Dec 19, 2024
67540e3
Merge branch 'master' into enhance/clip
rittik9 Dec 19, 2024
eb3590c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2024
f2761fd
Update clip_score.py
rittik9 Dec 20, 2024
5af7443
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
ec82ed5
Update clip_score.py
rittik9 Dec 20, 2024
244a4d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
0917814
Update clip_score.py
rittik9 Dec 20, 2024
29c0a9a
Update clip_score.py
rittik9 Dec 20, 2024
4f7a4b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
8124d0e
Update test_clip_score.py
rittik9 Dec 20, 2024
47f4fc8
refactor: clip_score.py
rittik9 Dec 20, 2024
2b46025
refactor
rittik9 Dec 20, 2024
37ab156
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
7557008
refactor: replace deprecated `List` with built-in `list` for type ann…
rittik9 Dec 20, 2024
36f0b72
refactor
rittik9 Dec 20, 2024
f20ecf2
fix: resolve mypy type errors by adding runtime type checks
rittik9 Dec 20, 2024
7b23137
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
043bb0c
refactor: clip_score.py
rittik9 Dec 20, 2024
3af99bb
Merge branch 'enhance/clip' of https://github.com/rittik9/torchmetric…
rittik9 Dec 20, 2024
7283005
refactor: clip_score.py
rittik9 Dec 20, 2024
9f6fc32
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
93c2830
refactor
rittik9 Dec 20, 2024
82ba7d6
Merge branch 'enhance/clip' of https://github.com/rittik9/torchmetric…
rittik9 Dec 20, 2024
1d4f16b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
a18616e
refactor
rittik9 Dec 20, 2024
28fbef4
refactor
rittik9 Dec 20, 2024
c6c433e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
a33d69a
refactor
rittik9 Dec 20, 2024
e6a9ede
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
58262f4
refactor
rittik9 Dec 20, 2024
fe5a42e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2024
76fbcaa
Update clip_score.py
rittik9 Dec 21, 2024
01bc8bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2024
fec61e3
Merge branch 'master' into enhance/clip
rittik9 Dec 21, 2024
7ae0d2f
Update clip_score.py
rittik9 Dec 23, 2024
bd40ffb
Update test_clip_score.py
rittik9 Dec 23, 2024
f41bc55
Merge branch 'master' into enhance/clip
rittik9 Dec 24, 2024
6504ee4
Merge branch 'master' into enhance/clip
rittik9 Dec 26, 2024
89d96f8
Merge branch 'master' into enhance/clip
rittik9 Jan 2, 2025
20a218f
Update test_clip_score.py
rittik9 Jan 2, 2025
6711b4d
Merge branch 'master' into enhance/clip
rittik9 Jan 6, 2025
37c1b8e
Merge branch 'master' into enhance/clip
rittik9 Jan 6, 2025
43d955a
Merge branch 'master' into enhance/clip
rittik9 Jan 7, 2025
fd2b3d2
uncomment test
rittik9 Jan 7, 2025
ad38bb0
Apply suggestions from code review
Borda Jan 7, 2025
db7e199
Merge branch 'Lightning-AI:master' into enhance/clip
rittik9 Jan 7, 2025
caa02ff
Update clip_score.py
rittik9 Jan 7, 2025
4cffc23
Update clip_score.py
rittik9 Jan 7, 2025
4ff62e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
7ab790a
Update clip_score.py
rittik9 Jan 7, 2025
4f476a0
Update clip_score.py
rittik9 Jan 7, 2025
0b29244
update docs
rittik9 Jan 8, 2025
2b7347d
typefix
rittik9 Jan 8, 2025
7c93760
improve _get_features
rittik9 Jan 8, 2025
00ee2e9
improve _get_features docs
rittik9 Jan 8, 2025
81b4405
clip_score.py
rittik9 Jan 9, 2025
cd3663f
Revert "clip_score.py"
rittik9 Jan 9, 2025
03efd65
add tests
rittik9 Jan 9, 2025
b71fe12
add doctest for same modality
rittik9 Jan 9, 2025
9690417
fix device
rittik9 Jan 9, 2025
887be9d
fix doctests
rittik9 Jan 9, 2025
0a56001
fix doctests
rittik9 Jan 9, 2025
3f2a5c3
fix doctests
rittik9 Jan 9, 2025
045faad
fix doctests
rittik9 Jan 9, 2025
4c074c5
add unittests
rittik9 Jan 9, 2025
579d200
add doctests
rittik9 Jan 10, 2025
1571f8c
add random seed in doctests
rittik9 Jan 10, 2025
1156b18
modify doctest
rittik9 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 140 additions & 48 deletions src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, List, Union, cast

import torch
from torch import Tensor
Expand Down Expand Up @@ -41,53 +41,140 @@ def _download_clip_for_clip_score() -> None:
_CLIPProcessor = None


def _detect_modality(input_data: Union[Tensor, List[Tensor], List[str], str]) -> Literal["image", "text"]:
"""Automatically detect the modality of the input data.

Args:
input_data: Input data that can be either image tensors or text strings

Returns:
str: Either "image" or "text"

Raises:
ValueError: If the input_data is an empty list or modality cannot be determined

"""
if isinstance(input_data, Tensor):
return "image"

if isinstance(input_data, list):
if len(input_data) == 0:
raise ValueError("Empty input list")
if isinstance(input_data[0], Tensor):
return "image"
if isinstance(input_data[0], str):
return "text"

if isinstance(input_data, str):
return "text"

raise ValueError("Could not automatically determine modality for input_data")


def _process_image_data(images: Union[Tensor, List[Tensor]]) -> List[Tensor]:
"""Helper function to process image data."""
if isinstance(images, Tensor):
if images.ndim == 3:
return [images]
raise ValueError("Expected all images to be 3d but found image that has either more or less")
if not all(i.ndim == 3 for i in images):
raise ValueError("Expected all images to be 3d but found image that has either more or less")
return images


def _process_text_data(texts: Union[str, List[str]]) -> List[str]:
"""Helper function to process text data."""
if not isinstance(texts, list):
Borda marked this conversation as resolved.
Show resolved Hide resolved
texts = [texts]
return texts


def _get_features(
data: List[Union[Tensor, str]],
modality: Literal["image", "text"],
device: torch.device,
model: "_CLIPModel",
processor: "_CLIPProcessor",
) -> Tensor:
"""Get features from the CLIP model for either images or text.

Args:
data: List of input data (images or text)
modality: Type of input data ("image" or "text")
device: Device to run the model on
model: CLIP model instance
processor: CLIP processor instance
Returns:
Borda marked this conversation as resolved.
Show resolved Hide resolved
Tensor of features from the CLIP model

"""
if modality == "image":
# Add type checking for images
image_data = [i for i in data if isinstance(i, Tensor)]
processed = processor(images=[i.cpu() for i in image_data], return_tensors="pt", padding=True)
features = model.get_image_features(processed["pixel_values"].to(device))
Borda marked this conversation as resolved.
Show resolved Hide resolved
else:
Borda marked this conversation as resolved.
Show resolved Hide resolved
processed = processor(text=data, return_tensors="pt", padding=True)
max_position_embeddings = model.config.text_config.max_position_embeddings
if processed["attention_mask"].shape[-1] > max_position_embeddings:
rank_zero_warn(
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length."
"If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
"longer sequences",
UserWarning,
)
processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings]
processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings]
features = model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device))

return features
Borda marked this conversation as resolved.
Show resolved Hide resolved


def _clip_score_update(
images: Union[Tensor, List[Tensor]],
text: Union[str, list[str]],
source: Union[Tensor, List[Tensor], List[str], str],
target: Union[Tensor, List[Tensor], List[str], str],
model: _CLIPModel,
processor: _CLIPProcessor,
) -> tuple[Tensor, int]:
if not isinstance(images, list):
if images.ndim == 3:
images = [images]
else: # unwrap into list
images = list(images)
source_modality = _detect_modality(source)
target_modality = _detect_modality(target)

if not all(i.ndim == 3 for i in images):
raise ValueError("Expected all images to be 3d but found image that has either more or less")

if not isinstance(text, list):
text = [text]
source_data = (
_process_image_data(cast(Union[Tensor, List[Tensor]], source))
if source_modality == "image"
else _process_text_data(cast(Union[str, List[str]], source))
)
target_data = (
_process_image_data(cast(Union[Tensor, List[Tensor]], target))
if target_modality == "image"
else _process_text_data(cast(Union[str, List[str]], target))
)

if len(text) != len(images):
# Verify matching lengths
if len(source_data) != len(target_data):
raise ValueError(
f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}"
)
device = images[0].device
processed_input = processor(text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True)

img_features = model.get_image_features(processed_input["pixel_values"].to(device))
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)

max_position_embeddings = model.config.text_config.max_position_embeddings
if processed_input["attention_mask"].shape[-1] > max_position_embeddings:
rank_zero_warn(
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length."
"If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
"longer sequences",
UserWarning,
"Expected the number of source and target examples to be the same but got "
f"{len(source_data)} and {len(target_data)}"
)
processed_input["attention_mask"] = processed_input["attention_mask"][..., :max_position_embeddings]
processed_input["input_ids"] = processed_input["input_ids"][..., :max_position_embeddings]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if source_modality == "image" and isinstance(source_data[0], Tensor):
device = source_data[0].device
elif target_modality == "image" and isinstance(target_data[0], Tensor):
device = target_data[0].device
model = model.to(device)

txt_features = model.get_text_features(
processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device)
source_features = _get_features(
cast(List[Union[Tensor, str]], source_data), source_modality, device, model, processor
)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
target_features = _get_features(
cast(List[Union[Tensor, str]], target_data), target_modality, device, model, processor
)
source_features = source_features / source_features.norm(p=2, dim=-1, keepdim=True)
target_features = target_features / target_features.norm(p=2, dim=-1, keepdim=True)

# cosine similarity between feature vectors
score = 100 * (img_features * txt_features).sum(axis=-1)
return score, len(text)
# Calculate cosine similarity
score = 100 * (source_features * target_features).sum(axis=-1)
return score, len(source_data)


def _get_clip_model_and_processor(
Expand All @@ -113,8 +200,8 @@ def _get_clip_model_and_processor(


def clip_score(
images: Union[Tensor, List[Tensor]],
text: Union[str, list[str]],
source: Union[Tensor, List[Tensor], List[str], str],
target: Union[Tensor, List[Tensor], List[str], str],
model_name_or_path: Literal[
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
Expand All @@ -135,15 +222,21 @@ def clip_score(
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer
to 100 the better.

.. caution::
Metric is not scriptable
.. note:: Metric is not scriptable

Args:
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
text: Either a single caption or a list of captions
model_name_or_path: string indicating the version of the CLIP model to use. Available models are
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
and `"openai/clip-vit-large-patch14"`,
source: Source input. This can be:
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
- Text: Either a single caption or a list of captions.
target: Target input. This can be:
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
- Text: Either a single caption or a list of captions.
model_name_or_path: String indicating the version of the CLIP model to use. Available models are:
- `"openai/clip-vit-base-patch16"`
- `"openai/clip-vit-base-patch32"`
- `"openai/clip-vit-large-patch14-336"`
- `"openai/clip-vit-large-patch14"`


Raises:
ModuleNotFoundError:
Expand All @@ -161,7 +254,6 @@ def clip_score(

"""
model, processor = _get_clip_model_and_processor(model_name_or_path)
device = images.device if isinstance(images, Tensor) else images[0].device
score, _ = _clip_score_update(images, text, model.to(device), processor)
score, _ = _clip_score_update(source, target, model, processor)
score = score.mean(0)
return torch.max(score, torch.zeros_like(score))
17 changes: 11 additions & 6 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Sequence, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -118,12 +117,18 @@ def __init__(
self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")

def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, list[str]]) -> None:
def update(
self, source: Union[Tensor, List[Tensor], List[str], str], target: Union[Tensor, List[Tensor], List[str], str]
) -> None:
"""Update CLIP score on a batch of images and text.

Args:
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
text: Either a single caption or a list of captions
source: Source input. This can be:
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
- Text: Either a single caption or a list of captions.
target: Target input. This can be:
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
- Text: Either a single caption or a list of captions.

Raises:
ValueError:
Expand All @@ -132,7 +137,7 @@ def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, list[str]
If the number of images and captions do not match

"""
score, n_samples = _clip_score_update(images, text, self.model, self.processor)
score, n_samples = _clip_score_update(source, target, self.model, self.processor)
self.score += score.sum(0)
self.n_samples += n_samples

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/multimodal/test_clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_clip_score_differentiability(self, inputs, model_name_or_path):
def test_error_on_not_same_amount_of_input(self, inputs, model_name_or_path):
"""Test that an error is raised if the number of images and text examples does not match."""
metric = CLIPScore(model_name_or_path=model_name_or_path)
with pytest.raises(ValueError, match="Expected the number of images and text examples to be the same.*"):
with pytest.raises(ValueError, match="Expected the number of source and target examples to be the same.*"):
metric(torch.randint(255, (2, 3, 64, 64)), "28-year-old chef found dead in San Francisco mall")

@skip_on_connection_issues()
Expand Down
Loading