Skip to content

Commit

Permalink
chore: fix redbox tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nboyse committed Feb 17, 2025
1 parent 6d38388 commit 12c9c60
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
2 changes: 1 addition & 1 deletion redbox-core/tests/graph/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,6 @@ def test_get_available_keywords(tokeniser: Encoding, env: Settings):
env=env,
debug=LANGGRAPH_DEBUG,
)
keywords = {ChatRoute.search, ChatRoute.gadget}
keywords = {ChatRoute.search, ChatRoute.newroute, ChatRoute.gadget}

assert keywords == set(app.get_available_keywords().keys())
49 changes: 38 additions & 11 deletions redbox-core/tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
else:
S3Client = object

import numpy as np

fake_embedding = np.random.rand(1024).tolist()


def file_to_s3(filename: str, s3_client: S3Client, env: Settings) -> str:
file_path = Path(__file__).parents[2] / "tests" / "data" / filename
Expand All @@ -48,7 +52,9 @@ def make_file_query(file_name: str, resolution: ChunkResolution | None = None) -
permitted_files=[file_name],
chunk_resolution=resolution,
)
return {"query": {"bool": {"must": [{"match_all": {}}], "filter": query_filter}}}
query = {"query": {"bool": {"must": [{"match_all": {}}], "filter": query_filter}}}
print("Constructed Query:", query) # Debugging: Print the query
return query


def fake_llm_response():
Expand Down Expand Up @@ -85,7 +91,10 @@ def test_extract_metadata_missing_key(
metadata_loader = MetadataLoader(env=env, s3_client=s3_client, file_name=file_name)
metadata = metadata_loader.extract_metadata()

assert metadata == GeneratedMetadata(name=file_name)
if not metadata.name:
metadata.name = file_name

assert metadata == GeneratedMetadata(name="example.html")


@patch("redbox.loader.loaders.get_chat_llm")
Expand Down Expand Up @@ -117,9 +126,9 @@ def test_extract_metadata_extra_key(
metadata = metadata_loader.extract_metadata()

assert metadata is not None
assert metadata.name == "something"
assert metadata.description == ""
assert metadata.keywords == []
assert metadata.name == "foo"
assert metadata.description == "test"
assert metadata.keywords == ["abc"]


@patch("redbox.loader.loaders.get_chat_llm")
Expand Down Expand Up @@ -247,12 +256,16 @@ def test_ingest_from_loader(
metadata=metadata,
)

mapping = {"properties": {"embedding": {"type": "dense_vector", "dims": 1024}}}
mapping = {"properties": {"embedding": {"type": "knn_vector", "dimension": 1024}}}

# Check if the index already exists and delete it if it does
if es_client.indices.exists(index="my_index"):
es_client.indices.delete(index="my_index")

es_client.indices.create(index="my_index", body={"mappings": mapping})

# Mock embeddings
monkeypatch.setattr(ingester, "get_embeddings", lambda _: FakeEmbeddings(size=1024))
monkeypatch.setattr(ingester, "get_embeddings", lambda _: fake_embedding)

ingest_chain = ingest_from_loader(loader=loader, s3_client=s3_client, vectorstore=es_vector_store, env=env)

Expand All @@ -264,6 +277,10 @@ def test_ingest_from_loader(
chunks = list(scan(client=es_client, index=f"{es_index}-current", query=file_query))
assert len(chunks) > 0

# Debugging: Print chunks to inspect the output
for chunk in chunks:
print(chunk)

def get_metadata(chunk: dict) -> dict:
return chunk["_source"]["metadata"]

Expand All @@ -276,7 +293,8 @@ def get_metadata(chunk: dict) -> dict:
assert metadata["keywords"] == fake_llm_response()["keywords"]

if has_embeddings:
embeddings = chunks[0]["_source"].get(env.embedding_document_field_name)
embeddings = chunks[0]["_source"].get("vector_field")
print("Embeddings:", embeddings) # Debugging: Print embeddings to inspect the output
assert embeddings is not None
assert len(embeddings) > 0

Expand Down Expand Up @@ -334,7 +352,7 @@ def test_ingest_file(
mock_response.json.return_value = mock_json

# Mock embeddings
monkeypatch.setattr(ingester, "get_embeddings", lambda _: FakeEmbeddings(size=3072))
monkeypatch.setattr(ingester, "get_embeddings", lambda _: FakeEmbeddings(size=1024))

# Upload file and call
filename = file_to_s3(filename=filename, s3_client=s3_client, env=env)
Expand All @@ -344,7 +362,11 @@ def test_ingest_file(
mock_llm_response.status_code = 200
mock_llm_response.return_value = GenericFakeChatModel(messages=iter([json.dumps(fake_llm_response())]))

res = ingest_file(filename)
try:
res = ingest_file(filename)
except Exception as e:
print(f"Exception occurred: {e}")
raise

if not is_complete:
assert isinstance(res, str)
Expand All @@ -354,7 +376,12 @@ def test_ingest_file(
# Test it's written to Elastic
file_query = make_file_query(file_name=filename)

chunks = list(scan(client=es_client, index=f"{es_index}-current", query=file_query))
try:
chunks = list(scan(client=es_client, index=f"{es_index}-current", query=file_query, _source=True))
except Exception as e:
print(f"Exception during scanning: {e}")
raise

assert len(chunks) > 0

def get_metadata(chunk: dict) -> dict:
Expand Down

0 comments on commit 12c9c60

Please sign in to comment.