From 119d235efa7ed94b9d11537608bac67316ce8dc0 Mon Sep 17 00:00:00 2001 From: Tami Takamiya Date: Sat, 28 Dec 2024 01:28:31 -0500 Subject: [PATCH 1/2] Allow reranking nodes retrieved from vector DB --- ols/customize/__init__.py | 1 + ols/customize/ols/reranker.py | 14 +++++++++++ ols/src/query_helpers/docs_summarizer.py | 6 ++++- .../query_helpers/test_docs_summarizer.py | 25 +++++++++++++++++++ 4 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 ols/customize/ols/reranker.py diff --git a/ols/customize/__init__.py b/ols/customize/__init__.py index b5fff8ec..d334ab3d 100644 --- a/ols/customize/__init__.py +++ b/ols/customize/__init__.py @@ -6,3 +6,4 @@ project = os.getenv("PROJECT", "ols") prompts = importlib.import_module(f"ols.customize.{project}.prompts") keywords = importlib.import_module(f"ols.customize.{project}.keywords") +reranker = importlib.import_module(f"ols.customize.{project}.reranker") diff --git a/ols/customize/ols/reranker.py b/ols/customize/ols/reranker.py new file mode 100644 index 00000000..da3f476c --- /dev/null +++ b/ols/customize/ols/reranker.py @@ -0,0 +1,14 @@ +"""Reranker for post-processing the Vector DB search results.""" + +import logging + +from llama_index.core.schema import NodeWithScore + +logger = logging.getLogger(__name__) + + +def rerank(retrieved_nodes: list[NodeWithScore]) -> list[NodeWithScore]: + """Rerank Vector DB search results.""" + message = f"reranker.rerank() is called with {len(retrieved_nodes)} result(s)." + logger.debug(message) + return retrieved_nodes diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index 8c3076e1..46c29fe9 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -10,6 +10,8 @@ from ols.app.metrics import TokenMetricUpdater from ols.app.models.models import SummarizerResponse from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters +from ols.customize import prompts +from ols.customize.ols.reranker import rerank from ols.src.prompts.prompt_generator import GeneratePrompt from ols.src.query_helpers.query_helper import QueryHelper from ols.utils.token_handler import TokenHandler @@ -80,8 +82,10 @@ def summarize( if vector_index is not None: retriever = vector_index.as_retriever(similarity_top_k=RAG_CONTENT_LIMIT) + retrieved_nodes = retriever.retrieve(query) + retrieved_nodes = rerank(retrieved_nodes) rag_chunks, available_tokens = token_handler.truncate_rag_context( - retriever.retrieve(query), self.model, available_tokens + retrieved_nodes, self.model, available_tokens ) else: logger.warning("Proceeding without RAG content. Check start up messages.") diff --git a/tests/unit/query_helpers/test_docs_summarizer.py b/tests/unit/query_helpers/test_docs_summarizer.py index dbd19cd1..2ed8c6f7 100644 --- a/tests/unit/query_helpers/test_docs_summarizer.py +++ b/tests/unit/query_helpers/test_docs_summarizer.py @@ -1,12 +1,15 @@ """Unit tests for DocsSummarizer class.""" +import logging from unittest.mock import ANY, patch import pytest from ols import config +from ols.app.models.config import LoggingConfig from ols.src.query_helpers.docs_summarizer import DocsSummarizer, QueryHelper from ols.utils import suid +from ols.utils.logging_configurator import configure_logging from tests import constants from tests.mock_classes.mock_langchain_interface import mock_langchain_interface from tests.mock_classes.mock_llama_index import MockLlamaIndex @@ -129,3 +132,25 @@ def test_summarize_no_reference_content(): assert question in summary.response assert summary.rag_chunks == [] assert not summary.history_truncated + + +@patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4) +@patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 3) +@patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None)) +def test_summarize_reranker(caplog): + """Basic test to make sure the reranker is called as expected.""" + logging_config = LoggingConfig(app_log_level="debug") + + configure_logging(logging_config) + logger = logging.getLogger("ols") + logger.handlers = [caplog.handler] # add caplog handler to logger + + summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None)) + question = "What's the ultimate question with answer 42?" + rag_index = MockLlamaIndex() + # no history is passed into summarize() method + summary = summarizer.summarize(conversation_id, question, rag_index) + check_summary_result(summary, question) + + # Check captured log text to see if reranker was called. + assert "reranker.rerank() is called with 1 result(s)." in caplog.text From ac104b65d2d0d976e90ce3eaf93b360dc5f39dde Mon Sep 17 00:00:00 2001 From: Tami Takamiya Date: Thu, 2 Jan 2025 08:05:50 -0500 Subject: [PATCH 2/2] Fix a merge (linter) error --- ols/src/query_helpers/docs_summarizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ols/src/query_helpers/docs_summarizer.py b/ols/src/query_helpers/docs_summarizer.py index 46c29fe9..0041b546 100644 --- a/ols/src/query_helpers/docs_summarizer.py +++ b/ols/src/query_helpers/docs_summarizer.py @@ -10,7 +10,6 @@ from ols.app.metrics import TokenMetricUpdater from ols.app.models.models import SummarizerResponse from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters -from ols.customize import prompts from ols.customize.ols.reranker import rerank from ols.src.prompts.prompt_generator import GeneratePrompt from ols.src.query_helpers.query_helper import QueryHelper