diff --git a/trec_biogen/pipelines.py b/trec_biogen/pipelines.py index 4cc06d7..a1d7398 100644 --- a/trec_biogen/pipelines.py +++ b/trec_biogen/pipelines.py @@ -2,6 +2,7 @@ from functools import cached_property from os import environ from typing import Any, Hashable +from pyterrier_t5 import MonoT5ReRanker from elasticsearch7 import Elasticsearch from elasticsearch7_dsl.query import Query, Match, Exists, Bool @@ -94,6 +95,7 @@ def _build_result(article: Article) -> dict[Hashable, Any]: @dataclass(frozen=True) class Pipeline(Transformer): + @cached_property def _elasticsearch(self) -> Elasticsearch: return elasticsearch_connection() @@ -104,6 +106,9 @@ def _elasticsearch_index_pubmed(self) -> str | None: @cached_property def _pipeline(self) -> Transformer: + + monoT5 = MonoT5ReRanker(verbose=True, batch_size=16) + pipeline = Transformer.identity() # Retrieve or re-rank documents with Elasticsearch (BM25). @@ -112,11 +117,10 @@ def _pipeline(self) -> Transformer: client=self._elasticsearch, query_builder=_build_query, result_builder=_build_result, - num_results=10, + num_results=100, index=self._elasticsearch_index_pubmed, verbose=True, - ) - + ) >> monoT5 # TODO: Re-rank documents? # TODO: Split passages.