Skip to content

Commit

Permalink
add llm choose route node
Browse files Browse the repository at this point in the history
  • Loading branch information
Saisakul Chernbumroong authored and Saisakul Chernbumroong committed Feb 27, 2025
1 parent d421e30 commit c7d6d98
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 16 deletions.
12 changes: 10 additions & 2 deletions redbox-core/redbox/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_parameterised_retriever,
)
from redbox.graph.nodes.tools import build_govuk_search_tool, build_search_documents_tool, build_search_wikipedia_tool
from redbox.graph.root import get_agentic_search_graph, get_chat_with_documents_graph, get_root_graph
from redbox.graph.root import get_agentic_search_graph, get_chat_with_documents_graph, new_root_graph
from redbox.models.chain import RedboxState
from redbox.models.chat import ChatRoute
from redbox.models.file import ChunkResolution
Expand Down Expand Up @@ -65,7 +65,15 @@ def __init__(
search_govuk = build_govuk_search_tool()

self.tools = [search_documents, search_wikipedia, search_govuk]
self.graph = get_root_graph(
# self.graph = get_root_graph(
# all_chunks_retriever=self.all_chunks_retriever,
# parameterised_retriever=self.parameterised_retriever,
# metadata_retriever=self.metadata_retriever,
# tools=self.tools,
# debug=debug,
# )

self.graph = new_root_graph(
all_chunks_retriever=self.all_chunks_retriever,
parameterised_retriever=self.parameterised_retriever,
metadata_retriever=self.metadata_retriever,
Expand Down
26 changes: 25 additions & 1 deletion redbox-core/redbox/chains/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from redbox.api.format import format_documents
from redbox.chains.activity import log_activity
from redbox.chains.components import get_tokeniser
from redbox.chains.components import get_chat_llm, get_tokeniser
from redbox.chains.parser import ClaudeParser
from redbox.models.chain import ChainChatMessage, PromptSet, RedboxState, get_prompts
from redbox.models.errors import QuestionLengthError
from redbox.models.graph import RedboxEventType
Expand Down Expand Up @@ -263,3 +264,26 @@ def _identifying_params(self) -> dict[str, Any]:
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model. Used for logging purposes only."""
return "canned"


def basic_chat_chain(system_prompt, _additional_variables: dict = {}, parser=None):
@chain
def _basic_chat_chain(state: RedboxState):
nonlocal parser
llm = get_chat_llm(state.request.ai_settings.chat_backend)
context = {
"question": state.request.question,
} | _additional_variables

if parser:
format_instructions = parser.get_format_instructions()
prompt = ChatPromptTemplate(
[(system_prompt)], partial_variables={"format_instructions": format_instructions}
)
else:
prompt = ChatPromptTemplate([(system_prompt)])
parser = ClaudeParser()
chain = prompt | llm | parser
return chain.invoke(context)

return _basic_chat_chain
25 changes: 16 additions & 9 deletions redbox-core/redbox/graph/nodes/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,22 @@
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable, RunnableLambda, RunnableParallel, chain
from langchain_core.runnables import (Runnable, RunnableLambda,
RunnableParallel)
from langchain_core.tools import StructuredTool
from langchain_core.vectorstores import VectorStoreRetriever

from redbox.chains.activity import log_activity
from redbox.chains.components import get_basic_metadata_retriever, get_chat_llm, get_tokeniser
from redbox.chains.components import (get_basic_metadata_retriever,
get_chat_llm, get_tokeniser)
from redbox.chains.parser import ClaudeParser
from redbox.chains.runnables import CannedChatLLM, basic_chat_chain, build_llm_chain
from redbox.chains.runnables import (CannedChatLLM, basic_chat_chain,
build_llm_chain)
from redbox.models import ChatRoute
from redbox.models.chain import DocumentState, PromptSet, RedboxState, RequestMetadata
from redbox.models.graph import ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG, RedboxActivityEvent, RedboxEventType
from redbox.models.chain import (DocumentState, PromptSet, RedboxState,
RequestMetadata)
from redbox.models.graph import (ROUTE_NAME_TAG, SOURCE_DOCUMENTS_TAG,
RedboxActivityEvent, RedboxEventType)
from redbox.models.settings import get_settings
from redbox.transform import combine_documents, flatten_document_state

Expand Down Expand Up @@ -339,21 +344,21 @@ def _activity_log_node(state: RedboxState):
return _activity_log_node


def lm_choose_route(parser: ClaudeParser):
def lm_choose_route(state: RedboxState, parser: ClaudeParser):
"""
LLM choose the route (search/summarise) based on user question and file metadata
"""
metadata = None

@chain
@RunnableLambda
def get_metadata(state: RedboxState):
nonlocal metadata
env = get_settings()
retriever = get_basic_metadata_retriever(env)
metadata = retriever.invoke(state)
return state

@chain
@RunnableLambda
def use_result(state: RedboxState):
chain = basic_chat_chain(
system_prompt=state.request.ai_settings.llm_decide_route_prompt,
Expand All @@ -362,4 +367,6 @@ def use_result(state: RedboxState):
)
return chain.invoke(state)

return get_metadata | use_result
chain = get_metadata | use_result
res = chain.invoke(state)
return res.next.value
13 changes: 9 additions & 4 deletions redbox-core/redbox/graph/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def get_summarise_graph():
def new_root_graph(all_chunks_retriever, parameterised_retriever, metadata_retriever, tools, debug):
agent_parser = ClaudeParser(pydantic_object=AgentDecision)

def lm_choose_route_wrapper(state: RedboxState):
return lm_choose_route(state, parser=agent_parser)

builder = StateGraph(RedboxState)
builder.add_node("has_keyword", empty_process)
builder.add_node(
Expand All @@ -65,8 +68,10 @@ def new_root_graph(all_chunks_retriever, parameterised_retriever, metadata_retri
"new_route_graph",
build_new_graph(all_chunks_retriever, parameterised_retriever, metadata_retriever, tools, debug),
)
builder.add_node("chat_graph", get_chat_graph(debug=debug))
builder.add_node("is_self_route_enabled", empty_process)
builder.add_node("llm_choose_route", lm_choose_route(parser=agent_parser))
builder.add_node("any_documents_selected", empty_process)
builder.add_node("llm_choose_route", empty_process)
builder.add_node("summarise_graph", get_summarise_graph())
builder.add_node(
"log_user_request",
Expand All @@ -80,7 +85,6 @@ def new_root_graph(all_chunks_retriever, parameterised_retriever, metadata_retri
]
),
)
builder.add_node("summarise_graph", get_summarise_graph())
# edges
builder.add_edge(START, "log_user_request")
builder.add_edge(START, "retrieve_metadata")
Expand All @@ -92,7 +96,7 @@ def new_root_graph(all_chunks_retriever, parameterised_retriever, metadata_retri
ChatRoute.search: "search_graph",
ChatRoute.gadget: "gadget_graph",
ChatRoute.newroute: "new_route_graph",
ChatRoute.summarise: "summarise_graph",
ChatRoute.chat_with_docs: "summarise_graph",
"DEFAULT": "any_documents_selected",
},
)
Expand All @@ -111,13 +115,14 @@ def new_root_graph(all_chunks_retriever, parameterised_retriever, metadata_retri
)
builder.add_conditional_edges(
"llm_choose_route",
lambda s: s.last_message.next.value,
lm_choose_route_wrapper,
{"search": "search_graph", "summarise": "summarise_graph"},
)
builder.add_edge("search_graph", END)
builder.add_edge("gadget_graph", END)
builder.add_edge("new_route_graph", END)
builder.add_edge("summarise_graph", END)
return builder.compile()


def get_self_route_graph(retriever: VectorStoreRetriever, prompt_set: PromptSet, debug: bool = False):
Expand Down

0 comments on commit c7d6d98

Please sign in to comment.