Skip to content

Commit

Permalink
fixes device issue when using any other device than cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
aakashks committed Dec 29, 2024
1 parent 46423f8 commit 359ab5f
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,23 @@ def __init__(
model_name=model_name, pretrained=checkpoint
)
self._model = model
self._model.to(device)
self._device = device
self._model.to(self._device)
self._preprocess = preprocess
self._tokenizer = open_clip.get_tokenizer(model_name=model_name)

def _encode_image(self, image: Image) -> Embedding:
pil_image = self._PILImage.fromarray(image)
with self._torch.no_grad():
image_features = self._model.encode_image(
self._preprocess(pil_image).unsqueeze(0)
self._preprocess(pil_image).unsqueeze(0).to(self._device)
)
image_features /= image_features.norm(dim=-1, keepdim=True)
return cast(Embedding, image_features.squeeze().cpu().numpy())

def _encode_text(self, text: Document) -> Embedding:
with self._torch.no_grad():
text_features = self._model.encode_text(self._tokenizer(text))
text_features = self._model.encode_text(self._tokenizer(text).to(self._device))
text_features /= text_features.norm(dim=-1, keepdim=True)
return cast(Embedding, text_features.squeeze().cpu().numpy())

Expand Down

0 comments on commit 359ab5f

Please sign in to comment.