-
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test:add langgraph multi agent tests (#39)
Co-authored-by: Samuel <[email protected]>
- Loading branch information
1 parent
dc398f0
commit d33a5d8
Showing
8 changed files
with
1,216 additions
and
839 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
...s/test_functional/fastapi_chat_completion_multi_agent_openai/multi_agent_server_openai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from fastapi import FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from dotenv import load_dotenv, find_dotenv | ||
import uvicorn | ||
|
||
from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import ( | ||
LangchainOpenaiApiBridgeFastAPI, | ||
) | ||
from tests.test_functional.fastapi_chat_completion_multi_agent_openai.my_openai_multi_agent_factory import ( | ||
MyOpenAIMultiAgentFactory, | ||
) | ||
|
||
_ = load_dotenv(find_dotenv()) | ||
app = FastAPI( | ||
title="Langgraph Multi Agent OpenAI API Bridge", | ||
version="1.0", | ||
description="OpenAI API exposing langgraph multi agent", | ||
) | ||
|
||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
expose_headers=["*"], | ||
) | ||
|
||
bridge = LangchainOpenaiApiBridgeFastAPI( | ||
app=app, agent_factory_provider=lambda: MyOpenAIMultiAgentFactory() | ||
) | ||
bridge.bind_openai_chat_completion() | ||
|
||
if __name__ == "__main__": | ||
uvicorn.run(app, host="localhost") |
171 changes: 171 additions & 0 deletions
171
...st_functional/fastapi_chat_completion_multi_agent_openai/my_openai_multi_agent_factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import functools | ||
import operator | ||
from datetime import datetime | ||
from typing import TypedDict, Annotated, Sequence | ||
from langchain.agents import create_openai_tools_agent, AgentExecutor | ||
from langchain_core.messages import HumanMessage, BaseMessage | ||
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain_core.runnables import Runnable, RunnableConfig | ||
from langchain_core.tools import Tool | ||
from langchain_openai import ChatOpenAI | ||
from langgraph.constants import START, END | ||
from langgraph.graph import StateGraph | ||
|
||
from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory | ||
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto | ||
|
||
|
||
# Define a new tool that returns the current datetime | ||
datetime_tool = Tool( | ||
name="Datetime", | ||
func=lambda x: datetime.now().isoformat(), | ||
description="Returns the current datetime", | ||
) | ||
|
||
mock_search_tool = Tool( | ||
name="Search", | ||
func=lambda x: "light", | ||
description="Search the web about something", | ||
) | ||
|
||
|
||
def create_agent(llm: ChatOpenAI, system_prompt: str, tools: list): | ||
# Each worker node will be given a name and some tools. | ||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
( | ||
"system", | ||
system_prompt, | ||
), | ||
MessagesPlaceholder(variable_name="messages"), | ||
MessagesPlaceholder(variable_name="agent_scratchpad"), | ||
] | ||
) | ||
agent = create_openai_tools_agent(llm, tools, prompt) | ||
executor = AgentExecutor(agent=agent, tools=tools) | ||
return executor | ||
|
||
|
||
def agent_node(state, agent, name): | ||
result = agent.invoke(state) | ||
return {"messages": [HumanMessage(content=result["output"], name=name)]} | ||
|
||
|
||
members = ["Researcher", "CurrentTime"] | ||
system_prompt = ( | ||
"You are a supervisor tasked with managing a conversation between the" | ||
" following workers: {members}. Given the following user request," | ||
" respond with the worker to act next. Each worker will perform a" | ||
" task and respond with their results and status. When finished," | ||
" respond with FINISH." | ||
) | ||
# Our team supervisor is an LLM node. It just picks the next agent to process and decides when the work is completed | ||
options = ["FINISH"] + members | ||
|
||
# Using openai function calling can make output parsing easier for us | ||
function_def = { | ||
"name": "route", | ||
"description": "Select the next role.", | ||
"parameters": { | ||
"title": "routeSchema", | ||
"type": "object", | ||
"properties": { | ||
"next": { | ||
"title": "Next", | ||
"anyOf": [ | ||
{"enum": options}, | ||
], | ||
} | ||
}, | ||
"required": ["next"], | ||
}, | ||
} | ||
|
||
# Create the prompt using ChatPromptTemplate | ||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", system_prompt), | ||
MessagesPlaceholder(variable_name="messages"), | ||
( | ||
"system", | ||
"Given the conversation above, who should act next?" | ||
" Or should we FINISH? Select one of: {options}", | ||
), | ||
] | ||
).partial(options=str(options), members=", ".join(members)) | ||
|
||
|
||
# The agent state is the input to each node in the graph | ||
class AgentState(TypedDict): | ||
# The annotation tells the graph that new messages will always be added to the current states | ||
messages: Annotated[Sequence[BaseMessage], operator.add] | ||
# The 'next' field indicates where to route to next | ||
next: str | ||
|
||
|
||
def create_graph(llm): | ||
# Construction of the chain for the supervisor agent | ||
supervisor_chain = ( | ||
prompt | ||
| llm.bind_functions(functions=[function_def], function_call="route") | ||
| JsonOutputFunctionsParser() | ||
) | ||
|
||
# Add the research agent using the create_agent helper function | ||
research_agent = create_agent(llm, "You are a web researcher.", [mock_search_tool]) | ||
research_node = functools.partial( | ||
agent_node, agent=research_agent, name="Researcher" | ||
) | ||
|
||
# Add the time agent using the create_agent helper function | ||
current_time_agent = create_agent( | ||
llm, "You can tell the current time at", [datetime_tool] | ||
) | ||
current_time_node = functools.partial( | ||
agent_node, agent=current_time_agent, name="CurrentTime" | ||
) | ||
|
||
workflow = StateGraph(AgentState) | ||
|
||
# Add a "chatbot" node. Nodes represent units of work. They are typically regular python functions. | ||
workflow.add_node("Researcher", research_node) | ||
workflow.add_node("CurrentTime", current_time_node) | ||
workflow.add_node("supervisor", supervisor_chain) | ||
|
||
# We want our workers to ALWAYS "report back" to the supervisor when done | ||
for member in members: | ||
workflow.add_edge(member, "supervisor") | ||
|
||
# Conditional edges usually contain "if" statements to route | ||
# to different nodes depending on the current graph state. | ||
# These functions receive the current graph state and return a string | ||
# or list of strings indicating which node(s) to call next. | ||
conditional_map = {k: k for k in members} | ||
conditional_map["FINISH"] = END | ||
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) | ||
|
||
# Add an entry point. This tells our graph where to start its work each time we run it. | ||
workflow.add_edge(START, "supervisor") | ||
|
||
# To be able to run our graph, call "compile()" on the graph builder. | ||
# This creates a "CompiledGraph" we can use invoke on our state. | ||
graph = workflow.compile(debug=True).with_config( | ||
RunnableConfig( | ||
recursion_limit=10, | ||
) | ||
) | ||
|
||
return graph | ||
|
||
|
||
class MyOpenAIMultiAgentFactory(BaseAgentFactory): | ||
|
||
def create_agent(self, dto: CreateAgentDto) -> Runnable: | ||
llm = ChatOpenAI( | ||
model=dto.model, | ||
api_key=dto.api_key, | ||
streaming=True, | ||
temperature=dto.temperature, | ||
) | ||
return create_graph(llm) |
43 changes: 43 additions & 0 deletions
43
...t_functional/fastapi_chat_completion_multi_agent_openai/test_multi_agent_server_openai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import pytest | ||
from openai import OpenAI | ||
from fastapi.testclient import TestClient | ||
from multi_agent_server_openai import app | ||
|
||
test_api = TestClient(app) | ||
|
||
|
||
@pytest.fixture | ||
def openai_client(): | ||
return OpenAI( | ||
base_url="http://testserver/openai/v1", | ||
http_client=test_api, | ||
) | ||
|
||
|
||
def test_chat_completion_invoke(openai_client): | ||
chat_completion = openai_client.chat.completions.create( | ||
model="gpt-4o-mini", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": 'What time is it?', | ||
} | ||
], | ||
) | ||
assert "time" in chat_completion.choices[0].message.content | ||
|
||
|
||
def test_chat_completion_stream(openai_client): | ||
chunks = openai_client.chat.completions.create( | ||
model="gpt-4o-mini", | ||
messages=[{"role": "user", "content": 'How does photosynthesis work?'}], | ||
stream=True, | ||
) | ||
every_content = [] | ||
for chunk in chunks: | ||
if chunk.choices and isinstance(chunk.choices[0].delta.content, str): | ||
every_content.append(chunk.choices[0].delta.content) | ||
|
||
stream_output = "".join(every_content) | ||
|
||
assert "light" in stream_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters