Skip to content

Commit

Permalink
Merge pull request #186 from tisnik/ref-docs-from-rag-chunks
Browse files Browse the repository at this point in the history
Ability to create ref docs from rag chunks
  • Loading branch information
tisnik authored Dec 6, 2024
2 parents b4287e3 + d3a6c5a commit d3df32b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 42 deletions.
11 changes: 2 additions & 9 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,8 @@ def conversation_request(

timestamps["store transcripts"] = time.time()

# De-dup & retain order to create list of referenced documents
referenced_documents = list(
{
rag_chunk.doc_url: ReferencedDocument(
rag_chunk.doc_url,
rag_chunk.doc_title,
)
for rag_chunk in summarizer_response.rag_chunks
}.values()
referenced_documents = ReferencedDocument.from_rag_chunks(
summarizer_response.rag_chunks
)

timestamps["add references"] = time.time()
Expand Down
19 changes: 18 additions & 1 deletion ols/app/models/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Data models representing payloads for REST API calls."""

from collections import OrderedDict
from typing import Optional, Self

from pydantic import BaseModel, field_validator, model_validator
Expand Down Expand Up @@ -123,7 +124,7 @@ def validate_provider_and_model(self) -> Self:
return self


@dataclass
@dataclass(frozen=True, unsafe_hash=False)
class ReferencedDocument:
"""RAG referenced document.
Expand All @@ -135,6 +136,22 @@ class ReferencedDocument:
docs_url: str
title: str

@staticmethod
def from_rag_chunks(rag_chunks: list["RagChunk"]) -> list["ReferencedDocument"]:
"""Create a list of ReferencedDocument from a list of rag_chunks.
Order of items is preserved.
"""
return list(
OrderedDict(
(
rag_chunk.doc_url,
ReferencedDocument(rag_chunk.doc_url, rag_chunk.doc_title),
)
for rag_chunk in rag_chunks
).values()
)


class LLMResponse(BaseModel):
"""Model representing a response from the LLM (Language Model).
Expand Down
32 changes: 0 additions & 32 deletions tests/unit/app/endpoints/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,38 +665,6 @@ def test_conversation_request(
assert len(response.conversation_id) == 0


@pytest.mark.usefixtures("_load_config")
@patch("ols.src.query_helpers.question_validator.QuestionValidator.validate_question")
@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.summarize")
@patch("ols.config.conversation_cache.get")
def test_conversation_request_dedup_ref_docs(
mock_conversation_cache_get,
mock_summarize,
mock_validate_question,
auth,
):
"""Test deduplication of referenced docs."""
mock_rag_chunk = [
RagChunk("text1", "url-b", "title-b"),
RagChunk("text2", "url-b", "title-b"), # duplicate doc
RagChunk("text3", "url-a", "title-a"),
]
mock_validate_question.return_value = True
mock_summarize.return_value = SummarizerResponse(
response="some response",
rag_chunks=mock_rag_chunk,
history_truncated=False,
)
llm_request = LLMRequest(query="some query")
response = ols.conversation_request(llm_request, auth)

assert len(response.referenced_documents) == 2
assert response.referenced_documents[0].docs_url == "url-b"
assert response.referenced_documents[0].title == "title-b"
assert response.referenced_documents[1].docs_url == "url-a"
assert response.referenced_documents[1].title == "title-a"


@pytest.mark.usefixtures("_load_config")
@patch(
"ols.app.endpoints.ols.config.ols_config.query_validation_method",
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/app/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
LivenessResponse,
LLMRequest,
LLMResponse,
RagChunk,
ReadinessResponse,
ReferencedDocument,
StatusResponse,
Expand Down Expand Up @@ -411,3 +412,23 @@ def test_cache_entries_to_history_no_response():
"human: what?",
"ai: ",
]


def test_ref_docs_from_rag_chunks():
"""Test the ReferencedDocument model method `from_rag_chunks`."""
# urls are unsorted to ensure there is not a hidden sorting
rag_chunk_1 = RagChunk("bla2", "url2", "title2")
rag_chunk_2 = RagChunk("bla1", "url1", "title1")
rag_chunk_3 = RagChunk("bla3", "url3", "title3")
rag_chunk_4 = RagChunk("bla2", "url2", "title2") # duplicated doc

ref_docs = ReferencedDocument.from_rag_chunks(
[rag_chunk_1, rag_chunk_2, rag_chunk_3, rag_chunk_4]
)
expected = [
ReferencedDocument(docs_url="url2", title="title2"),
ReferencedDocument(docs_url="url1", title="title1"),
ReferencedDocument(docs_url="url3", title="title3"),
]

assert ref_docs == expected

0 comments on commit d3df32b

Please sign in to comment.