Skip to content

Commit

Permalink
Merge pull request #228 from TamiTakamiya/TamiTakamiya/reranker
Browse files Browse the repository at this point in the history
Allow reranking nodes retrieved from vector DB
  • Loading branch information
tisnik authored Jan 2, 2025
2 parents a3f019d + ac104b6 commit 3faebfa
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 1 deletion.
1 change: 1 addition & 0 deletions ols/customize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
project = os.getenv("PROJECT", "ols")
prompts = importlib.import_module(f"ols.customize.{project}.prompts")
keywords = importlib.import_module(f"ols.customize.{project}.keywords")
reranker = importlib.import_module(f"ols.customize.{project}.reranker")
14 changes: 14 additions & 0 deletions ols/customize/ols/reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Reranker for post-processing the Vector DB search results."""

import logging

from llama_index.core.schema import NodeWithScore

logger = logging.getLogger(__name__)


def rerank(retrieved_nodes: list[NodeWithScore]) -> list[NodeWithScore]:
"""Rerank Vector DB search results."""
message = f"reranker.rerank() is called with {len(retrieved_nodes)} result(s)."
logger.debug(message)
return retrieved_nodes
5 changes: 4 additions & 1 deletion ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ols.app.metrics import TokenMetricUpdater
from ols.app.models.models import SummarizerResponse
from ols.constants import RAG_CONTENT_LIMIT, GenericLLMParameters
from ols.customize.ols.reranker import rerank
from ols.src.prompts.prompt_generator import GeneratePrompt
from ols.src.query_helpers.query_helper import QueryHelper
from ols.utils.token_handler import TokenHandler
Expand Down Expand Up @@ -80,8 +81,10 @@ def summarize(

if vector_index is not None:
retriever = vector_index.as_retriever(similarity_top_k=RAG_CONTENT_LIMIT)
retrieved_nodes = retriever.retrieve(query)
retrieved_nodes = rerank(retrieved_nodes)
rag_chunks, available_tokens = token_handler.truncate_rag_context(
retriever.retrieve(query), self.model, available_tokens
retrieved_nodes, self.model, available_tokens
)
else:
logger.warning("Proceeding without RAG content. Check start up messages.")
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/query_helpers/test_docs_summarizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Unit tests for DocsSummarizer class."""

import logging
from unittest.mock import ANY, patch

import pytest

from ols import config
from ols.app.models.config import LoggingConfig
from ols.src.query_helpers.docs_summarizer import DocsSummarizer, QueryHelper
from ols.utils import suid
from ols.utils.logging_configurator import configure_logging
from tests import constants
from tests.mock_classes.mock_langchain_interface import mock_langchain_interface
from tests.mock_classes.mock_llama_index import MockLlamaIndex
Expand Down Expand Up @@ -129,3 +132,25 @@ def test_summarize_no_reference_content():
assert question in summary.response
assert summary.rag_chunks == []
assert not summary.history_truncated


@patch("ols.utils.token_handler.RAG_SIMILARITY_CUTOFF", 0.4)
@patch("ols.utils.token_handler.MINIMUM_CONTEXT_TOKEN_LIMIT", 3)
@patch("ols.src.query_helpers.docs_summarizer.LLMChain", new=mock_llm_chain(None))
def test_summarize_reranker(caplog):
"""Basic test to make sure the reranker is called as expected."""
logging_config = LoggingConfig(app_log_level="debug")

configure_logging(logging_config)
logger = logging.getLogger("ols")
logger.handlers = [caplog.handler] # add caplog handler to logger

summarizer = DocsSummarizer(llm_loader=mock_llm_loader(None))
question = "What's the ultimate question with answer 42?"
rag_index = MockLlamaIndex()
# no history is passed into summarize() method
summary = summarizer.summarize(conversation_id, question, rag_index)
check_summary_result(summary, question)

# Check captured log text to see if reranker was called.
assert "reranker.rerank() is called with 1 result(s)." in caplog.text

0 comments on commit 3faebfa

Please sign in to comment.