diff --git a/.env b/.env index 78d7cc00..d5654652 100644 --- a/.env +++ b/.env @@ -1,5 +1,6 @@ OPENAI_API_KEY=placeholder ANTHROPIC_API_KEY=placeholder +COHERE_API_KEY=placeholder YDC_API_KEY=placeholder TAVILY_API_KEY=placeholder AZURE_OPENAI_DEPLOYMENT_NAME=placeholder diff --git a/backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py b/backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py index e9b139f6..c8856f04 100644 --- a/backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py +++ b/backend/packages/gizmo-agent/gizmo_agent/agent_types/__init__.py @@ -1,6 +1,7 @@ from enum import Enum from .openai import get_openai_function_agent +from .cohere import get_cohere_function_agent from .xml.agent import get_xml_agent @@ -10,10 +11,12 @@ class GizmoAgentType(str, Enum): # AZURE_OPENAI = "GPT 4 (Azure OpenAI)" # CLAUDE2 = "Claude 2" # BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" + COHERE_COMMAND = "Command (Cohere)" __all__ = [ "get_openai_function_agent", + "get_cohere_function_agent", "get_xml_agent", "GizmoAgentType", ] diff --git a/backend/packages/gizmo-agent/gizmo_agent/agent_types/cohere.py b/backend/packages/gizmo-agent/gizmo_agent/agent_types/cohere.py new file mode 100644 index 00000000..20590317 --- /dev/null +++ b/backend/packages/gizmo-agent/gizmo_agent/agent_types/cohere.py @@ -0,0 +1,129 @@ +import json + +from langchain.chat_models import ChatCohere +from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.schema.messages import AIMessage, HumanMessage +from langchain.schema import AgentAction, AgentFinish + + +def _collapse_messages(messages): + log = "" + if isinstance(messages[-1], AIMessage): + scratchpad = messages[:-1] + final = messages[-1] + else: + scratchpad = messages + final = None + if len(scratchpad) % 2 != 0: + raise ValueError("Unexpected") + for i in range(0, len(scratchpad), 2): + action = messages[i] + observation = messages[i + 1] + log += f"{action.content}{observation.content}" + if final is not None: + log += final.content + return AIMessage(content=log) + +def construct_chat_history(messages): + collapsed_messages = [] + temp_messages = [] + for message in messages: + if isinstance(message, HumanMessage): + if temp_messages: + collapsed_messages.append(_collapse_messages(temp_messages)) + temp_messages = [] + collapsed_messages.append(message) + else: + temp_messages.append(message) + + # Don't forget to add the last non-human message if it exists + if temp_messages: + collapsed_messages.append(_collapse_messages(temp_messages)) + + return collapsed_messages + +def get_cohere_function_agent(tools, system_message): + llm = ChatCohere(model="command-nightly", streaming=True, temperature=0.2) + + prompt = conversational_prompt.partial( + tools=render_json_description(tools), + tool_names=", ".join([t.name for t in tools]), + system_message=system_message, + ) + llm_with_stop = llm.bind(stop=[""]) + agent = ( + {"messages": lambda x: construct_chat_history(x["messages"])} + | prompt + | llm_with_stop + | parse_output + ) + return agent + +template = """{system_message} + +Instead of responding with text, you may invoke any of the following tools that will provide useful information for a future response: + +{tools} + +If any tool seems relevant to the user's question, be sure to invoke it to provide more information for the user. +In order to use a tool, respond with the name of the tool surrounded with the tags , and input for the tool with the tags tags. +The user will invoke the tool and return the response surrounded by the tags . + +For example, if the user asks "What is the weather in SF?", you can respond with: + +searchweather in SF + +The user may respond with something like this: + +64 degrees + +After receiving the observation, respond to the as normal, including information from the observation. + +Example conversation: + +User: What is the weather in San Fransisco? +Chatbot: weather_searchweather in San Fransisco +User: 64 degrees +Chatbot: It is 64 degrees in SF +""" # noqa: E501 + +conversational_prompt = ChatPromptTemplate.from_messages( + [ + ("user", template), + MessagesPlaceholder(variable_name="messages"), + ] +) + +def parse_output(message): + text = message.content + if "" in text: + tool, tool_input, *rest = text.split("") + _tool = tool.split("")[1] + _tool_input = "" + if "" in tool_input: + _tool_input = tool_input.split("")[1] + if "" in _tool_input: + _tool_input = _tool_input.split("")[0] + return AgentAction(tool=_tool, tool_input=_tool_input, log=text) + else: + return AgentFinish(return_values={"output": text}, log=text) + +def tool_to_object(tool): + inputs = [] + for input_name, input_schema in tool.args_schema.schema().get("properties").items(): + inputs.append({ + "name": input_name, + "description": input_schema.get("description"), + "type": input_schema.get("type"), + }) + + return { + "name": tool.name, + "definition": { + "description": tool.description, + "inputs": inputs, + } + } + +def render_json_description(tools): + return json.dumps([tool_to_object(tool) for tool in tools]) diff --git a/backend/packages/gizmo-agent/gizmo_agent/main.py b/backend/packages/gizmo-agent/gizmo_agent/main.py index 8399ef78..f4995ce8 100644 --- a/backend/packages/gizmo-agent/gizmo_agent/main.py +++ b/backend/packages/gizmo-agent/gizmo_agent/main.py @@ -16,6 +16,7 @@ GizmoAgentType, get_openai_function_agent, # get_xml_agent, + get_cohere_function_agent, ) from gizmo_agent.tools import TOOL_OPTIONS, TOOLS, AvailableTools, get_retrieval_tool @@ -61,6 +62,8 @@ def __init__( # _agent = get_xml_agent(_tools, system_message) # elif agent == GizmoAgentType.BEDROCK_CLAUDE2: # _agent = get_xml_agent(_tools, system_message, bedrock=True) + elif agent == GizmoAgentType.COHERE_COMMAND: + _agent = get_cohere_function_agent(_tools, system_message) else: raise ValueError("Unexpected agent type") agent_executor = get_agent_executor( diff --git a/backend/poetry.lock b/backend/poetry.lock index 2c7a9a99..6df89548 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -652,6 +652,25 @@ hard-encoding-detection = ["chardet"] toml = ["tomli"] types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency"] +[[package]] +name = "cohere" +version = "4.40" +description = "Python SDK for the Cohere API" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "cohere-4.40-py3-none-any.whl", hash = "sha256:75dac8369d97fadc05901352d9db64a0ca6cd40c08423f3c4691f57eb7b131e7"}, + {file = "cohere-4.40.tar.gz", hash = "sha256:d9e5c1fa7f80a193c03330a634954b927bf188ead7dcfdb51865480f73aebda8"}, +] + +[package.dependencies] +aiohttp = ">=3.0,<4.0" +backoff = ">=2.0,<3.0" +fastavro = ">=1.8,<2.0" +importlib_metadata = ">=6.0,<7.0" +requests = ">=2.25.0,<3.0.0" +urllib3 = ">=1.26,<3" + [[package]] name = "colorama" version = "0.4.6" @@ -896,6 +915,52 @@ typing-extensions = ">=4.5.0" [package.extras] all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +[[package]] +name = "fastavro" +version = "1.9.2" +description = "Fast read/write of AVRO files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastavro-1.9.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:223cecf135fd29b83ca6a30035b15b8db169aeaf8dc4f9a5d34afadc4b31638a"}, + {file = "fastavro-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e08c9be8c6f7eed2cf30f8b64d50094cba38a81b751c7db9f9c4be2656715259"}, + {file = "fastavro-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394f06cc865c6fbae3bbca323633a28a5d914c55dc2c1cdefb75432456ef8f6f"}, + {file = "fastavro-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7a7caadd47bdd04bda534ff70b4b98d2823800c488fd911918115aec4c4dc09b"}, + {file = "fastavro-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:68478a1b8a583d83ad6550e9dceac6cbb148a99a52c3559a0413bf4c0b9c8786"}, + {file = "fastavro-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:b59a1123f1d534743af33fdbda80dd7b9146685bdd7931eae12bee6203065222"}, + {file = "fastavro-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:887c20dc527a549764c91f9e48ece071f2f26d217af66ebcaeb87bf29578fee5"}, + {file = "fastavro-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46458f78b481c12db62d3d8a81bae09cb0b5b521c0d066c6856fc2746908d00d"}, + {file = "fastavro-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f4a2a4bed0e829f79fa1e4f172d484b2179426e827bcc80c0069cc81328a5af"}, + {file = "fastavro-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6167f9bbe1c5a28fbc2db767f97dbbb4981065e6eeafd4e613f6fe76c576ffd4"}, + {file = "fastavro-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d574bc385f820da0404528157238de4e5fdd775d2cb3d05b3b0f1b475d493837"}, + {file = "fastavro-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:ec600eb15b3ec931904c5bf8da62b3b725cb0f369add83ba47d7b5e9322f92a0"}, + {file = "fastavro-1.9.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c82b0761503420cd45f7f50bc31975ac1c75b5118e15434c1d724b751abcc249"}, + {file = "fastavro-1.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db62d9b8c944b8d9c481e5f980d5becfd034bdd58c72e27c9333bd504b06bda0"}, + {file = "fastavro-1.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65e61f040bc9494646f42a466e9cd428783b82d7161173f3296710723ba5a453"}, + {file = "fastavro-1.9.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:6278b93cdd5bef1778c0232ce1f265137f90bc6be97a5c1dd7e0d99a406c0488"}, + {file = "fastavro-1.9.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cd003ddea5d89720194b6e57011c37221d9fc4ddc750e6f4723516eb659be686"}, + {file = "fastavro-1.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:43f09d100a26e8b59f30dde664d93e423b648e008abfc43132608a18fe8ddcc2"}, + {file = "fastavro-1.9.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:3ddffeff5394f285c69f9cd481f47b6cf62379840cdbe6e0dc74683bd589b56e"}, + {file = "fastavro-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e75a2b2ec697d2058a7d96522e921f03f174cf9049ace007c24be7ab58c5370"}, + {file = "fastavro-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd2e8fd0567483eb0fdada1b979ad4d493305dfdd3f351c82a87df301f0ae1f"}, + {file = "fastavro-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c652dbe3f087c943a5b89f9a50a574e64f23790bfbec335ce2b91a2ae354a443"}, + {file = "fastavro-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba73e9a1822162f1b3a43de0362f29880014c5c4d49d63ad7fcce339ef73ea2"}, + {file = "fastavro-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:beeef2964bbfd09c539424808539b956d7425afbb7055b89e2aa311374748b56"}, + {file = "fastavro-1.9.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d5fa48266d75e057b27d8586b823d6d7d7c94593fd989d75033eb4c8078009fb"}, + {file = "fastavro-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b69aeb0d063f5955a0e412f9779444fc452568a49db75a90a8d372f9cb4a01c8"}, + {file = "fastavro-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce336c59fb40fdb8751bda8cc6076cfcdf9767c3c107f6049e049166b26c61f"}, + {file = "fastavro-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:581036e18661f045415a51ad528865e1d7ba5a9690a3dede9e6ea50f94ed6c4c"}, + {file = "fastavro-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:39b6b5c3cda569c0a130fd2d08d4c53a326ede7e05174a24eda08f7698f70eda"}, + {file = "fastavro-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:d33e40f246bf07f106f9d2da68d0234efcc62276b6e35bde00ff920ea7f871fd"}, + {file = "fastavro-1.9.2.tar.gz", hash = "sha256:5c1ffad986200496bd69b5c4748ae90b5d934d3b1456f33147bee3a0bb17f89b"}, +] + +[package.extras] +codecs = ["cramjam", "lz4", "zstandard"] +lz4 = ["lz4"] +snappy = ["cramjam"] +zstandard = ["zstandard"] + [[package]] name = "feedparser" version = "6.0.10" @@ -1283,6 +1348,25 @@ files = [ {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, ] +[[package]] +name = "importlib-metadata" +version = "6.11.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -3369,7 +3453,22 @@ files = [ idna = ">=2.0" multidict = ">=4.0" +[[package]] +name = "zipp" +version = "3.17.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, + {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "60304f12bed53e0d58baec54df54e79e1b5936dc651cc50ec76802c0cf40860c" +content-hash = "9b2c9ab985bc8d708b71806bcf24e704235f0682290fde1ec9af4acad33527e9" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 1bcbbacf..20447d86 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -26,6 +26,7 @@ langchain = ">=0.0.338" permchain = ">=0.0.8" pydantic = "<2.0" python-magic = "^0.4.27" +cohere = "^4.39" [tool.poetry.group.dev.dependencies] uvicorn = "^0.23.2"