Skip to content

Commit

Permalink
Feat/multimodal (#45)
Browse files Browse the repository at this point in the history
* inference_mode instead of no_grad

* cleaner verbosity

* add image prep

* add image documents

* monovlm ranker

* version bump

* update pyproject

* wip

* wip

* example fully functional
  • Loading branch information
bclavie authored Nov 12, 2024
1 parent ef30e19 commit c4be8ff
Show file tree
Hide file tree
Showing 12 changed files with 666 additions and 27 deletions.
360 changes: 360 additions & 0 deletions examples/reranker_images.ipynb

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ packages = [
name = "rerankers"


version = "0.5.3"
version = "0.6.0"

description = "A unified API for various document re-ranking models."

Expand Down Expand Up @@ -53,22 +53,26 @@ dependencies = [

[project.optional-dependencies]
all = [
"transformers",
"transformers>=4.45.0",
"torch",
"litellm",
"requests",
"sentencepiece",
"protobuf",
"flashrank",
"flash-attn",
"pillow",
"accelerate>=0.26.0",
"peft>=0.13.0",
"nmslib-metabrainz; python_version >= '3.10'",
"rank-llm; python_version >= '3.10'"
]
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
transformers = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf"]
api = ["requests"]
gpt = ["litellm"]
flashrank = ["flashrank"]
llmlayerwise = ["transformers", "torch", "sentencepiece", "protobuf", "flash-attn"]
llmlayerwise = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf", "flash-attn"]
monovlm = ["transformers>=4.45.0", "torch", "sentencepiece", "protobuf", "flash-attn", "pillow", "accelerate>=0.26.0", "peft>=0.13.0"]
rankllm = [
"nmslib-metabrainz; python_version >= '3.10'",
"rank-llm; python_version >= '3.10'"
Expand Down
2 changes: 1 addition & 1 deletion rerankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from rerankers.documents import Document

__all__ = ["Reranker", "Document"]
__version__ = "0.5.3"
__version__ = "0.6.0"
29 changes: 24 additions & 5 deletions rerankers/documents.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
from typing import Optional, Union
from pydantic import BaseModel
from typing import Optional, Union, Literal
from pydantic import BaseModel, validator


class Document(BaseModel):
text: str
document_type: Literal["text", "image"] = "text"
text: Optional[str] = None
base64: Optional[str] = None
image_path: Optional[str] = None
doc_id: Optional[Union[str, int]] = None
metadata: Optional[dict] = None

@validator("text")
def validate_text(cls, v, values):
if values.get("document_type") == "text" and v is None:
raise ValueError("text field is required when document_type is 'text'")
return v

def __init__(
self,
text: str,
text: Optional[str] = None,
doc_id: Optional[Union[str, int]] = None,
metadata: Optional[dict] = None,
document_type: Literal["text", "image"] = "text",
image_path: Optional[str] = None,
base64: Optional[str] = None,
):
super().__init__(text=text, doc_id=doc_id, metadata=metadata)
super().__init__(
text=text,
doc_id=doc_id,
metadata=metadata,
document_type=document_type,
base64=base64,
image_path=image_path,
)
6 changes: 6 additions & 0 deletions rerankers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@
AVAILABLE_RANKERS["LLMLayerWiseRanker"] = LLMLayerWiseRanker
except ImportError:
pass

try:
from rerankers.models.monovlm_ranker import MonoVLMRanker
AVAILABLE_RANKERS["MonoVLMRanker"] = MonoVLMRanker
except ImportError:
pass
6 changes: 3 additions & 3 deletions rerankers/models/colbert_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _document_encode(self, documents: list[str]):
return self._encode(documents, self.document_token_id)

def _to_embs(self, encoding) -> torch.Tensor:
with torch.no_grad():
with torch.inference_mode():
# embs = self.model(**encoding).last_hidden_state.squeeze(1)
embs = self.model(**encoding)
if self.normalize:
Expand Down Expand Up @@ -271,7 +271,7 @@ def score(self, query: str, doc: str) -> float:
scores = self._colbert_rank(query, [doc])
return scores[0] if scores else 0.0

@torch.no_grad()
@torch.inference_mode()
def _colbert_rank(
self,
query: str,
Expand Down Expand Up @@ -377,7 +377,7 @@ def _encode(
return encoding

def _to_embs(self, encoding) -> torch.Tensor:
with torch.no_grad():
with torch.inference_mode():
batched_embs = []
for i in range(0, encoding["input_ids"].size(0), self.batch_size):
batch_encoding = {
Expand Down
4 changes: 2 additions & 2 deletions rerankers/models/llm_layerwise_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _get_inputs(self, pairs, max_sequence_length: int):
return_tensors="pt",
)

@torch.no_grad()
@torch.inference_mode()
def rank(
self,
query: str,
Expand Down Expand Up @@ -177,7 +177,7 @@ def rank(
]
return RankedResults(results=ranked_results, query=query, has_scores=True)

@torch.no_grad()
@torch.inference_mode()
def score(self, query: str, doc: str) -> float:
inputs = self._get_inputs(
[(query, doc)], max_sequence_length=self.max_sequence_length
Expand Down
164 changes: 164 additions & 0 deletions rerankers/models/monovlm_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch
from PIL import Image
import base64
import io
# TODO: Support more than Qwen
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from rerankers.models.ranker import BaseRanker
from rerankers.documents import Document
from typing import Union, List, Optional
from rerankers.utils import vprint, get_device, get_dtype, prep_image_docs
from rerankers.results import RankedResults, Result

PREDICTION_TOKENS = {
"default": ["False", "True"],
"lightonai/MonoQwen2-VL-v0.1": ["False", "True"]
}

def _get_output_tokens(model_name_or_path, token_false: str, token_true: str):
if token_false == "auto":
if model_name_or_path in PREDICTION_TOKENS:
token_false = PREDICTION_TOKENS[model_name_or_path][0]
else:
token_false = PREDICTION_TOKENS["default"][0]
print(
f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_false to `{token_false}`."
)
if token_true == "auto":
if model_name_or_path in PREDICTION_TOKENS:
token_true = PREDICTION_TOKENS[model_name_or_path][1]
else:
token_true = PREDICTION_TOKENS["default"][1]
print(
f"WARNING: Model {model_name_or_path} does not have known True/False tokens. Defaulting token_true to `{token_true}`."
)

return token_false, token_true

class MonoVLMRanker(BaseRanker):
def __init__(
self,
model_name_or_path: str,
processor_name: Optional[str] = None,
dtype: Optional[Union[str, torch.dtype]] = 'bf16',
device: Optional[Union[str, torch.device]] = None,
batch_size: int = 1,
verbose: int = 1,
token_false: str = "auto",
token_true: str = "auto",
return_logits: bool = False,
prompt_template: str = "Assert the relevance of the previous image document to the following query, answer True or False. The query is: {query}",
**kwargs
):
self.verbose = verbose
self.device = get_device(device, verbose=self.verbose)
if self.device == 'mps':
print("WARNING: MPS is not supported by MonoVLMRanker due to PyTorch limitations. Falling back to CPU.")
self.device = 'cpu'
print(dtype)
self.dtype = get_dtype(dtype, self.device, self.verbose)
self.batch_size = batch_size
self.return_logits = return_logits
self.prompt_template = prompt_template

vprint(f"Loading model {model_name_or_path}, this might take a while...", self.verbose)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)

processor_name = processor_name or "Qwen/Qwen2-VL-2B-Instruct"
processor_kwargs = kwargs.get("processor_kwargs", {})
model_kwargs = kwargs.get("model_kwargs", {})
attention_implementation = kwargs.get("attention_implementation", "flash_attention_2")
self.processor = AutoProcessor.from_pretrained(processor_name, **processor_kwargs)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name_or_path,
device_map=self.device,
torch_dtype=self.dtype,
attn_implementation=attention_implementation,
**model_kwargs
)
self.model.eval()

token_false, token_true = _get_output_tokens(
model_name_or_path=model_name_or_path,
token_false=token_false,
token_true=token_true,
)
self.token_false_id = self.processor.tokenizer.convert_tokens_to_ids(token_false)
self.token_true_id = self.processor.tokenizer.convert_tokens_to_ids(token_true)

vprint(f"VLM true token set to {token_true}", self.verbose)
vprint(f"VLM false token set to {token_false}", self.verbose)

@torch.inference_mode()
def _get_scores(self, query: str, docs: List[Document]) -> List[float]:
scores = []
for doc in docs:
if doc.document_type != "image" or not doc.base64:
raise ValueError("MonoVLMRanker requires image documents with base64 data")

# Convert base64 to PIL Image
image_io = io.BytesIO(base64.b64decode(doc.base64))
image_io.seek(0) # Reset file pointer to start
image = Image.open(image_io).convert('RGB')

# Prepare prompt
prompt = self.prompt_template.format(query=query)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]

# Process inputs
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = self.processor(
text=text,
images=image,
return_tensors="pt"
).to(self.device).to(self.dtype)

# Get model outputs
outputs = self.model(**inputs)
logits = outputs.logits[:, -1, :]

# Calculate scores
relevant_logits = logits[:, [self.token_false_id, self.token_true_id]]
if self.return_logits:
score = relevant_logits[0, 1].cpu().item() # True logit
else:
probs = torch.softmax(relevant_logits, dim=-1)
score = probs[0, 1].cpu().item() # True probability

scores.append(score)

return scores

def rank(
self,
query: str,
docs: Union[str, List[str], Document, List[Document]],
doc_ids: Optional[Union[List[str], List[int]]] = None,
metadata: Optional[List[dict]] = None,
) -> RankedResults:
docs = prep_image_docs(docs, doc_ids, metadata)
scores = self._get_scores(query, docs)
ranked_results = [
Result(document=doc, score=score, rank=idx + 1)
for idx, (doc, score) in enumerate(
sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
)
]
return RankedResults(results=ranked_results, query=query, has_scores=True)

def score(self, query: str, doc: Union[str, Document]) -> float:
scores = self._get_scores(query, [doc])
return scores[0]
4 changes: 2 additions & 2 deletions rerankers/models/t5ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def score(self, query: str, doc: str) -> float:
scores = self._get_scores(query, [doc])
return scores[0] if scores else 0.0

@torch.no_grad()
@torch.inference_mode()
def _get_scores(
self,
query: str,
Expand Down Expand Up @@ -231,7 +231,7 @@ def _get_scores(
return logits
return scores

@torch.no_grad()
@torch.inference_mode()
def _greedy_decode(
self,
model,
Expand Down
4 changes: 2 additions & 2 deletions rerankers/models/transformer_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]):
inputs, return_tensors="pt", padding=True, truncation=True
).to(self.device)

@torch.no_grad()
@torch.inference_mode()
def rank(
self,
query: str,
Expand Down Expand Up @@ -83,7 +83,7 @@ def rank(
]
return RankedResults(results=ranked_results, query=query, has_scores=True)

@torch.no_grad()
@torch.inference_mode()
def score(self, query: str, doc: str) -> float:
inputs = self.tokenize((query, doc))
outputs = self.model(**inputs)
Expand Down
10 changes: 9 additions & 1 deletion rerankers/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"en": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
"other": "BAAI/bge-reranker-v2.5-gemma2-lightweight",
},
"monovlm": {
"en": "lightonai/MonoQwen2-VL-v0.1",
"other": "lightonai/MonoQwen2-VL-v0.1"
}
}

DEPS_MAPPING = {
Expand All @@ -47,6 +51,7 @@
"FlashRankRanker": "flashrank",
"RankLLMRanker": "rankllm",
"LLMLayerWiseRanker": "transformers",
"MonoVLMRanker": "transformers"
}

PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai", "text-embeddings-inference"]
Expand Down Expand Up @@ -84,6 +89,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"flashrank": "FlashRankRanker",
"rankllm": "RankLLMRanker",
"llm-layerwise": "LLMLayerWiseRanker",
"monovlm": "MonoVLMRanker"
}
return model_mapping.get(explicit_model_type, explicit_model_type)
else:
Expand All @@ -105,6 +111,8 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"vicuna": "RankLLMRanker",
"zephyr": "RankLLMRanker",
"bge-reranker-v2.5-gemma2-lightweight": "LLMLayerWiseRanker",
"monovlm": "MonoVLMRanker",
"monoqwen2-vl": "MonoVLMRanker"
}
for key, value in model_mapping.items():
if key in model_name:
Expand Down Expand Up @@ -198,7 +206,7 @@ def Reranker(
model_type = _get_model_type(model_name, model_type)

try:
print(f"Loading {model_type} model {model_name}")
vprint(f"Loading {model_type} model {model_name} (this message can be suppressed by setting verbose=0)", verbose)
return AVAILABLE_RANKERS[model_type](model_name, verbose=verbose, **kwargs)
except KeyError:
print(
Expand Down
Loading

0 comments on commit c4be8ff

Please sign in to comment.