Skip to content

Commit

Permalink
Add neural re-rankers
Browse files Browse the repository at this point in the history
  • Loading branch information
janheinrichmerker committed Aug 29, 2024
1 parent 02c39d8 commit d1fe3ff
Showing 1 changed file with 127 additions and 8 deletions.
135 changes: 127 additions & 8 deletions trec_biogen/optimization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Callable, Literal, Sequence
from warnings import catch_warnings, simplefilter
from typing import Callable, Literal, Sequence, TypeAlias
from warnings import catch_warnings, simplefilter, filterwarnings

from dspy import settings as dspy_settings
from optuna import Study, Trial, create_study
Expand All @@ -9,6 +9,8 @@
from optuna.trial import FrozenTrial
from optuna_integration import WeightsAndBiasesCallback
from pyterrier.transformer import Transformer
from pyterrier_t5 import MonoT5ReRanker, DuoT5ReRanker
from pyterrier_dr import TasB, TctColBert, Ance

from trec_biogen.answering import IndependentAnsweringModule, RecurrentAnsweringModule
from trec_biogen.dspy_generation import (
Expand All @@ -27,6 +29,7 @@
from trec_biogen.language_models import LanguageModelName, get_dspy_language_model
from trec_biogen.model import Answer
from trec_biogen.modules import AnsweringModule, GenerationModule, RetrievalModule
from trec_biogen.pyterrier import CutoffRerank
from trec_biogen.pyterrier_pubmed import (
PubMedElasticsearchRetrieve,
PubMedSentencePassager,
Expand Down Expand Up @@ -56,6 +59,21 @@ def _suggest_must_should(trial: Trial, name: str) -> Literal["must", "should"] |
raise ValueError(f"Illegal value: {must_should}")



PointwiseRerankerModel: TypeAlias = Literal[
"castorini/monot5-base-msmarco",
"castorini/monot5-3b-msmarco",
"castorini/monot5-3b-med-msmarco",
"sentence-transformers/msmarco-distilbert-base-tas-b",
"sentence-transformers/msmarco-roberta-base-ance-firstp",
"castorini/tct_colbert-v2-hnp-msmarco",
]
PairwiseRerankerModel: TypeAlias = Literal[
"castorini/duot5-base-msmarco",
"castorini/duot5-3b-msmarco",
"castorini/duot5-3b-med-msmarco",
]

def build_retrieval_module(
trial: Trial,
) -> RetrievalModule:
Expand Down Expand Up @@ -207,13 +225,114 @@ def build_retrieval_module(
)
pipeline = pipeline >> pubmed_sentence_passager

# TODO: Re-ranking.
# Pointwise re-ranking.
pointwise_reranker_model: PointwiseRerankerModel | None = trial.suggest_categorical(
name="pointwise_reranker_model",
choices=[
"castorini/monot5-base-msmarco",
"castorini/monot5-3b-msmarco",
"castorini/monot5-3b-med-msmarco",
"sentence-transformers/msmarco-distilbert-base-tas-b",
"sentence-transformers/msmarco-roberta-base-ance-firstp",
"castorini/tct_colbert-v2-hnp-msmarco",
None,
],
) # type: ignore
pointwise_reranker: Transformer | None
if pointwise_reranker_model in (
"castorini/monot5-base-msmarco",
"castorini/monot5-3b-msmarco",
"castorini/monot5-3b-med-msmarco",
):
pointwise_reranker = MonoT5ReRanker(
model=pointwise_reranker_model,
verbose=True,
)
elif pointwise_reranker_model in (
"sentence-transformers/msmarco-distilbert-base-tas-b",
):
with catch_warnings():
filterwarnings(
action="ignore",
message="TypedStorage is deprecated",
category=UserWarning,
)
pointwise_reranker = TasB(
model_name=pointwise_reranker_model,
verbose=True,
)
elif pointwise_reranker_model in (
"castorini/tct_colbert-v2-hnp-msmarco",
):
pointwise_reranker = TctColBert(
model_name=pointwise_reranker_model,
verbose=True,
)
elif pointwise_reranker_model in (
"sentence-transformers/msmarco-roberta-base-ance-firstp",
):
pointwise_reranker = Ance(
model_name=pointwise_reranker_model,
verbose=True,
)
else:
pointwise_reranker = None
if pointwise_reranker is not None:
pointwise_reranker_cutoff=trial.suggest_categorical(
name="pointwise_reranker_cutoff",
choices=[
10,
50,
100,
]
)
pipeline = CutoffRerank(
candidates=pipeline,
reranker=pointwise_reranker,
cutoff=pointwise_reranker_cutoff,
)

# max_sentences=trial.suggest_int(
# name="pubmed_sentence_passager_max_sentences",
# low=1,
# high=5,
# ),
# Pairwise re-ranking.
pairwise_reranker_model: PairwiseRerankerModel | None = trial.suggest_categorical(
name="pairwise_reranker_model",
choices=[
"castorini/duot5-base-msmarco",
"castorini/duot5-3b-msmarco",
"castorini/duot5-3b-med-msmarco",
None,
],
) # type: ignore
pairwise_reranker: Transformer | None
if pairwise_reranker_model in (
"castorini/duot5-base-msmarco",
"castorini/duot5-3b-msmarco",
"castorini/duot5-3b-med-msmarco",
):
with catch_warnings():
filterwarnings(
action="ignore",
message="TypedStorage is deprecated",
category=UserWarning,
)
pairwise_reranker = DuoT5ReRanker(
model=pairwise_reranker_model,
verbose=True,
)
else:
pairwise_reranker = None
if pairwise_reranker is not None:
pairwise_reranker_cutoff=trial.suggest_categorical(
name="pairwise_reranker_cutoff",
choices=[
3,
5,
]
)
pipeline = CutoffRerank(
candidates=pipeline,
reranker=pairwise_reranker,
cutoff=pairwise_reranker_cutoff,
)

retrieval_module = PyterrierRetrievalModule(pipeline, progress=True)

Expand Down

0 comments on commit d1fe3ff

Please sign in to comment.