Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 committed Jan 25, 2024
1 parent eb88e7d commit bd59e7f
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 59 deletions.
125 changes: 76 additions & 49 deletions backend/packages/agent-executor/agent_executor/permchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from langchain.schema.agent import AgentAction, AgentActionMessageLog, AgentFinish
from langchain.schema.messages import AIMessage, AnyMessage, FunctionMessage
from langchain_core.language_models.base import LanguageModelLike
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
Expand All @@ -13,7 +14,12 @@
from langchain.tools import BaseTool
from langgraph.channels import Topic
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.pregel import Channel, Pregel, ReservedChannels
from langgraph.graph.message import MessageGraph
from langgraph.graph import END
from langgraph.prebuilt import ToolExecutor
from langgraph.prebuilt import ToolInvocation




def _create_agent_message(
Expand Down Expand Up @@ -73,54 +79,75 @@ async def _arun_tool(

def get_agent_executor(
tools: list[BaseTool],
agent: Runnable[dict[str, list[AnyMessage]], AgentAction | AgentFinish],
llm: LanguageModelLike,
checkpoint: BaseCheckpointSaver,
) -> Pregel:
tool_map = {tool.name: tool for tool in tools}
tool_lambda = RunnableLambda(_run_tool, _arun_tool).bind(tools=tool_map)

tool_chain = itemgetter("messages") | tool_lambda | Channel.write_to("messages")
agent_chain = agent | _create_agent_message | Channel.write_to("messages")

def route_last_message(input: dict[str, bool | Sequence[AnyMessage]]) -> Runnable:
if not input["messages"]:
# no messages, do nothing
return agent_chain

message: AnyMessage = input["messages"][-1]
if isinstance(message.additional_kwargs.get("agent"), AgentFinish):
# finished, do nothing
return RunnablePassthrough()

if input[ReservedChannels.is_last_step]:
# exhausted iterations without finishing, return stop message
return Channel.write_to(
messages=_create_agent_message(
AgentFinish(
{
"output": "Agent stopped due to iteration limit or time limit."
},
"",
)
)
)

if isinstance(message.additional_kwargs.get("agent"), AgentAction):
# agent action, run it
return tool_chain

# otherwise, run the agent
return agent_chain

executor = (
Channel.subscribe_to(["messages"]).join([ReservedChannels.is_last_step])
| route_last_message
):
tool_executor = ToolExecutor(tools)

# Define the function that determines whether to continue or not
def should_continue(messages):
last_message = messages[-1]
# If there is no function call, then we finish
if "function_call" not in last_message.additional_kwargs:
return "end"
# Otherwise if there is, we continue
else:
return "continue"

# Define the function to execute tools
async def call_tool(messages):
# Based on the continue condition
# we know the last message involves a function call
last_message = messages[-1]
# We construct an ToolInvocation from the function_call
action = ToolInvocation(
tool=last_message.additional_kwargs["function_call"]["name"],
tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]),
)
# We call the tool_executor and get back a response
response = await tool_executor.ainvoke(action)
# We use the response to create a FunctionMessage
function_message = FunctionMessage(content=str(response), name=action.tool)
# We return a list, because this will get added to the existing list
return function_message

workflow = MessageGraph()

# Define the two nodes we will cycle between
workflow.add_node("agent", llm)
workflow.add_node("action", call_tool)

# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")

# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END
}
)

return Pregel(
nodes={"executor": executor},
channels={"messages": Topic(AnyMessage, accumulate=True)},
input=["messages"],
output=["messages"],
checkpointer=checkpoint,
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge('action', 'agent')

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile(checkpointer=checkpoint)
return app
13 changes: 6 additions & 7 deletions backend/packages/gizmo-agent/gizmo_agent/agent_types/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools.render import format_tool_to_openai_function
from langchain_core.messages import SystemMessage


def get_openai_function_agent(
Expand All @@ -23,17 +24,15 @@ def get_openai_function_agent(
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
streaming=True,
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_message),
MessagesPlaceholder(variable_name="messages"),
]
)

def _get_messages(messages):
return [SystemMessage(content=system_message)] + messages

if tools:
llm_with_tools = llm.bind(
functions=[format_tool_to_openai_function(t) for t in tools]
)
else:
llm_with_tools = llm
agent = prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser()
agent = _get_messages | llm_with_tools
return agent
6 changes: 3 additions & 3 deletions backend/packages/gizmo-agent/gizmo_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
else:
raise ValueError("Unexpected agent type")
agent_executor = get_agent_executor(
tools=_tools, agent=_agent, checkpoint=RedisCheckpoint()
tools=_tools, llm=_agent, checkpoint=RedisCheckpoint()
).with_config({"recursion_limit": 10})
super().__init__(
tools=tools,
Expand Down Expand Up @@ -152,9 +152,9 @@ class AgentOutput(BaseModel):
from langchain.schema.messages import HumanMessage

async def run():
async for m in agent.astream_log(
{"messages": HumanMessage(content="whats your name")},
async for m in agent.astream_events(HumanMessage(content="whats your name"),
config={"configurable": {"user_id": "1", "thread_id": "test1"}},
version="v1",
):
print(m)

Expand Down

0 comments on commit bd59e7f

Please sign in to comment.