Skip to content

Commit

Permalink
revert changes to retriever
Browse files Browse the repository at this point in the history
  • Loading branch information
nboyse committed Feb 13, 2025
1 parent 951e079 commit 45fb278
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
25 changes: 10 additions & 15 deletions redbox-core/redbox/retriever/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,22 +233,17 @@ def __init__(self, es_client: Union[Elasticsearch, OpenSearch], **kwargs: Any) -

def _get_relevant_documents(
self, query: RedboxState, *, run_manager: CallbackManagerForRetrieverRun
) -> list[Document]:
body = self.body_func(query) # Generate the query
print(f"Executing OpenSearch query: {body}")

# Run a normal search instead of scan()
response = self.es_client.search(index=self.index_name, body=body)

print(f"Raw OpenSearch response: {response}")

hits = response.get("hits", {}).get("hits", []) # Extract documents
for hit in hits:
print(f"Hit: {hit}")
) -> list[Document]: # noqa:ARG002
# if not self.es_client or not self.document_mapper:
# msg = "faulty configuration"
# raise ValueError(msg) # should not happen

print(f"Retrieved {len(hits)} documents from OpenSearch")
body = self.body_func(query) # type: ignore

results = [self.document_mapper(hit) for hit in hits]
results = [
self.document_mapper(hit)
for hit in scan(client=self.es_client, index=self.index_name, query=body, _source=True)
]

return sorted(results, key=lambda result: result.metadata["index"])

Expand Down Expand Up @@ -281,4 +276,4 @@ def _get_relevant_documents(
for hit in scan(client=self.es_client, index=self.index_name, query=body, _source=True)
]

return sorted(results, key=lambda result: result.metadata["index"])
return sorted(results, key=lambda result: result.metadata["index"])
13 changes: 12 additions & 1 deletion redbox-core/tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,19 @@ def test_ingest_from_loader(
metadata=metadata,
)

mapping = {
"properties": {
"embedding": {
"type": "dense_vector",
"dims": 1024
}
}
}

es_client.indices.create(index="my_index", body={"mappings": mapping})

# Mock embeddings
monkeypatch.setattr(ingester, "get_embeddings", lambda _: FakeEmbeddings(size=3072))
monkeypatch.setattr(ingester, "get_embeddings", lambda _: FakeEmbeddings(size=1024))

ingest_chain = ingest_from_loader(loader=loader, s3_client=s3_client, vectorstore=es_vector_store, env=env)

Expand Down

0 comments on commit 45fb278

Please sign in to comment.