Skip to content

Commit

Permalink
pass kwargs allowing model_kwargs and tokenizer_kwargs to be passed. …
Browse files Browse the repository at this point in the history
…also fixed tests as Document has doc_id and not id. Last, add .conda/ to .gitignore (#44)
  • Loading branch information
sam-bercovici authored Nov 12, 2024
1 parent c4be8ff commit fc8f594
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ ENV/
env.bak/
venv.bak/
.envrc
.conda/

# mkdocs documentation
/site
Expand Down
6 changes: 5 additions & 1 deletion rerankers/models/colbert_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(
verbose: int = 1,
query_token: str = "[unused0]",
document_token: str = "[unused1]",
**kwargs,
):
self.verbose = verbose
self.device = get_device(device, self.verbose)
Expand All @@ -230,10 +231,13 @@ def __init__(
f"Loading model {model_name}, this might take a while...",
self.verbose,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
model_kwargs = kwargs.get("model_kwargs", {})
self.model = (
ColBERTModel.from_pretrained(
model_name,
**model_kwargs
)
.to(self.device)
.to(self.dtype)
Expand Down
15 changes: 12 additions & 3 deletions rerankers/models/llm_layerwise_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
cutoff_layers: Optional[List[int]] = None,
compress_ratio: Optional[int] = None,
compress_layer: Optional[List[int]] = None,
**kwargs,
):
self.verbose = verbose
self.device = get_device(device, verbose=self.verbose)
Expand All @@ -50,16 +51,24 @@ def __init__(
)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)

tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
tokenizer_trust_remote_code = tokenizer_kwargs.pop("trust_remote_code", True)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True
model_name_or_path,
trust_remote_code=tokenizer_trust_remote_code,
**tokenizer_kwargs,
)
self.max_sequence_length = max_sequence_length
self.tokenizer.model_max_length = self.max_sequence_length
self.tokenizer.padding_side = "right"
model_kwargs = kwargs.get("model_kwargs", {})
model_trust_remote_code = model_kwargs.pop("trust_remote_code", True)

self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, trust_remote_code=True, torch_dtype=self.dtype
model_name_or_path,
trust_remote_code=model_trust_remote_code,
torch_dtype=self.dtype,
**model_kwargs,
).to(self.device)
self.model.eval()

Expand Down
16 changes: 11 additions & 5 deletions rerankers/models/t5ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from rerankers.documents import Document


import torch

from rerankers.results import RankedResults, Result
from rerankers.utils import (
vprint,
Expand Down Expand Up @@ -89,7 +87,8 @@ def __init__(
token_false: str = "auto",
token_true: str = "auto",
return_logits: bool = False,
inputs_template: str = "Query: {query} Document: {text} Relevant:"
inputs_template: str = "Query: {query} Document: {text} Relevant:",
**kwargs,
):
"""
Implementation of the key functions from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py
Expand All @@ -113,11 +112,18 @@ def __init__(
)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)
model_kwargs = kwargs.get("model_kwargs", {})
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path, torch_dtype=self.dtype
model_name_or_path,
torch_dtype=self.dtype,
**model_kwargs,
).to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
**tokenizer_kwargs,
)

token_false, token_true = _get_output_tokens(
model_name_or_path=model_name_or_path,
Expand Down
12 changes: 10 additions & 2 deletions rerankers/models/transformer_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,26 @@ def __init__(
device: Optional[Union[str, torch.device]] = None,
batch_size: int = 16,
verbose: int = 1,
**kwargs,
):
self.verbose = verbose
self.device = get_device(device, verbose=self.verbose)
self.dtype = get_dtype(dtype, self.device, self.verbose)
model_kwargs = kwargs.get("model_kwargs", {})
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path, torch_dtype=self.dtype
model_name_or_path,
torch_dtype=self.dtype,
**model_kwargs,
).to(self.device)
vprint(f"Loaded model {model_name_or_path}", self.verbose)
vprint(f"Using device {self.device}.", self.verbose)
vprint(f"Using dtype {self.dtype}.", self.verbose)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
**tokenizer_kwargs,
)
self.ranking_type = "pointwise"
self.batch_size = batch_size

Expand Down
3 changes: 2 additions & 1 deletion rerankers/results.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union, Optional, List
from typing import List, Optional, Union

from pydantic import BaseModel, validator

from rerankers.documents import Document
Expand Down
9 changes: 7 additions & 2 deletions tests/test_crossenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@ def test_transformer_ranker_rank(mock_rank):
expected_results = RankedResults(
results=[
Result(
document=Document(id=1, text="Gone with the wind is an all-time classic"),
document=Document(
doc_id=1, text="Gone with the wind is an all-time classic"
),
score=1.6181640625,
rank=1,
),
Result(
document=Document(id=0, text="Gone with the wind is a masterclass in bad storytelling."),
document=Document(
doc_id=0,
text="Gone with the wind is a masterclass in bad storytelling.",
),
score=0.88427734375,
rank=2,
),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
def test_ranked_results_functions():
results = RankedResults(
results=[
Result(document=Document(id=0, text="Doc 0"), score=0.9, rank=2),
Result(document=Document(id=1, text="Doc 1"), score=0.95, rank=1),
Result(document=Document(doc_id=0, text="Doc 0"), score=0.9, rank=2),
Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1),
],
query="Test Query",
has_scores=True,
Expand All @@ -20,7 +20,7 @@ def test_ranked_results_functions():


def test_result_attributes():
result = Result(document=Document(id=1, text="Doc 1"), score=0.95, rank=1)
result = Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1)
assert result.doc_id == 1
assert result.text == "Doc 1"
assert result.score == 0.95
Expand Down

0 comments on commit fc8f594

Please sign in to comment.