Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 committed Jan 26, 2024
1 parent bd59e7f commit d5bb47b
Show file tree
Hide file tree
Showing 31 changed files with 277 additions and 5,258 deletions.
Original file line number Diff line number Diff line change
@@ -1,35 +1,40 @@
from typing import Any, Mapping, Optional, Sequence

from agent_executor.checkpoint import RedisCheckpoint
from agent_executor.dnd import create_dnd_bot
from agent_executor.permchain import get_agent_executor
from langchain_openai import ChatOpenAI
from app.checkpoint import RedisCheckpoint
from app.agent_types.openai_agent import get_openai_agent_executor
from app.agent_types.xml_agent import get_xml_agent_executor
from app.llms import get_openai_llm, get_anthropic_llm
from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.messages import AnyMessage
from langchain_core.runnables import (
ConfigurableField,
ConfigurableFieldMultiOption,
RunnableBinding,
)
from gizmo_agent.agent_types import (
GizmoAgentType,
get_openai_function_agent,
# get_xml_agent,
)
from gizmo_agent.tools import (
from app.tools import (
RETRIEVAL_DESCRIPTION,
TOOL_OPTIONS,
TOOLS,
AvailableTools,
get_retrieval_tool,
)

from enum import Enum


class AgentType(str, Enum):
GPT_35_TURBO = "GPT 3.5 Turbo"
GPT_4 = "GPT 4"
AZURE_OPENAI = "GPT 4 (Azure OpenAI)"
CLAUDE2 = "Claude 2"
BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)"

DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."


class ConfigurableAgent(RunnableBinding):
tools: Sequence[str]
agent: GizmoAgentType
agent: AgentType
system_message: str = DEFAULT_SYSTEM_MESSAGE
retrieval_description: str = RETRIEVAL_DESCRIPTION
assistant_id: Optional[str] = None
Expand All @@ -39,7 +44,7 @@ def __init__(
self,
*,
tools: Sequence[str],
agent: GizmoAgentType = GizmoAgentType.GPT_35_TURBO,
agent: AgentType = AgentType.GPT_35_TURBO,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
assistant_id: Optional[str] = None,
retrieval_description: str = RETRIEVAL_DESCRIPTION,
Expand All @@ -58,21 +63,24 @@ def __init__(
_tools.append(get_retrieval_tool(assistant_id, retrieval_description))
else:
_tools.append(TOOLS[_tool]())
if agent == GizmoAgentType.GPT_35_TURBO:
_agent = get_openai_function_agent(_tools, system_message)
# elif agent == GizmoAgentType.GPT_4:
# _agent = get_openai_function_agent(_tools, system_message, gpt_4=True)
# elif agent == GizmoAgentType.AZURE_OPENAI:
# _agent = get_openai_function_agent(_tools, system_message, azure=True)
# elif agent == GizmoAgentType.CLAUDE2:
# _agent = get_xml_agent(_tools, system_message)
# elif agent == GizmoAgentType.BEDROCK_CLAUDE2:
# _agent = get_xml_agent(_tools, system_message, bedrock=True)
if agent == AgentType.GPT_35_TURBO:
llm = get_openai_llm()
_agent = get_openai_agent_executor(_tools, llm, system_message, RedisCheckpoint())
elif agent == AgentType.GPT_4:
llm = get_openai_llm(gpt_4=True)
_agent = get_openai_agent_executor(_tools, llm, system_message, RedisCheckpoint())
elif agent == AgentType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
_agent = get_openai_agent_executor(_tools, llm, system_message, RedisCheckpoint())
elif agent == AgentType.CLAUDE2:
llm = get_anthropic_llm()
_agent = get_xml_agent_executor(_tools, llm, system_message, RedisCheckpoint())
elif agent == AgentType.BEDROCK_CLAUDE2:
llm = get_anthropic_llm(bedrock=True)
_agent = get_xml_agent_executor(_tools, llm, system_message, RedisCheckpoint())
else:
raise ValueError("Unexpected agent type")
agent_executor = get_agent_executor(
tools=_tools, llm=_agent, checkpoint=RedisCheckpoint()
).with_config({"recursion_limit": 10})
agent_executor = _agent.with_config({"recursion_limit": 10})
super().__init__(
tools=tools,
agent=agent,
Expand All @@ -92,30 +100,9 @@ class AgentOutput(BaseModel):
messages: Sequence[AnyMessage] = Field(..., extra={"widget": {"type": "chat"}})


dnd_llm = ChatOpenAI(
model="gpt-3.5-turbo-1106", temperature=0, streaming=True
).configurable_alternatives(
ConfigurableField(id="llm", name="LLM"),
default_key="gpt-35-turbo",
# azure_openai=AzureChatOpenAI(
# temperature=0,
# deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
# openai_api_base=os.environ["AZURE_OPENAI_API_BASE"],
# openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
# openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
# streaming=True,
# ),
)


dnd_bot = create_dnd_bot(dnd_llm, checkpoint=RedisCheckpoint()).with_types(
input_type=AgentInput, output_type=AgentOutput
)


agent = (
ConfigurableAgent(
agent=GizmoAgentType.GPT_35_TURBO,
agent=AgentType.GPT_35_TURBO,
tools=[],
system_message=DEFAULT_SYSTEM_MESSAGE,
retrieval_description=RETRIEVAL_DESCRIPTION,
Expand All @@ -137,12 +124,6 @@ class AgentOutput(BaseModel):
id="retrieval_description", name="Retrieval Description"
),
)
.configurable_alternatives(
ConfigurableField(id="type", name="Bot Type"),
default_key="agent",
prefix_keys=True,
dungeons_and_dragons=dnd_bot,
)
.with_types(input_type=AgentInput, output_type=AgentOutput)
)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,87 +1,34 @@
import json
from operator import itemgetter
from typing import Sequence

from langchain.schema.agent import AgentAction, AgentActionMessageLog, AgentFinish
from langchain.schema.messages import AIMessage, AnyMessage, FunctionMessage
from langchain.schema.messages import FunctionMessage
from langchain_core.language_models.base import LanguageModelLike
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnablePassthrough,
)

from langchain.tools.render import format_tool_to_openai_function
from langchain_core.messages import SystemMessage

from langchain.tools import BaseTool
from langgraph.channels import Topic
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.graph.message import MessageGraph
from langgraph.graph import END
from langgraph.prebuilt import ToolExecutor
from langgraph.prebuilt import ToolInvocation




def _create_agent_message(
output: AgentAction | AgentFinish
) -> list[AnyMessage] | AnyMessage:
if isinstance(output, AgentAction):
if isinstance(output, AgentActionMessageLog):
output.message_log[-1].additional_kwargs["agent"] = output
messages = output.message_log
output.message_log = [] # avoid circular reference for json dumps
return messages
else:
return AIMessage(
content=output.log,
additional_kwargs={"agent": output},
)
else:
return AIMessage(
content=output.return_values["output"],
additional_kwargs={"agent": output},
)


def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)


def _run_tool(
messages: list[AnyMessage], config: RunnableConfig, *, tools: dict[str, BaseTool]
) -> FunctionMessage:
action: AgentAction = messages[-1].additional_kwargs["agent"]
tool = tools[action.tool]
result = tool.invoke(action.tool_input, config)
return _create_function_message(action, result)


async def _arun_tool(
messages: list[AnyMessage], config: RunnableConfig, *, tools: dict[str, BaseTool]
) -> FunctionMessage:
action: AgentAction = messages[-1].additional_kwargs["agent"]
tool = tools[action.tool]
result = await tool.ainvoke(action.tool_input, config)
return _create_function_message(action, result)


def get_agent_executor(
def get_openai_agent_executor(
tools: list[BaseTool],
llm: LanguageModelLike,
system_message: str,
checkpoint: BaseCheckpointSaver,
):
def _get_messages(messages):
return [SystemMessage(content=system_message)] + messages

if tools:
llm_with_tools = llm.bind(
functions=[format_tool_to_openai_function(t) for t in tools]
)
else:
llm_with_tools = llm
agent = _get_messages | llm_with_tools
tool_executor = ToolExecutor(tools)

# Define the function that determines whether to continue or not
Expand Down Expand Up @@ -114,7 +61,7 @@ async def call_tool(messages):
workflow = MessageGraph()

# Define the two nodes we will cycle between
workflow.add_node("agent", llm)
workflow.add_node("agent", agent)
workflow.add_node("action", call_tool)

# Set the entrypoint as `agent`
Expand Down
27 changes: 27 additions & 0 deletions backend/app/agent_types/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
xml_template = """{system_message}
You have access to the following tools:
{tools}
In order to use a tool, you can use <tool></tool> and <tool_input></tool_input> tags. You will then get back a response in the form <observation></observation>
For example, if you have a tool called 'search' that could run a google search, in order to search for the weather in SF you would respond:
<tool>search</tool><tool_input>weather in SF</tool_input>
<observation>64 degrees</observation>
When you are done, you can respond as normal to the user.
Example 1:
Human: Hi!
Assistant: Hi! How are you?
Human: What is the weather in SF?
Assistant: <tool>search</tool><tool_input>weather in SF</tool_input>
<observation>64 degrees</observation>
It is 64 degrees in SF
Begin!"""
Loading

0 comments on commit d5bb47b

Please sign in to comment.