Skip to content

Commit

Permalink
test:add langgraph multi agent tests (#39)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel <[email protected]>
  • Loading branch information
Valdanitooooo and samuelint authored Sep 6, 2024
1 parent dc398f0 commit d33a5d8
Show file tree
Hide file tree
Showing 8 changed files with 1,216 additions and 839 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ThreadRunStepDelta,
RunStep,
)
from langchain_core.messages.tool import ToolMessage

from openai.types.beta.threads import (
Message,
Expand Down Expand Up @@ -163,7 +164,26 @@ def create_langchain_function(
output: Optional[Union[dict[object], float, str]] = None,
) -> function_tool_call.Function:
arguments_json = json.dumps(arguments) if arguments else None
output_json = json.dumps(output) if output else None

output_json = _serialize_output(output=output)

return function_tool_call.Function(
name=name, arguments=arguments_json, output=output_json
)


def _serialize_output(output: Optional[Union[dict[object], float, str]] = None):
if output is None:
return None

if isinstance(output, ToolMessage):
output_obj = {
"content": output.content,
"tool_call_id": output.tool_call_id,
"status": output.status,
}
if output.artifact is not None:
output_obj["artifact"] = output.artifact
return json.dumps(output_obj)

return json.dumps(output) if output else None
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class OpenAIChatCompletionChoice(BaseModel):


class OpenAIChatCompletionObject(BaseModel):
id: str
id: Optional[str]
object: str = ("chat.completion",)
created: int
model: str
Expand Down
1,750 changes: 914 additions & 836 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ langchain = { version = "^0.2.6", optional = true }
langchain-openai = { version = "^0.1.8", optional = true }
fastapi = { version = "^0.111.0", optional = true }
python-dotenv = { version = "^1.0.1", optional = true }
langgraph = { version = "^0.0.62", optional = true }
langgraph = { version = "^0.2.16", optional = true }
langchain-anthropic = { version = "^0.1.19", optional = true }
langchain-groq = { version = "^0.1.6", optional = true }

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv, find_dotenv
import uvicorn

from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import (
LangchainOpenaiApiBridgeFastAPI,
)
from tests.test_functional.fastapi_chat_completion_multi_agent_openai.my_openai_multi_agent_factory import (
MyOpenAIMultiAgentFactory,
)

_ = load_dotenv(find_dotenv())
app = FastAPI(
title="Langgraph Multi Agent OpenAI API Bridge",
version="1.0",
description="OpenAI API exposing langgraph multi agent",
)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)

bridge = LangchainOpenaiApiBridgeFastAPI(
app=app, agent_factory_provider=lambda: MyOpenAIMultiAgentFactory()
)
bridge.bind_openai_chat_completion()

if __name__ == "__main__":
uvicorn.run(app, host="localhost")
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import functools
import operator
from datetime import datetime
from typing import TypedDict, Annotated, Sequence
from langchain.agents import create_openai_tools_agent, AgentExecutor
from langchain_core.messages import HumanMessage, BaseMessage
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import Tool
from langchain_openai import ChatOpenAI
from langgraph.constants import START, END
from langgraph.graph import StateGraph

from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto


# Define a new tool that returns the current datetime
datetime_tool = Tool(
name="Datetime",
func=lambda x: datetime.now().isoformat(),
description="Returns the current datetime",
)

mock_search_tool = Tool(
name="Search",
func=lambda x: "light",
description="Search the web about something",
)


def create_agent(llm: ChatOpenAI, system_prompt: str, tools: list):
# Each worker node will be given a name and some tools.
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
system_prompt,
),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
agent = create_openai_tools_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools)
return executor


def agent_node(state, agent, name):
result = agent.invoke(state)
return {"messages": [HumanMessage(content=result["output"], name=name)]}


members = ["Researcher", "CurrentTime"]
system_prompt = (
"You are a supervisor tasked with managing a conversation between the"
" following workers: {members}. Given the following user request,"
" respond with the worker to act next. Each worker will perform a"
" task and respond with their results and status. When finished,"
" respond with FINISH."
)
# Our team supervisor is an LLM node. It just picks the next agent to process and decides when the work is completed
options = ["FINISH"] + members

# Using openai function calling can make output parsing easier for us
function_def = {
"name": "route",
"description": "Select the next role.",
"parameters": {
"title": "routeSchema",
"type": "object",
"properties": {
"next": {
"title": "Next",
"anyOf": [
{"enum": options},
],
}
},
"required": ["next"],
},
}

# Create the prompt using ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
(
"system",
"Given the conversation above, who should act next?"
" Or should we FINISH? Select one of: {options}",
),
]
).partial(options=str(options), members=", ".join(members))


# The agent state is the input to each node in the graph
class AgentState(TypedDict):
# The annotation tells the graph that new messages will always be added to the current states
messages: Annotated[Sequence[BaseMessage], operator.add]
# The 'next' field indicates where to route to next
next: str


def create_graph(llm):
# Construction of the chain for the supervisor agent
supervisor_chain = (
prompt
| llm.bind_functions(functions=[function_def], function_call="route")
| JsonOutputFunctionsParser()
)

# Add the research agent using the create_agent helper function
research_agent = create_agent(llm, "You are a web researcher.", [mock_search_tool])
research_node = functools.partial(
agent_node, agent=research_agent, name="Researcher"
)

# Add the time agent using the create_agent helper function
current_time_agent = create_agent(
llm, "You can tell the current time at", [datetime_tool]
)
current_time_node = functools.partial(
agent_node, agent=current_time_agent, name="CurrentTime"
)

workflow = StateGraph(AgentState)

# Add a "chatbot" node. Nodes represent units of work. They are typically regular python functions.
workflow.add_node("Researcher", research_node)
workflow.add_node("CurrentTime", current_time_node)
workflow.add_node("supervisor", supervisor_chain)

# We want our workers to ALWAYS "report back" to the supervisor when done
for member in members:
workflow.add_edge(member, "supervisor")

# Conditional edges usually contain "if" statements to route
# to different nodes depending on the current graph state.
# These functions receive the current graph state and return a string
# or list of strings indicating which node(s) to call next.
conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)

# Add an entry point. This tells our graph where to start its work each time we run it.
workflow.add_edge(START, "supervisor")

# To be able to run our graph, call "compile()" on the graph builder.
# This creates a "CompiledGraph" we can use invoke on our state.
graph = workflow.compile(debug=True).with_config(
RunnableConfig(
recursion_limit=10,
)
)

return graph


class MyOpenAIMultiAgentFactory(BaseAgentFactory):

def create_agent(self, dto: CreateAgentDto) -> Runnable:
llm = ChatOpenAI(
model=dto.model,
api_key=dto.api_key,
streaming=True,
temperature=dto.temperature,
)
return create_graph(llm)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from openai import OpenAI
from fastapi.testclient import TestClient
from multi_agent_server_openai import app

test_api = TestClient(app)


@pytest.fixture
def openai_client():
return OpenAI(
base_url="http://testserver/openai/v1",
http_client=test_api,
)


def test_chat_completion_invoke(openai_client):
chat_completion = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": 'What time is it?',
}
],
)
assert "time" in chat_completion.choices[0].message.content


def test_chat_completion_stream(openai_client):
chunks = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": 'How does photosynthesis work?'}],
stream=True,
)
every_content = []
for chunk in chunks:
if chunk.choices and isinstance(chunk.choices[0].delta.content, str):
every_content.append(chunk.choices[0].delta.content)

stream_output = "".join(every_content)

assert "light" in stream_output
30 changes: 30 additions & 0 deletions tests/test_unit/assistant/adapter/test_openai_event_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from langchain_openai_api_bridge.assistant.adapter.openai_event_factory import (
create_langchain_function,
)
from langchain_core.messages.tool import ToolMessage


class TestCreateLangchainFunction:
Expand All @@ -20,3 +22,31 @@ def test_float_output_is_set_to_string(self):
result = create_langchain_function(arguments={"a": "a"}, output=2.1)

assert result.output == "2.1"

def test_ToolMessageOutput_is_serialized_to_json(self):
tool_message_output = ToolMessage(
content="some",
tool_call_id="123",
)
result = create_langchain_function(
arguments={"a": "a"}, output=tool_message_output
)

output = json.loads(result.output)
assert output["content"] == "some"
assert output["tool_call_id"] == "123"
assert output["status"] == "success"
assert output.get("artifact") is None

def test_ToolMessageOutput_with_artifact_is_serialized_to_json(self):
tool_message_output = ToolMessage(
content="some",
tool_call_id="123",
artifact={"test": "test"},
)
result = create_langchain_function(
arguments={"a": "a"}, output=tool_message_output
)

output = json.loads(result.output)
assert output["artifact"] == {"test": "test"}

0 comments on commit d33a5d8

Please sign in to comment.