Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to create ref docs from rag chunks #186

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,8 @@ def conversation_request(

timestamps["store transcripts"] = time.time()

# De-dup & retain order to create list of referenced documents
referenced_documents = list(
{
rag_chunk.doc_url: ReferencedDocument(
rag_chunk.doc_url,
rag_chunk.doc_title,
)
for rag_chunk in summarizer_response.rag_chunks
}.values()
referenced_documents = ReferencedDocument.from_rag_chunks(
summarizer_response.rag_chunks
)

timestamps["add references"] = time.time()
Expand Down
19 changes: 18 additions & 1 deletion ols/app/models/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Data models representing payloads for REST API calls."""

from collections import OrderedDict
from typing import Optional, Self

from pydantic import BaseModel, field_validator, model_validator
Expand Down Expand Up @@ -123,7 +124,7 @@ def validate_provider_and_model(self) -> Self:
return self


@dataclass
@dataclass(frozen=True, unsafe_hash=False)
class ReferencedDocument:
"""RAG referenced document.
Expand All @@ -135,6 +136,22 @@ class ReferencedDocument:
docs_url: str
title: str

@staticmethod
def from_rag_chunks(rag_chunks: list["RagChunk"]) -> list["ReferencedDocument"]:
"""Create a list of ReferencedDocument from a list of rag_chunks.
Order of items is preserved.
"""
return list(
OrderedDict(
(
rag_chunk.doc_url,
ReferencedDocument(rag_chunk.doc_url, rag_chunk.doc_title),
)
for rag_chunk in rag_chunks
).values()
)


class LLMResponse(BaseModel):
"""Model representing a response from the LLM (Language Model).
Expand Down
32 changes: 0 additions & 32 deletions tests/unit/app/endpoints/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,38 +665,6 @@ def test_conversation_request(
assert len(response.conversation_id) == 0


@pytest.mark.usefixtures("_load_config")
@patch("ols.src.query_helpers.question_validator.QuestionValidator.validate_question")
@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.summarize")
@patch("ols.config.conversation_cache.get")
def test_conversation_request_dedup_ref_docs(
mock_conversation_cache_get,
mock_summarize,
mock_validate_question,
auth,
):
"""Test deduplication of referenced docs."""
mock_rag_chunk = [
RagChunk("text1", "url-b", "title-b"),
RagChunk("text2", "url-b", "title-b"), # duplicate doc
RagChunk("text3", "url-a", "title-a"),
]
mock_validate_question.return_value = True
mock_summarize.return_value = SummarizerResponse(
response="some response",
rag_chunks=mock_rag_chunk,
history_truncated=False,
)
llm_request = LLMRequest(query="some query")
response = ols.conversation_request(llm_request, auth)

assert len(response.referenced_documents) == 2
assert response.referenced_documents[0].docs_url == "url-b"
assert response.referenced_documents[0].title == "title-b"
assert response.referenced_documents[1].docs_url == "url-a"
assert response.referenced_documents[1].title == "title-a"


@pytest.mark.usefixtures("_load_config")
@patch(
"ols.app.endpoints.ols.config.ols_config.query_validation_method",
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/app/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
LivenessResponse,
LLMRequest,
LLMResponse,
RagChunk,
ReadinessResponse,
ReferencedDocument,
StatusResponse,
Expand Down Expand Up @@ -411,3 +412,23 @@ def test_cache_entries_to_history_no_response():
"human: what?",
"ai: ",
]


def test_ref_docs_from_rag_chunks():
"""Test the ReferencedDocument model method `from_rag_chunks`."""
# urls are unsorted to ensure there is not a hidden sorting
rag_chunk_1 = RagChunk("bla2", "url2", "title2")
rag_chunk_2 = RagChunk("bla1", "url1", "title1")
rag_chunk_3 = RagChunk("bla3", "url3", "title3")
rag_chunk_4 = RagChunk("bla2", "url2", "title2") # duplicated doc

ref_docs = ReferencedDocument.from_rag_chunks(
[rag_chunk_1, rag_chunk_2, rag_chunk_3, rag_chunk_4]
)
expected = [
ReferencedDocument(docs_url="url2", title="title2"),
ReferencedDocument(docs_url="url1", title="title1"),
ReferencedDocument(docs_url="url3", title="title3"),
]

assert ref_docs == expected
Loading