From 45fb2788f5e8f9f9900bcf2175071c8ebeb8e5ab Mon Sep 17 00:00:00 2001 From: Natasha Boyse Date: Thu, 13 Feb 2025 15:22:32 +0000 Subject: [PATCH] revert changes to retriever --- redbox-core/redbox/retriever/retrievers.py | 25 +++++++++------------- redbox-core/tests/test_ingest.py | 13 ++++++++++- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/redbox-core/redbox/retriever/retrievers.py b/redbox-core/redbox/retriever/retrievers.py index 9886d196..d29caefd 100644 --- a/redbox-core/redbox/retriever/retrievers.py +++ b/redbox-core/redbox/retriever/retrievers.py @@ -233,22 +233,17 @@ def __init__(self, es_client: Union[Elasticsearch, OpenSearch], **kwargs: Any) - def _get_relevant_documents( self, query: RedboxState, *, run_manager: CallbackManagerForRetrieverRun - ) -> list[Document]: - body = self.body_func(query) # Generate the query - print(f"Executing OpenSearch query: {body}") - - # Run a normal search instead of scan() - response = self.es_client.search(index=self.index_name, body=body) - - print(f"Raw OpenSearch response: {response}") - - hits = response.get("hits", {}).get("hits", []) # Extract documents - for hit in hits: - print(f"Hit: {hit}") + ) -> list[Document]: # noqa:ARG002 + # if not self.es_client or not self.document_mapper: + # msg = "faulty configuration" + # raise ValueError(msg) # should not happen - print(f"Retrieved {len(hits)} documents from OpenSearch") + body = self.body_func(query) # type: ignore - results = [self.document_mapper(hit) for hit in hits] + results = [ + self.document_mapper(hit) + for hit in scan(client=self.es_client, index=self.index_name, query=body, _source=True) + ] return sorted(results, key=lambda result: result.metadata["index"]) @@ -281,4 +276,4 @@ def _get_relevant_documents( for hit in scan(client=self.es_client, index=self.index_name, query=body, _source=True) ] - return sorted(results, key=lambda result: result.metadata["index"]) + return sorted(results, key=lambda result: result.metadata["index"]) \ No newline at end of file diff --git a/redbox-core/tests/test_ingest.py b/redbox-core/tests/test_ingest.py index 9477d9ce..38843507 100644 --- a/redbox-core/tests/test_ingest.py +++ b/redbox-core/tests/test_ingest.py @@ -247,8 +247,19 @@ def test_ingest_from_loader( metadata=metadata, ) + mapping = { + "properties": { + "embedding": { + "type": "dense_vector", + "dims": 1024 + } + } + } + + es_client.indices.create(index="my_index", body={"mappings": mapping}) + # Mock embeddings - monkeypatch.setattr(ingester, "get_embeddings", lambda _: FakeEmbeddings(size=3072)) + monkeypatch.setattr(ingester, "get_embeddings", lambda _: FakeEmbeddings(size=1024)) ingest_chain = ingest_from_loader(loader=loader, s3_client=s3_client, vectorstore=es_vector_store, env=env)