From d3a6c5a39dc24e5fcfb44ef0b0166f198abd7cd1 Mon Sep 17 00:00:00 2001 From: Ondrej Metelka Date: Thu, 5 Dec 2024 15:29:02 +0100 Subject: [PATCH] Ability to create ref docs from rag chunks Signed-off-by: Pavel Tisnovsky --- ols/app/endpoints/ols.py | 11 ++-------- ols/app/models/models.py | 19 ++++++++++++++++- tests/unit/app/endpoints/test_ols.py | 32 ---------------------------- tests/unit/app/models/test_models.py | 21 ++++++++++++++++++ 4 files changed, 41 insertions(+), 42 deletions(-) diff --git a/ols/app/endpoints/ols.py b/ols/app/endpoints/ols.py index 16233a9f..3eb6bb84 100644 --- a/ols/app/endpoints/ols.py +++ b/ols/app/endpoints/ols.py @@ -157,15 +157,8 @@ def conversation_request( timestamps["store transcripts"] = time.time() - # De-dup & retain order to create list of referenced documents - referenced_documents = list( - { - rag_chunk.doc_url: ReferencedDocument( - rag_chunk.doc_url, - rag_chunk.doc_title, - ) - for rag_chunk in summarizer_response.rag_chunks - }.values() + referenced_documents = ReferencedDocument.from_rag_chunks( + summarizer_response.rag_chunks ) timestamps["add references"] = time.time() diff --git a/ols/app/models/models.py b/ols/app/models/models.py index 164c112f..1b2cfa11 100644 --- a/ols/app/models/models.py +++ b/ols/app/models/models.py @@ -1,5 +1,6 @@ """Data models representing payloads for REST API calls.""" +from collections import OrderedDict from typing import Optional, Self from pydantic import BaseModel, field_validator, model_validator @@ -123,7 +124,7 @@ def validate_provider_and_model(self) -> Self: return self -@dataclass +@dataclass(frozen=True, unsafe_hash=False) class ReferencedDocument: """RAG referenced document. @@ -135,6 +136,22 @@ class ReferencedDocument: docs_url: str title: str + @staticmethod + def from_rag_chunks(rag_chunks: list["RagChunk"]) -> list["ReferencedDocument"]: + """Create a list of ReferencedDocument from a list of rag_chunks. + + Order of items is preserved. + """ + return list( + OrderedDict( + ( + rag_chunk.doc_url, + ReferencedDocument(rag_chunk.doc_url, rag_chunk.doc_title), + ) + for rag_chunk in rag_chunks + ).values() + ) + class LLMResponse(BaseModel): """Model representing a response from the LLM (Language Model). diff --git a/tests/unit/app/endpoints/test_ols.py b/tests/unit/app/endpoints/test_ols.py index 1f0aa0e9..722f3d83 100644 --- a/tests/unit/app/endpoints/test_ols.py +++ b/tests/unit/app/endpoints/test_ols.py @@ -665,38 +665,6 @@ def test_conversation_request( assert len(response.conversation_id) == 0 -@pytest.mark.usefixtures("_load_config") -@patch("ols.src.query_helpers.question_validator.QuestionValidator.validate_question") -@patch("ols.src.query_helpers.docs_summarizer.DocsSummarizer.summarize") -@patch("ols.config.conversation_cache.get") -def test_conversation_request_dedup_ref_docs( - mock_conversation_cache_get, - mock_summarize, - mock_validate_question, - auth, -): - """Test deduplication of referenced docs.""" - mock_rag_chunk = [ - RagChunk("text1", "url-b", "title-b"), - RagChunk("text2", "url-b", "title-b"), # duplicate doc - RagChunk("text3", "url-a", "title-a"), - ] - mock_validate_question.return_value = True - mock_summarize.return_value = SummarizerResponse( - response="some response", - rag_chunks=mock_rag_chunk, - history_truncated=False, - ) - llm_request = LLMRequest(query="some query") - response = ols.conversation_request(llm_request, auth) - - assert len(response.referenced_documents) == 2 - assert response.referenced_documents[0].docs_url == "url-b" - assert response.referenced_documents[0].title == "title-b" - assert response.referenced_documents[1].docs_url == "url-a" - assert response.referenced_documents[1].title == "title-a" - - @pytest.mark.usefixtures("_load_config") @patch( "ols.app.endpoints.ols.config.ols_config.query_validation_method", diff --git a/tests/unit/app/models/test_models.py b/tests/unit/app/models/test_models.py index 39d7108d..64520f7e 100644 --- a/tests/unit/app/models/test_models.py +++ b/tests/unit/app/models/test_models.py @@ -11,6 +11,7 @@ LivenessResponse, LLMRequest, LLMResponse, + RagChunk, ReadinessResponse, ReferencedDocument, StatusResponse, @@ -411,3 +412,23 @@ def test_cache_entries_to_history_no_response(): "human: what?", "ai: ", ] + + +def test_ref_docs_from_rag_chunks(): + """Test the ReferencedDocument model method `from_rag_chunks`.""" + # urls are unsorted to ensure there is not a hidden sorting + rag_chunk_1 = RagChunk("bla2", "url2", "title2") + rag_chunk_2 = RagChunk("bla1", "url1", "title1") + rag_chunk_3 = RagChunk("bla3", "url3", "title3") + rag_chunk_4 = RagChunk("bla2", "url2", "title2") # duplicated doc + + ref_docs = ReferencedDocument.from_rag_chunks( + [rag_chunk_1, rag_chunk_2, rag_chunk_3, rag_chunk_4] + ) + expected = [ + ReferencedDocument(docs_url="url2", title="title2"), + ReferencedDocument(docs_url="url1", title="title1"), + ReferencedDocument(docs_url="url3", title="title3"), + ] + + assert ref_docs == expected