diff --git a/backend/packages/agent-executor/agent_executor/permchain.py b/backend/packages/agent-executor/agent_executor/permchain.py index 4c37152d..3b997157 100644 --- a/backend/packages/agent-executor/agent_executor/permchain.py +++ b/backend/packages/agent-executor/agent_executor/permchain.py @@ -4,6 +4,7 @@ from langchain.schema.agent import AgentAction, AgentActionMessageLog, AgentFinish from langchain.schema.messages import AIMessage, AnyMessage, FunctionMessage +from langchain_core.language_models.base import LanguageModelLike from langchain.schema.runnable import ( Runnable, RunnableConfig, @@ -13,7 +14,12 @@ from langchain.tools import BaseTool from langgraph.channels import Topic from langgraph.checkpoint import BaseCheckpointSaver -from langgraph.pregel import Channel, Pregel, ReservedChannels +from langgraph.graph.message import MessageGraph +from langgraph.graph import END +from langgraph.prebuilt import ToolExecutor +from langgraph.prebuilt import ToolInvocation + + def _create_agent_message( @@ -73,54 +79,75 @@ async def _arun_tool( def get_agent_executor( tools: list[BaseTool], - agent: Runnable[dict[str, list[AnyMessage]], AgentAction | AgentFinish], + llm: LanguageModelLike, checkpoint: BaseCheckpointSaver, -) -> Pregel: - tool_map = {tool.name: tool for tool in tools} - tool_lambda = RunnableLambda(_run_tool, _arun_tool).bind(tools=tool_map) - - tool_chain = itemgetter("messages") | tool_lambda | Channel.write_to("messages") - agent_chain = agent | _create_agent_message | Channel.write_to("messages") - - def route_last_message(input: dict[str, bool | Sequence[AnyMessage]]) -> Runnable: - if not input["messages"]: - # no messages, do nothing - return agent_chain - - message: AnyMessage = input["messages"][-1] - if isinstance(message.additional_kwargs.get("agent"), AgentFinish): - # finished, do nothing - return RunnablePassthrough() - - if input[ReservedChannels.is_last_step]: - # exhausted iterations without finishing, return stop message - return Channel.write_to( - messages=_create_agent_message( - AgentFinish( - { - "output": "Agent stopped due to iteration limit or time limit." - }, - "", - ) - ) - ) - - if isinstance(message.additional_kwargs.get("agent"), AgentAction): - # agent action, run it - return tool_chain - - # otherwise, run the agent - return agent_chain - - executor = ( - Channel.subscribe_to(["messages"]).join([ReservedChannels.is_last_step]) - | route_last_message +): + tool_executor = ToolExecutor(tools) + + # Define the function that determines whether to continue or not + def should_continue(messages): + last_message = messages[-1] + # If there is no function call, then we finish + if "function_call" not in last_message.additional_kwargs: + return "end" + # Otherwise if there is, we continue + else: + return "continue" + + # Define the function to execute tools + async def call_tool(messages): + # Based on the continue condition + # we know the last message involves a function call + last_message = messages[-1] + # We construct an ToolInvocation from the function_call + action = ToolInvocation( + tool=last_message.additional_kwargs["function_call"]["name"], + tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]), + ) + # We call the tool_executor and get back a response + response = await tool_executor.ainvoke(action) + # We use the response to create a FunctionMessage + function_message = FunctionMessage(content=str(response), name=action.tool) + # We return a list, because this will get added to the existing list + return function_message + + workflow = MessageGraph() + + # Define the two nodes we will cycle between + workflow.add_node("agent", llm) + workflow.add_node("action", call_tool) + + # Set the entrypoint as `agent` + # This means that this node is the first one called + workflow.set_entry_point("agent") + + # We now add a conditional edge + workflow.add_conditional_edges( + # First, we define the start node. We use `agent`. + # This means these are the edges taken after the `agent` node is called. + "agent", + # Next, we pass in the function that will determine which node is called next. + should_continue, + # Finally we pass in a mapping. + # The keys are strings, and the values are other nodes. + # END is a special node marking that the graph should finish. + # What will happen is we will call `should_continue`, and then the output of that + # will be matched against the keys in this mapping. + # Based on which one it matches, that node will then be called. + { + # If `tools`, then we call the tool node. + "continue": "action", + # Otherwise we finish. + "end": END + } ) - return Pregel( - nodes={"executor": executor}, - channels={"messages": Topic(AnyMessage, accumulate=True)}, - input=["messages"], - output=["messages"], - checkpointer=checkpoint, - ) + # We now add a normal edge from `tools` to `agent`. + # This means that after `tools` is called, `agent` node is called next. + workflow.add_edge('action', 'agent') + + # Finally, we compile it! + # This compiles it into a LangChain Runnable, + # meaning you can use it as you would any other runnable + app = workflow.compile(checkpointer=checkpoint) + return app diff --git a/backend/packages/gizmo-agent/gizmo_agent/agent_types/openai.py b/backend/packages/gizmo-agent/gizmo_agent/agent_types/openai.py index 6dc0e9ac..9923d320 100644 --- a/backend/packages/gizmo-agent/gizmo_agent/agent_types/openai.py +++ b/backend/packages/gizmo-agent/gizmo_agent/agent_types/openai.py @@ -4,6 +4,7 @@ from langchain_openai import AzureChatOpenAI, ChatOpenAI from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.tools.render import format_tool_to_openai_function +from langchain_core.messages import SystemMessage def get_openai_function_agent( @@ -23,17 +24,15 @@ def get_openai_function_agent( openai_api_key=os.environ["AZURE_OPENAI_API_KEY"], streaming=True, ) - prompt = ChatPromptTemplate.from_messages( - [ - ("system", system_message), - MessagesPlaceholder(variable_name="messages"), - ] - ) + + def _get_messages(messages): + return [SystemMessage(content=system_message)] + messages + if tools: llm_with_tools = llm.bind( functions=[format_tool_to_openai_function(t) for t in tools] ) else: llm_with_tools = llm - agent = prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser() + agent = _get_messages | llm_with_tools return agent diff --git a/backend/packages/gizmo-agent/gizmo_agent/main.py b/backend/packages/gizmo-agent/gizmo_agent/main.py index 7df3add3..89480c29 100644 --- a/backend/packages/gizmo-agent/gizmo_agent/main.py +++ b/backend/packages/gizmo-agent/gizmo_agent/main.py @@ -71,7 +71,7 @@ def __init__( else: raise ValueError("Unexpected agent type") agent_executor = get_agent_executor( - tools=_tools, agent=_agent, checkpoint=RedisCheckpoint() + tools=_tools, llm=_agent, checkpoint=RedisCheckpoint() ).with_config({"recursion_limit": 10}) super().__init__( tools=tools, @@ -152,9 +152,9 @@ class AgentOutput(BaseModel): from langchain.schema.messages import HumanMessage async def run(): - async for m in agent.astream_log( - {"messages": HumanMessage(content="whats your name")}, + async for m in agent.astream_events(HumanMessage(content="whats your name"), config={"configurable": {"user_id": "1", "thread_id": "test1"}}, + version="v1", ): print(m)