Skip to content

Commit

Permalink
correcting tokeniser
Browse files Browse the repository at this point in the history
  • Loading branch information
nboyse committed Feb 18, 2025
1 parent 9bfeb7f commit d1cfecb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
47 changes: 24 additions & 23 deletions redbox-core/redbox/graph/nodes/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import requests
import anthropic
from elasticsearch import Elasticsearch
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_core.documents import Document
Expand All @@ -23,15 +22,9 @@
from redbox.transform import merge_documents, sort_documents


def get_tokeniser_for_claude():
client = anthropic.Anthropic()
model_name = "claude-3-sonnet-20240229-v1:0"

def count_tokens(text):
response = client.messages.count_tokens(model=model_name, messages=[{"role": "user", "content": text}])
return response.json()["input_tokens"]

return count_tokens
def bedrock_tokeniser(text: str) -> int:
# Simple tokenizer that counts the number of words in the text
return len(text.split())


def build_search_documents_tool(
Expand Down Expand Up @@ -100,7 +93,7 @@ def _search_documents(query: str, state: Annotated[RedboxState, InjectedState])
def build_govuk_search_tool(filter=True) -> Tool:
"""Constructs a tool that searches gov.uk and sets state["documents"]."""

tokeniser = get_tokeniser_for_claude()
tokeniser = bedrock_tokeniser

def recalculate_similarity(response, query, num_results):
embedding_model = get_embeddings(get_settings())
Expand Down Expand Up @@ -185,7 +178,7 @@ def build_search_wikipedia_tool(number_wikipedia_results=1, max_chars_per_wiki_p
top_k_results=number_wikipedia_results,
doc_content_chars_max=max_chars_per_wiki_page,
)
tokeniser = get_tokeniser_for_claude()
tokeniser = bedrock_tokeniser

@tool(response_format="content_and_artifact")
def _search_wikipedia(query: str) -> tuple[str, list[Document]]:
Expand All @@ -202,18 +195,26 @@ def _search_wikipedia(query: str) -> tuple[str, list[Document]]:
response (str): The content of the relevant Wikipedia page
"""
response = _wikipedia_wrapper.load(query)
mapped_documents = [
Document(
page_content=doc.page_content,
metadata=ChunkMetadata(
index=i,
uri=doc.metadata["source"],
token_count=tokeniser(doc.page_content),
creator_type=ChunkCreatorType.wikipedia,
).model_dump(),
if not response:
print("No Wikipedia response found.")
return "", []

mapped_documents = []
for i, doc in enumerate(response):
token_count = tokeniser(doc.page_content)
print(f"Document {i} token count: {token_count}")

mapped_documents.append(
Document(
page_content=doc.page_content,
metadata=ChunkMetadata(
index=i,
uri=doc.metadata["source"],
token_count=token_count,
creator_type=ChunkCreatorType.wikipedia,
).model_dump(),
)
)
for i, doc in enumerate(response)
]
docs = mapped_documents
return format_documents(docs), docs

Expand Down
6 changes: 3 additions & 3 deletions redbox-core/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ def test_wikipedia_tool():
]
}
)

assert response["messages"][0].content != ""
assert response["messages"][0].artifact is not None

for document in response["messages"][0].artifact:
assert document.page_content != ""
metadata = ChunkMetadata.model_validate(document.metadata)
assert urlparse(metadata.uri).hostname == "en.wikipedia.org"
assert metadata.creator_type == ChunkCreatorType.wikipedia
assert "token_count" in document.metadata


@pytest.mark.parametrize(
Expand Down

0 comments on commit d1cfecb

Please sign in to comment.