diff --git a/.env b/.env
index 78d7cc00..d5654652 100644
--- a/.env
+++ b/.env
@@ -1,5 +1,6 @@
OPENAI_API_KEY=placeholder
ANTHROPIC_API_KEY=placeholder
+COHERE_API_KEY=placeholder
YDC_API_KEY=placeholder
TAVILY_API_KEY=placeholder
AZURE_OPENAI_DEPLOYMENT_NAME=placeholder
diff --git a/backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py b/backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py
index e9b139f6..c8856f04 100644
--- a/backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py
+++ b/backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py
@@ -1,6 +1,7 @@
from enum import Enum
from .openai import get_openai_function_agent
+from .cohere import get_cohere_function_agent
from .xml.agent import get_xml_agent
@@ -10,10 +11,12 @@ class GizmoAgentType(str, Enum):
# AZURE_OPENAI = "GPT 4 (Azure OpenAI)"
# CLAUDE2 = "Claude 2"
# BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)"
+ COHERE_COMMAND = "Command (Cohere)"
__all__ = [
"get_openai_function_agent",
+ "get_cohere_function_agent",
"get_xml_agent",
"GizmoAgentType",
]
diff --git a/backend/packages/gizmo-agent/gizmo_agent/agent_types/cohere.py b/backend/packages/gizmo-agent/gizmo_agent/agent_types/cohere.py
new file mode 100644
index 00000000..20590317
--- /dev/null
+++ b/backend/packages/gizmo-agent/gizmo_agent/agent_types/cohere.py
@@ -0,0 +1,129 @@
+import json
+
+from langchain.chat_models import ChatCohere
+from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
+from langchain.schema.messages import AIMessage, HumanMessage
+from langchain.schema import AgentAction, AgentFinish
+
+
+def _collapse_messages(messages):
+ log = ""
+ if isinstance(messages[-1], AIMessage):
+ scratchpad = messages[:-1]
+ final = messages[-1]
+ else:
+ scratchpad = messages
+ final = None
+ if len(scratchpad) % 2 != 0:
+ raise ValueError("Unexpected")
+ for i in range(0, len(scratchpad), 2):
+ action = messages[i]
+ observation = messages[i + 1]
+ log += f"{action.content}{observation.content}"
+ if final is not None:
+ log += final.content
+ return AIMessage(content=log)
+
+def construct_chat_history(messages):
+ collapsed_messages = []
+ temp_messages = []
+ for message in messages:
+ if isinstance(message, HumanMessage):
+ if temp_messages:
+ collapsed_messages.append(_collapse_messages(temp_messages))
+ temp_messages = []
+ collapsed_messages.append(message)
+ else:
+ temp_messages.append(message)
+
+ # Don't forget to add the last non-human message if it exists
+ if temp_messages:
+ collapsed_messages.append(_collapse_messages(temp_messages))
+
+ return collapsed_messages
+
+def get_cohere_function_agent(tools, system_message):
+ llm = ChatCohere(model="command-nightly", streaming=True, temperature=0.2)
+
+ prompt = conversational_prompt.partial(
+ tools=render_json_description(tools),
+ tool_names=", ".join([t.name for t in tools]),
+ system_message=system_message,
+ )
+ llm_with_stop = llm.bind(stop=[""])
+ agent = (
+ {"messages": lambda x: construct_chat_history(x["messages"])}
+ | prompt
+ | llm_with_stop
+ | parse_output
+ )
+ return agent
+
+template = """{system_message}
+
+Instead of responding with text, you may invoke any of the following tools that will provide useful information for a future response:
+
+{tools}
+
+If any tool seems relevant to the user's question, be sure to invoke it to provide more information for the user.
+In order to use a tool, respond with the name of the tool surrounded with the tags , and input for the tool with the tags tags.
+The user will invoke the tool and return the response surrounded by the tags .
+
+For example, if the user asks "What is the weather in SF?", you can respond with:
+
+searchweather in SF
+
+The user may respond with something like this:
+
+64 degrees
+
+After receiving the observation, respond to the as normal, including information from the observation.
+
+Example conversation:
+
+User: What is the weather in San Fransisco?
+Chatbot: weather_searchweather in San Fransisco
+User: 64 degrees
+Chatbot: It is 64 degrees in SF
+""" # noqa: E501
+
+conversational_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("user", template),
+ MessagesPlaceholder(variable_name="messages"),
+ ]
+)
+
+def parse_output(message):
+ text = message.content
+ if "" in text:
+ tool, tool_input, *rest = text.split("")
+ _tool = tool.split("")[1]
+ _tool_input = ""
+ if "" in tool_input:
+ _tool_input = tool_input.split("")[1]
+ if "" in _tool_input:
+ _tool_input = _tool_input.split("")[0]
+ return AgentAction(tool=_tool, tool_input=_tool_input, log=text)
+ else:
+ return AgentFinish(return_values={"output": text}, log=text)
+
+def tool_to_object(tool):
+ inputs = []
+ for input_name, input_schema in tool.args_schema.schema().get("properties").items():
+ inputs.append({
+ "name": input_name,
+ "description": input_schema.get("description"),
+ "type": input_schema.get("type"),
+ })
+
+ return {
+ "name": tool.name,
+ "definition": {
+ "description": tool.description,
+ "inputs": inputs,
+ }
+ }
+
+def render_json_description(tools):
+ return json.dumps([tool_to_object(tool) for tool in tools])
diff --git a/backend/packages/gizmo-agent/gizmo_agent/main.py b/backend/packages/gizmo-agent/gizmo_agent/main.py
index 8399ef78..f4995ce8 100644
--- a/backend/packages/gizmo-agent/gizmo_agent/main.py
+++ b/backend/packages/gizmo-agent/gizmo_agent/main.py
@@ -16,6 +16,7 @@
GizmoAgentType,
get_openai_function_agent,
# get_xml_agent,
+ get_cohere_function_agent,
)
from gizmo_agent.tools import TOOL_OPTIONS, TOOLS, AvailableTools, get_retrieval_tool
@@ -61,6 +62,8 @@ def __init__(
# _agent = get_xml_agent(_tools, system_message)
# elif agent == GizmoAgentType.BEDROCK_CLAUDE2:
# _agent = get_xml_agent(_tools, system_message, bedrock=True)
+ elif agent == GizmoAgentType.COHERE_COMMAND:
+ _agent = get_cohere_function_agent(_tools, system_message)
else:
raise ValueError("Unexpected agent type")
agent_executor = get_agent_executor(
diff --git a/backend/poetry.lock b/backend/poetry.lock
index 2c7a9a99..6df89548 100644
--- a/backend/poetry.lock
+++ b/backend/poetry.lock
@@ -652,6 +652,25 @@ hard-encoding-detection = ["chardet"]
toml = ["tomli"]
types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency"]
+[[package]]
+name = "cohere"
+version = "4.40"
+description = "Python SDK for the Cohere API"
+optional = false
+python-versions = ">=3.8,<4.0"
+files = [
+ {file = "cohere-4.40-py3-none-any.whl", hash = "sha256:75dac8369d97fadc05901352d9db64a0ca6cd40c08423f3c4691f57eb7b131e7"},
+ {file = "cohere-4.40.tar.gz", hash = "sha256:d9e5c1fa7f80a193c03330a634954b927bf188ead7dcfdb51865480f73aebda8"},
+]
+
+[package.dependencies]
+aiohttp = ">=3.0,<4.0"
+backoff = ">=2.0,<3.0"
+fastavro = ">=1.8,<2.0"
+importlib_metadata = ">=6.0,<7.0"
+requests = ">=2.25.0,<3.0.0"
+urllib3 = ">=1.26,<3"
+
[[package]]
name = "colorama"
version = "0.4.6"
@@ -896,6 +915,52 @@ typing-extensions = ">=4.5.0"
[package.extras]
all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
+[[package]]
+name = "fastavro"
+version = "1.9.2"
+description = "Fast read/write of AVRO files"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "fastavro-1.9.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:223cecf135fd29b83ca6a30035b15b8db169aeaf8dc4f9a5d34afadc4b31638a"},
+ {file = "fastavro-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e08c9be8c6f7eed2cf30f8b64d50094cba38a81b751c7db9f9c4be2656715259"},
+ {file = "fastavro-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394f06cc865c6fbae3bbca323633a28a5d914c55dc2c1cdefb75432456ef8f6f"},
+ {file = "fastavro-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7a7caadd47bdd04bda534ff70b4b98d2823800c488fd911918115aec4c4dc09b"},
+ {file = "fastavro-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:68478a1b8a583d83ad6550e9dceac6cbb148a99a52c3559a0413bf4c0b9c8786"},
+ {file = "fastavro-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:b59a1123f1d534743af33fdbda80dd7b9146685bdd7931eae12bee6203065222"},
+ {file = "fastavro-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:887c20dc527a549764c91f9e48ece071f2f26d217af66ebcaeb87bf29578fee5"},
+ {file = "fastavro-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46458f78b481c12db62d3d8a81bae09cb0b5b521c0d066c6856fc2746908d00d"},
+ {file = "fastavro-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f4a2a4bed0e829f79fa1e4f172d484b2179426e827bcc80c0069cc81328a5af"},
+ {file = "fastavro-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6167f9bbe1c5a28fbc2db767f97dbbb4981065e6eeafd4e613f6fe76c576ffd4"},
+ {file = "fastavro-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d574bc385f820da0404528157238de4e5fdd775d2cb3d05b3b0f1b475d493837"},
+ {file = "fastavro-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:ec600eb15b3ec931904c5bf8da62b3b725cb0f369add83ba47d7b5e9322f92a0"},
+ {file = "fastavro-1.9.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c82b0761503420cd45f7f50bc31975ac1c75b5118e15434c1d724b751abcc249"},
+ {file = "fastavro-1.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db62d9b8c944b8d9c481e5f980d5becfd034bdd58c72e27c9333bd504b06bda0"},
+ {file = "fastavro-1.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65e61f040bc9494646f42a466e9cd428783b82d7161173f3296710723ba5a453"},
+ {file = "fastavro-1.9.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:6278b93cdd5bef1778c0232ce1f265137f90bc6be97a5c1dd7e0d99a406c0488"},
+ {file = "fastavro-1.9.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cd003ddea5d89720194b6e57011c37221d9fc4ddc750e6f4723516eb659be686"},
+ {file = "fastavro-1.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:43f09d100a26e8b59f30dde664d93e423b648e008abfc43132608a18fe8ddcc2"},
+ {file = "fastavro-1.9.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:3ddffeff5394f285c69f9cd481f47b6cf62379840cdbe6e0dc74683bd589b56e"},
+ {file = "fastavro-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e75a2b2ec697d2058a7d96522e921f03f174cf9049ace007c24be7ab58c5370"},
+ {file = "fastavro-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd2e8fd0567483eb0fdada1b979ad4d493305dfdd3f351c82a87df301f0ae1f"},
+ {file = "fastavro-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c652dbe3f087c943a5b89f9a50a574e64f23790bfbec335ce2b91a2ae354a443"},
+ {file = "fastavro-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba73e9a1822162f1b3a43de0362f29880014c5c4d49d63ad7fcce339ef73ea2"},
+ {file = "fastavro-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:beeef2964bbfd09c539424808539b956d7425afbb7055b89e2aa311374748b56"},
+ {file = "fastavro-1.9.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d5fa48266d75e057b27d8586b823d6d7d7c94593fd989d75033eb4c8078009fb"},
+ {file = "fastavro-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b69aeb0d063f5955a0e412f9779444fc452568a49db75a90a8d372f9cb4a01c8"},
+ {file = "fastavro-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce336c59fb40fdb8751bda8cc6076cfcdf9767c3c107f6049e049166b26c61f"},
+ {file = "fastavro-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:581036e18661f045415a51ad528865e1d7ba5a9690a3dede9e6ea50f94ed6c4c"},
+ {file = "fastavro-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:39b6b5c3cda569c0a130fd2d08d4c53a326ede7e05174a24eda08f7698f70eda"},
+ {file = "fastavro-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:d33e40f246bf07f106f9d2da68d0234efcc62276b6e35bde00ff920ea7f871fd"},
+ {file = "fastavro-1.9.2.tar.gz", hash = "sha256:5c1ffad986200496bd69b5c4748ae90b5d934d3b1456f33147bee3a0bb17f89b"},
+]
+
+[package.extras]
+codecs = ["cramjam", "lz4", "zstandard"]
+lz4 = ["lz4"]
+snappy = ["cramjam"]
+zstandard = ["zstandard"]
+
[[package]]
name = "feedparser"
version = "6.0.10"
@@ -1283,6 +1348,25 @@ files = [
{file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"},
]
+[[package]]
+name = "importlib-metadata"
+version = "6.11.0"
+description = "Read metadata from Python packages"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"},
+ {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"},
+]
+
+[package.dependencies]
+zipp = ">=0.5"
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
+perf = ["ipython"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
+
[[package]]
name = "iniconfig"
version = "2.0.0"
@@ -3369,7 +3453,22 @@ files = [
idna = ">=2.0"
multidict = ">=4.0"
+[[package]]
+name = "zipp"
+version = "3.17.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
+ {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
+
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
-content-hash = "60304f12bed53e0d58baec54df54e79e1b5936dc651cc50ec76802c0cf40860c"
+content-hash = "9b2c9ab985bc8d708b71806bcf24e704235f0682290fde1ec9af4acad33527e9"
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index 1bcbbacf..20447d86 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -26,6 +26,7 @@ langchain = ">=0.0.338"
permchain = ">=0.0.8"
pydantic = "<2.0"
python-magic = "^0.4.27"
+cohere = "^4.39"
[tool.poetry.group.dev.dependencies]
uvicorn = "^0.23.2"