Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/self route toggle #93

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,5 @@ REDBOX_API_KEY = myapi
# AUTHBROKER_CLIENT_SECRET=REPLACE_WITH_GITLAB_SECRET
# AUTHBROKER_URL=https://sso.trade.gov.uk

ENABLE_METADATA_EXTRACTION = True
ENABLE_METADATA_EXTRACTION = True
SELF_ROUTE_ENABLED = False
39 changes: 37 additions & 2 deletions redbox-core/redbox/graph/edges.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
import re
from typing import Literal
from typing import Any, Literal

from langchain_core.runnables import Runnable
from langchain_core.runnables import Runnable, RunnableLambda

from redbox.chains.components import get_tokeniser
from redbox.graph.nodes.processes import PromptSet
from redbox.models import ChatRoute
from redbox.models.chain import RedboxState, get_prompts
from redbox.models.graph import ROUTE_NAME_TAG
from redbox.transform import get_document_token_count

log = logging.getLogger()
Expand Down Expand Up @@ -48,6 +49,40 @@ def _total_tokens_request_handler_conditional(
return _total_tokens_request_handler_conditional


def has_exceed_max_limit(state: RedboxState) -> Literal["max_exceeded", "pass"]:
max_tokens_allowed = state.request.ai_settings.max_document_tokens

total_tokens = state.metadata.selected_files_total_tokens

if total_tokens > max_tokens_allowed:
return "max_exceeded"
else:
return "pass"


def set_route_based_on_token_limit(prompt_set: PromptSet) -> Runnable[RedboxState, dict[str, Any]]:
"""
Set route
Uses a set of prompts to calculate the total tokens used in this request and returns a label
for the request handler to be used
"""

def _set_route_based_on_token_limit(
state: RedboxState,
) -> dict[str, Any]:
system_prompt, question_prompt = get_prompts(state, prompt_set)
token_budget_remaining_in_context = calculate_token_budget(state, system_prompt, question_prompt)

total_tokens = state.metadata.selected_files_total_tokens

if total_tokens > token_budget_remaining_in_context:
return {"route_name": ChatRoute.chat_with_docs_map_reduce}
else:
return {"route_name": ChatRoute.chat_with_docs}

return RunnableLambda(_set_route_based_on_token_limit).with_config(tags=[ROUTE_NAME_TAG])


def build_documents_bigger_than_context_conditional(prompt_set: PromptSet) -> Runnable:
"""Uses a set of prompts to build the correct conditional for exceeding the context window."""

Expand Down
29 changes: 27 additions & 2 deletions redbox-core/redbox/graph/nodes/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import re
import textwrap
import time
from random import uniform
from collections.abc import Callable
from functools import reduce
from random import uniform
from typing import Any, Iterable
from uuid import uuid4

from botocore.exceptions import EventStreamError

from langchain.schema import StrOutputParser
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.documents import Document
Expand Down Expand Up @@ -193,6 +192,10 @@ def _set_route(state: RedboxState) -> dict[str, Any]:
return RunnableLambda(_set_route).with_config(tags=[ROUTE_NAME_TAG])


def llm_cannot_answer_question(state: RedboxState):
return "unanswerable" in state.last_message.content.lower()


def build_set_self_route_from_llm_answer(
conditional: Callable[[str], bool],
true_condition_state_update: dict,
Expand All @@ -201,6 +204,28 @@ def build_set_self_route_from_llm_answer(
) -> Runnable[RedboxState, dict[str, Any]]:
"""A Runnable which sets the route based on a conditional on state['text']"""

@RunnableLambda
def _set_self_route_from_llm_answer(state: RedboxState):
llm_response = state.last_message.content
if conditional(llm_response):
return true_condition_state_update
else:
return false_condition_state_update

runnable = _set_self_route_from_llm_answer
if true_condition_state_update:
runnable = _set_self_route_from_llm_answer.with_config(tags=[ROUTE_NAME_TAG])
return runnable


def build_set_route_from_condition(
conditional: Callable[[str], bool],
true_condition_state_update: dict,
false_condition_state_update: dict,
final_route_response: bool = True,
) -> Runnable[RedboxState, dict[str, Any]]:
"""A Runnable which sets the route based on a conditional on state['text']"""

@RunnableLambda
def _set_self_route_from_llm_answer(state: RedboxState):
llm_response = state.last_message.content
Expand Down
73 changes: 26 additions & 47 deletions redbox-core/redbox/graph/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from langgraph.graph import END, START, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import ToolNode
from langgraph.pregel import RetryPolicy

from redbox.chains.components import get_structured_response_with_citations_parser
from redbox.chains.runnables import build_self_route_output_parser
from redbox.graph.edges import (
build_documents_bigger_than_context_conditional,
build_keyword_detection_conditional,
build_total_tokens_request_handler_conditional,
documents_selected_conditional,
has_exceed_max_limit,
multiple_docs_in_group_conditional,
set_route_based_on_token_limit,
)
from redbox.graph.nodes.processes import (
PromptSet,
Expand All @@ -25,10 +27,10 @@
build_retrieve_pattern,
build_set_metadata_pattern,
build_set_route_pattern,
build_set_self_route_from_llm_answer,
build_stuff_pattern,
clear_documents_process,
empty_process,
llm_cannot_answer_question,
report_sources_process,
)
from redbox.graph.nodes.sends import build_document_chunk_send, build_document_group_send, build_tool_send
Expand All @@ -37,14 +39,13 @@
from redbox.models.chat import ChatRoute, ErrorRoute
from redbox.models.graph import ROUTABLE_KEYWORDS, RedboxActivityEvent
from redbox.transform import structure_documents_by_file_name, structure_documents_by_group_and_indices
from langgraph.pregel import RetryPolicy


def get_self_route_graph(retriever: VectorStoreRetriever, prompt_set: PromptSet, debug: bool = False):
builder = StateGraph(RedboxState)

def self_route_question_is_unanswerable(llm_response: str):
return "unanswerable" in llm_response
return "unanswerable" in llm_response.lower()

# Processes
builder.add_node("p_condense_question", build_chat_pattern(prompt_set=PromptSet.CondenseQuestion))
Expand All @@ -68,29 +69,20 @@ def self_route_question_is_unanswerable(llm_response: str):
final_response_chain=False,
),
)
builder.add_node(
"p_set_route_name_from_answer",
build_set_self_route_from_llm_answer(
self_route_question_is_unanswerable,
true_condition_state_update={"route_name": ChatRoute.chat_with_docs_map_reduce},
false_condition_state_update={"route_name": ChatRoute.search},
),
)

builder.add_node("p_set_search_route", build_set_route_pattern(ChatRoute.search))
builder.add_node("p_clear_documents", clear_documents_process)

# Edges
builder.add_edge(START, "p_condense_question")
builder.add_edge("p_condense_question", "p_retrieve_docs")
builder.add_edge("p_retrieve_docs", "p_answer_question_or_decide_unanswerable")
builder.add_edge("p_answer_question_or_decide_unanswerable", "p_set_route_name_from_answer")
builder.add_conditional_edges(
"p_set_route_name_from_answer",
lambda state: state.route_name,
{
ChatRoute.chat_with_docs_map_reduce: "p_clear_documents",
ChatRoute.search: END,
},
"p_answer_question_or_decide_unanswerable",
llm_cannot_answer_question,
{True: "p_clear_documents", False: "p_set_search_route"},
)
builder.add_edge("p_set_search_route", END)
builder.add_edge("p_clear_documents", END)

return builder.compile(debug=debug)
Expand Down Expand Up @@ -158,9 +150,6 @@ def get_agentic_search_graph(tools: List[StructuredTool], debug: bool = False) -

citations_output_parser, format_instructions = get_structured_response_with_citations_parser()
builder = StateGraph(RedboxState)
# Tools
# agent_tool_names = ["_search_documents", "_search_wikipedia", "_search_govuk"]
# agent_tools: list[StructuredTool] = [tools[tool_name] for tool_name in agent_tool_names]

# Processes
builder.add_node("p_set_agentic_search_route", build_set_route_pattern(route=ChatRoute.gadget))
Expand Down Expand Up @@ -232,11 +221,6 @@ def get_chat_with_documents_graph(

# Processes
builder.add_node("p_pass_question_to_text", build_passthrough_pattern())
builder.add_node("p_set_chat_docs_route", build_set_route_pattern(route=ChatRoute.chat_with_docs))
builder.add_node(
"p_set_chat_docs_map_reduce_route",
build_set_route_pattern(route=ChatRoute.chat_with_docs_map_reduce),
)
builder.add_node(
"p_summarise_each_document",
build_merge_pattern(prompt_set=PromptSet.ChatwithDocsMapReduce),
Expand Down Expand Up @@ -264,7 +248,7 @@ def get_chat_with_documents_graph(
),
)
builder.add_node(
"p_answer_or_decide_route",
"can_RAG_answer_question",
get_self_route_graph(parameterised_retriever, PromptSet.SelfRoute),
)
builder.add_node(
Expand All @@ -280,12 +264,12 @@ def get_chat_with_documents_graph(
"p_activity_log_tool_decision",
build_activity_log_node(lambda state: RedboxActivityEvent(message=f"Using _{state.route_name}_")),
)

builder.add_node("set_route_based_on_token_limit", set_route_based_on_token_limit(PromptSet.ChatwithDocsMapReduce))
# Decisions
builder.add_node("d_request_handler_from_total_tokens", empty_process)
builder.add_node("d_single_doc_summaries_bigger_than_context", empty_process)
builder.add_node("d_doc_summaries_bigger_than_context", empty_process)
builder.add_node("d_groups_have_multiple_docs", empty_process)
builder.add_node("d_exceed_max_limit", empty_process)
builder.add_node("d_self_route_is_enabled", empty_process)

# Sends
Expand All @@ -295,32 +279,27 @@ def get_chat_with_documents_graph(

# Edges
builder.add_edge(START, "p_pass_question_to_text")
builder.add_edge("p_pass_question_to_text", "d_request_handler_from_total_tokens")
builder.add_edge("p_pass_question_to_text", "d_exceed_max_limit")
builder.add_conditional_edges(
"d_request_handler_from_total_tokens",
build_total_tokens_request_handler_conditional(PromptSet.ChatwithDocsMapReduce),
{
"max_exceeded": "p_too_large_error",
"context_exceeded": "d_self_route_is_enabled",
"pass": "p_set_chat_docs_route",
},
"d_exceed_max_limit",
has_exceed_max_limit,
{"max_exceeded": "p_too_large_error", "pass": "d_self_route_is_enabled"},
)

builder.add_conditional_edges(
"d_self_route_is_enabled",
lambda s: s.request.ai_settings.self_route_enabled,
{True: "p_answer_or_decide_route", False: "p_set_chat_docs_map_reduce_route"},
{True: "can_RAG_answer_question", False: "p_retrieve_all_chunks"},
then="p_activity_log_tool_decision",
)
builder.add_conditional_edges(
"p_answer_or_decide_route",
lambda state: state.route_name,
{
ChatRoute.search: END,
ChatRoute.chat_with_docs_map_reduce: "p_retrieve_all_chunks",
},
"can_RAG_answer_question",
lambda state: state.route_name == ChatRoute.search,
{True: END, False: "set_route_based_on_token_limit"},
)
builder.add_edge("p_set_chat_docs_route", "p_retrieve_all_chunks")
builder.add_edge("p_set_chat_docs_map_reduce_route", "p_retrieve_all_chunks")

builder.add_edge("set_route_based_on_token_limit", "p_retrieve_all_chunks")

builder.add_conditional_edges(
"p_retrieve_all_chunks",
lambda s: s.route_name,
Expand Down
8 changes: 7 additions & 1 deletion redbox-core/redbox/models/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Annotated, Literal, NotRequired, Required, TypedDict, get_args, get_origin
from uuid import UUID, uuid4

import environ
from dotenv import load_dotenv
from langchain_core.documents import Document
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
Expand All @@ -14,6 +16,10 @@
from redbox.models import prompts
from redbox.models.settings import ChatLLMBackend

load_dotenv()

env = environ.Env()


class ChainChatMessage(TypedDict):
role: Literal["user", "ai", "system"]
Expand All @@ -29,7 +35,7 @@ class AISettings(BaseModel):

# Prompts and LangGraph settings
max_document_tokens: int = 1_000_000
self_route_enabled: bool = False
self_route_enabled: bool = env.bool("SELF_ROUTE_ENABLED", default=False)
map_max_concurrency: int = 128
stuff_chunk_context_ratio: float = 0.75
recursion_limit: int = 50
Expand Down
Loading