From fc8f5946deecbf542d24d4000606d1d78cbefcea Mon Sep 17 00:00:00 2001 From: sam-bercovici <149293354+sam-bercovici@users.noreply.github.com> Date: Tue, 12 Nov 2024 09:57:01 +0200 Subject: [PATCH] pass kwargs allowing model_kwargs and tokenizer_kwargs to be passed. also fixed tests as Document has doc_id and not id. Last, add .conda/ to .gitignore (#44) --- .gitignore | 1 + rerankers/models/colbert_ranker.py | 6 +++++- rerankers/models/llm_layerwise_ranker.py | 15 ++++++++++++--- rerankers/models/t5ranker.py | 16 +++++++++++----- rerankers/models/transformer_ranker.py | 12 ++++++++++-- rerankers/results.py | 3 ++- tests/test_crossenc.py | 9 +++++++-- tests/test_results.py | 6 +++--- 8 files changed, 51 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 8e56b4d..505af8b 100644 --- a/.gitignore +++ b/.gitignore @@ -97,6 +97,7 @@ ENV/ env.bak/ venv.bak/ .envrc +.conda/ # mkdocs documentation /site diff --git a/rerankers/models/colbert_ranker.py b/rerankers/models/colbert_ranker.py index ffd7ffe..f676bf5 100644 --- a/rerankers/models/colbert_ranker.py +++ b/rerankers/models/colbert_ranker.py @@ -221,6 +221,7 @@ def __init__( verbose: int = 1, query_token: str = "[unused0]", document_token: str = "[unused1]", + **kwargs, ): self.verbose = verbose self.device = get_device(device, self.verbose) @@ -230,10 +231,13 @@ def __init__( f"Loading model {model_name}, this might take a while...", self.verbose, ) - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) + model_kwargs = kwargs.get("model_kwargs", {}) self.model = ( ColBERTModel.from_pretrained( model_name, + **model_kwargs ) .to(self.device) .to(self.dtype) diff --git a/rerankers/models/llm_layerwise_ranker.py b/rerankers/models/llm_layerwise_ranker.py index c1e11e4..c6dbbbe 100644 --- a/rerankers/models/llm_layerwise_ranker.py +++ b/rerankers/models/llm_layerwise_ranker.py @@ -38,6 +38,7 @@ def __init__( cutoff_layers: Optional[List[int]] = None, compress_ratio: Optional[int] = None, compress_layer: Optional[List[int]] = None, + **kwargs, ): self.verbose = verbose self.device = get_device(device, verbose=self.verbose) @@ -50,16 +51,24 @@ def __init__( ) vprint(f"Using device {self.device}.", self.verbose) vprint(f"Using dtype {self.dtype}.", self.verbose) - + tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) + tokenizer_trust_remote_code = tokenizer_kwargs.pop("trust_remote_code", True) self.tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, trust_remote_code=True + model_name_or_path, + trust_remote_code=tokenizer_trust_remote_code, + **tokenizer_kwargs, ) self.max_sequence_length = max_sequence_length self.tokenizer.model_max_length = self.max_sequence_length self.tokenizer.padding_side = "right" + model_kwargs = kwargs.get("model_kwargs", {}) + model_trust_remote_code = model_kwargs.pop("trust_remote_code", True) self.model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, trust_remote_code=True, torch_dtype=self.dtype + model_name_or_path, + trust_remote_code=model_trust_remote_code, + torch_dtype=self.dtype, + **model_kwargs, ).to(self.device) self.model.eval() diff --git a/rerankers/models/t5ranker.py b/rerankers/models/t5ranker.py index 8b51b5c..73ecbc8 100644 --- a/rerankers/models/t5ranker.py +++ b/rerankers/models/t5ranker.py @@ -14,8 +14,6 @@ from rerankers.documents import Document -import torch - from rerankers.results import RankedResults, Result from rerankers.utils import ( vprint, @@ -89,7 +87,8 @@ def __init__( token_false: str = "auto", token_true: str = "auto", return_logits: bool = False, - inputs_template: str = "Query: {query} Document: {text} Relevant:" + inputs_template: str = "Query: {query} Document: {text} Relevant:", + **kwargs, ): """ Implementation of the key functions from https://github.com/unicamp-dl/InRanker/blob/main/inranker/rankers.py @@ -113,11 +112,18 @@ def __init__( ) vprint(f"Using device {self.device}.", self.verbose) vprint(f"Using dtype {self.dtype}.", self.verbose) + model_kwargs = kwargs.get("model_kwargs", {}) self.model = AutoModelForSeq2SeqLM.from_pretrained( - model_name_or_path, torch_dtype=self.dtype + model_name_or_path, + torch_dtype=self.dtype, + **model_kwargs, ).to(self.device) self.model.eval() - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + **tokenizer_kwargs, + ) token_false, token_true = _get_output_tokens( model_name_or_path=model_name_or_path, diff --git a/rerankers/models/transformer_ranker.py b/rerankers/models/transformer_ranker.py index 9122d8f..a092beb 100644 --- a/rerankers/models/transformer_ranker.py +++ b/rerankers/models/transformer_ranker.py @@ -25,18 +25,26 @@ def __init__( device: Optional[Union[str, torch.device]] = None, batch_size: int = 16, verbose: int = 1, + **kwargs, ): self.verbose = verbose self.device = get_device(device, verbose=self.verbose) self.dtype = get_dtype(dtype, self.device, self.verbose) + model_kwargs = kwargs.get("model_kwargs", {}) self.model = AutoModelForSequenceClassification.from_pretrained( - model_name_or_path, torch_dtype=self.dtype + model_name_or_path, + torch_dtype=self.dtype, + **model_kwargs, ).to(self.device) vprint(f"Loaded model {model_name_or_path}", self.verbose) vprint(f"Using device {self.device}.", self.verbose) vprint(f"Using dtype {self.dtype}.", self.verbose) self.model.eval() - self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {}) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + **tokenizer_kwargs, + ) self.ranking_type = "pointwise" self.batch_size = batch_size diff --git a/rerankers/results.py b/rerankers/results.py index ab5ef49..275cfd6 100644 --- a/rerankers/results.py +++ b/rerankers/results.py @@ -1,4 +1,5 @@ -from typing import Union, Optional, List +from typing import List, Optional, Union + from pydantic import BaseModel, validator from rerankers.documents import Document diff --git a/tests/test_crossenc.py b/tests/test_crossenc.py index 9637708..b929264 100644 --- a/tests/test_crossenc.py +++ b/tests/test_crossenc.py @@ -15,12 +15,17 @@ def test_transformer_ranker_rank(mock_rank): expected_results = RankedResults( results=[ Result( - document=Document(id=1, text="Gone with the wind is an all-time classic"), + document=Document( + doc_id=1, text="Gone with the wind is an all-time classic" + ), score=1.6181640625, rank=1, ), Result( - document=Document(id=0, text="Gone with the wind is a masterclass in bad storytelling."), + document=Document( + doc_id=0, + text="Gone with the wind is a masterclass in bad storytelling.", + ), score=0.88427734375, rank=2, ), diff --git a/tests/test_results.py b/tests/test_results.py index 2fbeb22..f60be44 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -6,8 +6,8 @@ def test_ranked_results_functions(): results = RankedResults( results=[ - Result(document=Document(id=0, text="Doc 0"), score=0.9, rank=2), - Result(document=Document(id=1, text="Doc 1"), score=0.95, rank=1), + Result(document=Document(doc_id=0, text="Doc 0"), score=0.9, rank=2), + Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1), ], query="Test Query", has_scores=True, @@ -20,7 +20,7 @@ def test_ranked_results_functions(): def test_result_attributes(): - result = Result(document=Document(id=1, text="Doc 1"), score=0.95, rank=1) + result = Result(document=Document(doc_id=1, text="Doc 1"), score=0.95, rank=1) assert result.doc_id == 1 assert result.text == "Doc 1" assert result.score == 0.95