From 81b44057af61e4b4315578f34eb5feed7cd70acc Mon Sep 17 00:00:00 2001 From: rittik9 Date: Thu, 9 Jan 2025 06:53:58 +0000 Subject: [PATCH] clip_score.py --- src/torchmetrics/functional/multimodal/clip_score.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 140f0e2de76..81a910a0f3a 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -163,12 +163,8 @@ def _clip_score_update( device = target_data[0].device model = model.to(device) - source_features = _get_features( - cast(List[Union[Tensor, str]], source_data), source_modality, device, model, processor - ) - target_features = _get_features( - cast(List[Union[Tensor, str]], target_data), target_modality, device, model, processor - ) + source_features = _get_features(source_data, source_modality, device, model, processor) + target_features = _get_features(target_data, target_modality, device, model, processor) source_features = source_features / source_features.norm(p=2, dim=-1, keepdim=True) target_features = target_features / target_features.norm(p=2, dim=-1, keepdim=True)