diff --git a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py index ecaddcf..03b6512 100644 --- a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py +++ b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py @@ -15,12 +15,16 @@ class ChatCompletionCompatibleAPI: @staticmethod def from_agent( - agent: Runnable, llm_model: str, system_fingerprint: Optional[str] = "" + agent: Runnable, + llm_model: str, + system_fingerprint: Optional[str] = "", + event_adapter: callable = lambda event: None, ): return ChatCompletionCompatibleAPI( LangchainStreamAdapter(llm_model, system_fingerprint), LangchainInvokeAdapter(llm_model, system_fingerprint), agent, + event_adapter, ) def __init__( @@ -28,10 +32,12 @@ def __init__( stream_adapter: LangchainStreamAdapter, invoke_adapter: LangchainInvokeAdapter, agent: Runnable, + event_adapter: callable = lambda event: None, ) -> None: self.stream_adapter = stream_adapter self.invoke_adapter = invoke_adapter self.agent = agent + self.event_adapter = event_adapter def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]: input = self.__to_input(messages) @@ -40,7 +46,7 @@ def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]: version="v2", ) return ato_dict( - self.stream_adapter.ato_chat_completion_chunk_stream(astream_event) + self.stream_adapter.ato_chat_completion_chunk_stream(astream_event, event_adapter=self.event_adapter) ) def invoke(self, messages: List[OpenAIChatMessage]) -> dict: diff --git a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py index 9c778e3..f56577b 100644 --- a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py +++ b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py @@ -22,20 +22,21 @@ async def ato_chat_completion_chunk_stream( self, astream_event: AsyncIterator[StreamEvent], id: str = "", + event_adapter=lambda event: None, ) -> AsyncIterator[OpenAIChatCompletionChunkObject]: if id == "": id = str(uuid.uuid4()) async for event in astream_event: - kind = event["event"] - match kind: - case "on_chat_model_stream": - chunk = to_openai_chat_completion_chunk_object( - event=event, - id=id, - model=self.llm_model, - system_fingerprint=self.system_fingerprint, - ) - yield chunk + custom_event = event_adapter(event) + event_to_process = custom_event if custom_event is not None else event + kind = event_to_process["event"] + if kind == "on_chat_model_stream" or custom_event is not None: + yield to_openai_chat_completion_chunk_object( + event=event_to_process, + id=id, + model=self.llm_model, + system_fingerprint=self.system_fingerprint, + ) stop_chunk = create_final_chat_completion_chunk_object( id=id, model=self.llm_model diff --git a/langchain_openai_api_bridge/fastapi/chat_completion_router.py b/langchain_openai_api_bridge/fastapi/chat_completion_router.py index 6130cbe..b09e756 100644 --- a/langchain_openai_api_bridge/fastapi/chat_completion_router.py +++ b/langchain_openai_api_bridge/fastapi/chat_completion_router.py @@ -16,10 +16,11 @@ def create_chat_completion_router( tiny_di_container: TinyDIContainer, + event_adapter: callable = lambda event: None, ): - chat_completion_router = APIRouter(prefix="/chat/completions") + chat_completion_router = APIRouter(prefix="/chat") - @chat_completion_router.post("/") + @chat_completion_router.post("/completions") async def assistant_retreive_thread_messages( request: OpenAIChatCompletionRequest, authorization: str = Header(None) ): @@ -33,7 +34,7 @@ async def assistant_retreive_thread_messages( agent = agent_factory.create_agent(dto=create_agent_dto) - adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model) + adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, event_adapter=event_adapter) response_factory = HttpStreamResponseAdapter() if request.stream is True: @@ -46,9 +47,11 @@ async def assistant_retreive_thread_messages( def create_openai_chat_completion_router( - tiny_di_container: TinyDIContainer, prefix: str = "" + tiny_di_container: TinyDIContainer, + prefix: str = "", + event_adapter: callable = lambda event: None, ): - router = create_chat_completion_router(tiny_di_container=tiny_di_container) + router = create_chat_completion_router(tiny_di_container=tiny_di_container, event_adapter=event_adapter) open_ai_router = APIRouter(prefix=f"{prefix}/openai/v1") open_ai_router.include_router(router) diff --git a/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py b/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py index 2eec58e..0a2b66c 100644 --- a/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py +++ b/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py @@ -97,9 +97,9 @@ def bind_openai_assistant_api( self.app.include_router(assistant_router) - def bind_openai_chat_completion(self, prefix: str = "") -> None: + def bind_openai_chat_completion(self, prefix: str = "", event_adapter: callable = lambda event: None) -> None: chat_completion_router = create_openai_chat_completion_router( - self.tiny_di_container, prefix=prefix + self.tiny_di_container, prefix=prefix, event_adapter=event_adapter ) self.app.include_router(chat_completion_router) diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py new file mode 100644 index 0000000..fd2e9c8 --- /dev/null +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py @@ -0,0 +1,55 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from dotenv import load_dotenv, find_dotenv +import uvicorn + +from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto +from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import ( + LangchainOpenaiApiBridgeFastAPI, +) +from langchain_openai import ChatOpenAI + +_ = load_dotenv(find_dotenv()) + + +app = FastAPI( + title="Langchain Agent OpenAI API Bridge", + version="1.0", + description="OpenAI API exposing langchain agent", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], +) + + +def create_agent(dto: CreateAgentDto): + return ChatOpenAI( + temperature=dto.temperature or 0.7, + model=dto.model, + max_tokens=dto.max_tokens, + api_key=dto.api_key, + ) + + +bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) + + +def event_adapter(event): + kind = event["event"] + match kind: + case "on_chat_model_stream": + return event + + +bridge.bind_openai_chat_completion( + prefix="/my-custom-events-path", event_adapter=event_adapter +) + +if __name__ == "__main__": + uvicorn.run(app, host="localhost") diff --git a/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py new file mode 100644 index 0000000..3a9f60d --- /dev/null +++ b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py @@ -0,0 +1,44 @@ +import pytest +from openai import OpenAI +from fastapi.testclient import TestClient +from server_openai_event_adapter import app + + +test_api = TestClient(app) + + +@pytest.fixture +def openai_client_custom_events(): + return OpenAI( + base_url="http://testserver/my-custom-events-path/openai/v1", + http_client=test_api, + ) + + +def test_chat_completion_invoke_custom_events(openai_client_custom_events): + chat_completion = openai_client_custom_events.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": 'Say "This is a test"', + } + ], + ) + assert "This is a test" in chat_completion.choices[0].message.content + + +def test_chat_completion_stream_custom_events(openai_client_custom_events): + chunks = openai_client_custom_events.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": 'Say "This is a test"'}], + stream=True, + ) + every_content = [] + for chunk in chunks: + if chunk.choices and isinstance(chunk.choices[0].delta.content, str): + every_content.append(chunk.choices[0].delta.content) + + stream_output = "".join(every_content) + + assert "This is a test" in stream_output diff --git a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py index 450edb3..1638256 100644 --- a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py +++ b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py @@ -43,3 +43,36 @@ async def test_stream_contains_every_on_chat_model_stream( items = await assemble_stream(response_stream) assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict() assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict() + + @pytest.mark.asyncio + @patch( + "langchain_openai_api_bridge.chat_completion.langchain_stream_adapter.to_openai_chat_completion_chunk_object", + side_effect=lambda event, id, model, system_fingerprint: ( + ChatCompletionChunkStub({"key": event["data"]["chunk"].content}) + ), + ) + async def test_stream_contains_every_custom_handled_stream( + self, to_openai_chat_completion_chunk_object + ): + on_chat_model_stream_event1 = create_on_chat_model_stream_event(content="hello") + on_chat_model_stream_event2 = create_on_chat_model_stream_event(content="moto") + input_stream = generate_stream( + [ + on_chat_model_stream_event1, + on_chat_model_stream_event2, + ] + ) + + def event_adapter(event): + kind = event["event"] + match kind: + case "on_chat_model_stream": + return event + + response_stream = self.instance.ato_chat_completion_chunk_stream( + input_stream, event_adapter=event_adapter + ) + + items = await assemble_stream(response_stream) + assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict() + assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict()