diff --git a/burr/core/action.py b/burr/core/action.py index 60aa979f..73a57252 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -1011,6 +1011,18 @@ def bind(self, **kwargs: Any) -> Self: ... +def copy_func(f: types.FunctionType) -> types.FunctionType: + """Copies a function. This is used internally to bind parameters to a function + so we don't accidentally overwrite them. + + :param f: Function to copy + :return: The copied function + """ + fn = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__) + fn.__dict__.update(f.__dict__) + return fn + + def bind(self: FunctionRepresentingAction, **kwargs: Any) -> FunctionRepresentingAction: """Binds an action to the given parameters. This is functionally equivalent to functools.partial, but is more explicit and is meant to be used in the API. This only works with @@ -1028,6 +1040,7 @@ def my_action(state: State, z: int) -> tuple[dict, State]: :param kwargs: The keyword arguments to bind :return: The decorated function with the given parameters bound """ + self = copy_func(self) # we have to bind to a copy of the function, otherwise it will override self.action_function = self.action_function.with_params(**kwargs) return self @@ -1092,6 +1105,7 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State] """ def wrapped(fn) -> FunctionRepresentingAction: + fn = copy_func(fn) setattr( fn, FunctionBasedAction.ACTION_FUNCTION, diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index 9f318a11..be99070a 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -18,6 +18,7 @@ # dynamic importing due to the dashes (which make reading the examples on github easier) email_assistant = importlib.import_module("burr.examples.email-assistant.server") chatbot = importlib.import_module("burr.examples.multi-modal-chatbot.server") + streaming_chatbot = importlib.import_module("burr.examples.streaming-fastapi.server") except ImportError as e: require_plugin( @@ -97,7 +98,7 @@ async def version() -> dict: # Examples -- todo -- put them behind `if` statements app.include_router(chatbot.router, prefix="/api/v0/chatbot") app.include_router(email_assistant.router, prefix="/api/v0/email_assistant") -# email_assistant.register(app, "/api/v0/email_assistant") +app.include_router(streaming_chatbot.router, prefix="/api/v0/streaming_chatbot") if SERVE_STATIC: diff --git a/examples/streaming-fastapi/README.md b/examples/streaming-fastapi/README.md new file mode 100644 index 00000000..32bddafc --- /dev/null +++ b/examples/streaming-fastapi/README.md @@ -0,0 +1,214 @@ +# Streaming in FastAPI + +This example demonstrates how to stream data from Burr's streaming mode through FastAPI. + +This is gone over in more detail in our blog post (coming soon). This README will go over the main code + roles and how to run the example. + +This uses Server Sent Events (SSE) to stream data from FastAPI to the frontend. This also uses Async Generators to ensure optimal performance. + +## Example + +The application we created will be a simple chatbot proxy. It has a few diffrent modes -- it can either decide a prompt is "unsafe" (in this case meaning that it has the word "unsafe" in it, but this would typically go to specific model), +or do one of the following: + +1. Generate code +2. Answer a question +3. Generate a poem +4. Prompt for more + +It will use an LLM to decide which to do. It streams back text using async streaming in Burr. Read more about how that is implemented [here](https://burr.dagworks.io/concepts/streaming-actions/). + +Note that, even though not every response is streaming (E.G. unsafe response, which is hardcoded), they are modeled as streaming to make interaction with the app simpler. + +The app looks like this: + +![application graph](statemachine.png) + +## Running the example + +You will need an API key from Open AI to run this. + +You'll first have to have `burr[start]` installed. You can then view this demo in your app by running Burr: + +```bash +burr +``` + +This will open a browser on [http://localhost:7241](http://localhost:7241) + +Navigate to the [streaming example](http://localhost:7241/demos/streaming-chatbot). + +## Streaming in Burr + +Read more [here](https://burr.dagworks.io/concepts/streaming-actions/) +To use streaming in Burr, you write your actions as a generator. If you're using the function-based API (as we do in this example), +the function should yield a tuple, consisting of: +1. The result (intermediate or final) +2. The updated (`None` if intermediate, present if final) + +(2) will always be the last yield, and indicate that the streaming action is complete. Take, for example, the +"unsafe" response, meaning that the LLM has determined that it cannot respond. This is a simple example -- just to illustrate streaming: + +This sleeps to make a point (and make the demo more fun/give the appearance of the app "thinking") -- in reality you really would not want to do this. + +```python +@streaming_action(reads=["prompt", "chat_history"], writes=["response"]) +async def unsafe_response(state: State) -> Tuple[dict, State]: + result = { + "response": { + "content": "I am afraid I can't respond to that...", + "type": "text", + "role": "assistant", + } + } + for word in result["response"]["content"].split(): + await asyncio.sleep(0.1) + yield {"delta": word + " "}, None + yield result, state.update(**result).append(chat_history=result["response"]) +``` + +This is an async generator that yields just the delta until it gets to the end. This can easily proxy from another service (openAI for example), +or do some other async operation. + +When you call the action, you will get back a `AsyncStreamingResponseContainer` object. This is *also* an async generator! + +```python +action, streaming_container = await app.astream_result( + halt_after=TERMINAL_ACTIONS, inputs={"prompt": "Please generate a limerick about Alexander Hamilton and Aaron Burr"} +) + +async for item in streaming_container: + print(item['delta'], end="") +``` + +This will stream the results out. + +## Connecting to FastAPI + +To connect to FastAPI, we need to do the following: + +1. Instantiate a Burr Application in FastAPI +2. Create a route that will stream the data to the frontend, which is *also* an async generator. +3. Bridge the two together + +In [server.py](server.py), we have a helpful `_get_application` function that will get or create an application for us. +We can then call a chat_response function that looks like this: + +```python +@router.post("/response/{project_id}/{app_id}", response_class=StreamingResponse) +async def chat_response(project_id: str, app_id: str, prompt: PromptInput) -> StreamingResponse: + burr_app = _get_application(project_id, app_id) + chat_history = burr_app.state.get("chat_history", []) + action, streaming_container = await burr_app.astream_result( + halt_after=chat_application.TERMINAL_ACTIONS, inputs=dict(prompt=prompt.prompt) + ) + + async def sse_generator(): + yield f"data: {json.dumps({'type': 'chat_history', 'value': chat_history})}\n\n" + + async for item in streaming_container: + yield f"data: {json.dumps({'type': 'delta', 'value': item['delta']})} \n\n" + + return StreamingResponse(sse_generator()) +``` + +Note this returns a [StreamingResponse](https://fastapi.tiangolo.com/advanced/custom-response/#streamingresponse) +and does some fancy stuff with the SSE API. Particularly: +1. It returns the initial state, so the UI can update to the latest (not strictly necessary, but nice to have for rendering) +2. It streams the deltas as they come in +3. It returns the data in the format: "data: ...\n\n" as this is standard for SSE + +And it's as simple as that! You can now stream data from Burr to FastAPI. + +## Streaming in Typescript/React + +This part can get a little messy with state management/chat history, but here's the basics of it. There are multiple approaches +to managing SSE in React, but we will be using the very bare-bones `fetch` and `getReaders()` API. + +The following code is the `submitPrompt` function that will send the prompt and modify the state. This gets called when the +user submits a prompt (E.G. on the `onClick` of a button). + +It relies on the state variables: + +- `currentPrompt`/`setCurrentPrompt` - the current prompt +- `chatHistory`/`setChatHistory` - the chat history + +This also assumes the server is a post request with the prompt in the URL (putting it in the body is probably better...) + +### Fetch the result (POST) + +First we'll fetch the result with a post request to match the endpoint above. We will also get a reader object +to help us iterate through the inputs: + +```typescript +const response = await fetch( + `/api/v0/streaming_chatbot/response/${props.projectId}/${props.appId}`, + { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ prompt: currentPrompt }) + } + ); +const reader = response.body?.getReader(); +``` + +Then we'll run through the reader object and parse the data, modifying the state as we go: + +```typescript +if (reader) { + const decoder = new TextDecoder('utf-8'); + // eslint-disable-next-line no-constant-condition + while (true) { + const result = await reader.read(); + if (result.done) { + break; + } + const message = decoder.decode(result.value, { stream: true }); + message + .split('data: ') + .slice(1) + .forEach((item) => { + const event: Event = JSON.parse(item); + if (event.type === 'chat_history') { + const chatMessageEvent = event as ChatHistoryEvent; + setDisplayedChatHistory(chatMessageEvent.value); + } + if (event.type === 'delta') { + const chatMessageEvent = event as ChatMessageEvent; + chatResponse += chatMessageEvent.value; + setCurrentResponse(chatResponse); + } + }); + } + setDisplayedChatHistory((chatHistory) => [ + ...chatHistory, + { + role: ChatItem.role.USER, + content: currentPrompt, + type: ChatItem.type.TEXT + }, + { + role: ChatItem.role.ASSISTANT, + content: chatResponse, + type: ChatItem.type.TEXT + } + ]); + setCurrentPrompt(''); + setCurrentResponse(''); + setIsChatWaiting(false); +} +``` +In the above we: +1. Check if the reader is present (it is likely worth adding more error-correcting here) +2. Break if the reader is done +3. Decode the message +4. Parse the message + a. If it is a chat history event, update the chat history + b. If it is a delta event, update the chat response +5. Update the chat history with the new prompt and response +6. Reset the variables so we don't render them twice + +While the logic of updating state is bespoke to how we do it here, looping through the reader and parsing the data +is a common, highly generatizable operation. + +Note there are multiple ways of doing this -- this was just the simplest. diff --git a/examples/streaming-fastapi/__init__.py b/examples/streaming-fastapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/streaming-fastapi/application.py b/examples/streaming-fastapi/application.py new file mode 100644 index 00000000..8e90df8e --- /dev/null +++ b/examples/streaming-fastapi/application.py @@ -0,0 +1,192 @@ +import asyncio +import copy +from typing import AsyncGenerator, Optional, Tuple + +import openai + +from burr.core import ApplicationBuilder, State, default, when +from burr.core.action import action, streaming_action +from burr.core.graph import GraphBuilder + +MODES = [ + "answer_question", + "generate_poem", + "generate_code", + "unknown", +] + + +@action(reads=[], writes=["chat_history", "prompt"]) +def process_prompt(state: State, prompt: str) -> Tuple[dict, State]: + result = {"chat_item": {"role": "user", "content": prompt, "type": "text"}} + return result, state.wipe(keep=["prompt", "chat_history"]).append( + chat_history=result["chat_item"] + ).update(prompt=prompt) + + +@action(reads=["prompt"], writes=["safe"]) +def check_safety(state: State) -> Tuple[dict, State]: + result = {"safe": "unsafe" not in state["prompt"]} # quick hack to demonstrate + return result, state.update(safe=result["safe"]) + + +def _get_openai_client(): + return openai.AsyncOpenAI() + + +@action(reads=["prompt"], writes=["mode"]) +async def choose_mode(state: State) -> Tuple[dict, State]: + prompt = ( + f"You are a chatbot. You've been prompted this: {state['prompt']}. " + f"You have the capability of responding in the following modes: {', '.join(MODES)}. " + "Please respond with *only* a single word representing the mode that most accurately " + "corresponds to the prompt. Fr instance, if the prompt is 'write a poem about Alexander Hamilton and Aaron Burr', " + "the mode would be 'generate_poem'. If the prompt is 'what is the capital of France', the mode would be 'answer_question'." + "And so on, for every mode. If none of these modes apply, please respond with 'unknown'." + ) + + result = await _get_openai_client().chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt}, + ], + ) + content = result.choices[0].message.content + mode = content.lower() + if mode not in MODES: + mode = "unknown" + result = {"mode": mode} + return result, state.update(**result) + + +@streaming_action(reads=["prompt", "chat_history"], writes=["response"]) +async def prompt_for_more(state: State) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + """Not streaming, as we have the result immediately.""" + result = { + "response": { + "content": "None of the response modes I support apply to your question. Please clarify?", + "type": "text", + "role": "assistant", + } + } + for word in result["response"]["content"].split(): + await asyncio.sleep(0.1) + yield {"delta": word + " "}, None + yield result, state.update(**result).append(chat_history=result["response"]) + + +@streaming_action(reads=["prompt", "chat_history", "mode"], writes=["response"]) +async def chat_response( + state: State, prepend_prompt: str, model: str = "gpt-3.5-turbo" +) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + """Streaming action, as we don't have the result immediately. This makes it more interactive""" + chat_history = copy.deepcopy(state["chat_history"]) + chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}" + chat_history_api_format = [ + { + "role": chat["role"], + "content": chat["content"], + } + for chat in chat_history + ] + client = _get_openai_client() + result = await client.chat.completions.create( + model=model, messages=chat_history_api_format, stream=True + ) + buffer = [] + async for chunk in result: + chunk_str = chunk.choices[0].delta.content + if chunk_str is None: + continue + buffer.append(chunk_str) + yield { + "delta": chunk_str, + }, None + + result = { + "response": {"content": "".join(buffer), "type": "text", "role": "assistant"}, + "modified_chat_history": chat_history, + } + yield result, state.update(**result).append(chat_history=result["response"]) + + +@streaming_action(reads=["prompt", "chat_history"], writes=["response"]) +async def unsafe_response(state: State) -> Tuple[dict, State]: + result = { + "response": { + "content": "I am afraid I can't respond to that...", + "type": "text", + "role": "assistant", + } + } + for word in result["response"]["content"].split(): + await asyncio.sleep(0.1) + yield {"delta": word + " "}, None + yield result, state.update(**result).append(chat_history=result["response"]) + + +graph = ( + GraphBuilder() + .with_actions( + prompt=process_prompt, + check_safety=check_safety, + unsafe_response=unsafe_response, + decide_mode=choose_mode, + generate_code=chat_response.bind( + prepend_prompt="Please respond with *only* code and no other text (at all) to the following", + ), + answer_question=chat_response.bind( + prepend_prompt="Please answer the following question", + ), + generate_poem=chat_response.bind( + prepend_prompt="Please generate a poem based on the following prompt", + ), + prompt_for_more=prompt_for_more, + ) + .with_transitions( + ("prompt", "check_safety", default), + ("check_safety", "decide_mode", when(safe=True)), + ("check_safety", "unsafe_response", default), + ("decide_mode", "generate_code", when(mode="generate_code")), + ("decide_mode", "answer_question", when(mode="answer_question")), + ("decide_mode", "generate_poem", when(mode="generate_poem")), + ("decide_mode", "prompt_for_more", default), + ( + [ + "answer_question", + "generate_poem", + "generate_code", + "prompt_for_more", + "unsafe_response", + ], + "prompt", + ), + ) + .build() +) + + +def application(app_id: Optional[str] = None): + return ( + ApplicationBuilder() + .with_entrypoint("prompt") + .with_state(chat_history=[]) + .with_graph(graph) + .with_tracker(project="demo_chatbot_streaming") + .with_identifiers(app_id=app_id) + .build() + ) + + +# TODO -- replace these with action tags when we have the availability +TERMINAL_ACTIONS = [ + "answer_question", + "generate_code", + "prompt_for_more", + "unsafe_response", + "generate_poem", +] +if __name__ == "__main__": + app = application() + app.visualize(output_file_path="statemachine", include_conditions=True, view=True, format="png") diff --git a/examples/streaming-fastapi/notebook.ipynb b/examples/streaming-fastapi/notebook.ipynb new file mode 100644 index 00000000..5bb50255 --- /dev/null +++ b/examples/streaming-fastapi/notebook.ipynb @@ -0,0 +1,336 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1ce738e6-14b4-4892-9f2d-a1db94cfb29a", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install burr[start]" + ] + }, + { + "cell_type": "markdown", + "id": "2118a2f4-d969-4771-81d3-156b432d1dc8", + "metadata": {}, + "source": [ + "# Streaming applications\n", + "\n", + "This shows how one goes about working with a streaming response with Burr, using FastAPI.\n", + "The code for implementation is in [application.py](application.py).\n", + "\n", + "This notebook only shows the streaming side. To check out FastAPI in Burr, check out\n", + "- The [Burr code](./application.py) -- imported and used here\n", + "- The [backend FastAPI server](./server.py) for the streaming output using SSE\n", + "- The [frontend typescript code](https://github.com/dagworks-inc/burr/blob/main/telemetry/ui/src/examples/StreamingChatbot.tsx) that renders and interacts with the stream\n", + "\n", + "You can view this demo in your app by running Burr:\n", + "\n", + "```bash\n", + "burr \n", + "```\n", + "\n", + "This will open a browser on [http://localhost:7241](http://localhost:7241)\n", + "\n", + "Then navigate to the [streaming example](http://localhost:7241/demos/streaming-chatbot)." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5b53de32-86af-475f-8976-e88a08986e34", + "metadata": {}, + "outputs": [], + "source": [ + "from application import application as streaming_application\n", + "from application import TERMINAL_ACTIONS\n", + "import pprint" + ] + }, + { + "cell_type": "markdown", + "id": "675a4245-816f-4d43-b412-307aee83db8e", + "metadata": {}, + "source": [ + "# The application\n", + "\n", + "The application we created will be a simple chatbot proxy. It has a few diffrent modes -- it can either decide a prompt is \"unsafe\" (in this case meaning that it has the word \"unsafe\" in it, but this would typically go to specific model),\n", + "or do one of the following:\n", + "\n", + "1. Generate code\n", + "2. Answer a question\n", + "3. Generate a poem\n", + "4. Prompt for more\n", + "\n", + "It will use an LLM to decide which to do. It streams back text using async streaming in Burr. Read more about how that is implemented [here](https://burr.dagworks.io/concepts/streaming-actions/).\n", + "\n", + "Note that, even though not every response is streaming (E.G. unsafe response, which is hardcoded), they are modeled as streaming to make interaction with the app simpler." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "81799877-8f0f-4572-9a96-f6ce6c430d9a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "prompt\n", + "\n", + "prompt\n", + "\n", + "\n", + "\n", + "check_safety\n", + "\n", + "check_safety\n", + "\n", + "\n", + "\n", + "prompt->check_safety\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__prompt\n", + "\n", + "input: prompt\n", + "\n", + "\n", + "\n", + "input__prompt->prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "unsafe_response\n", + "\n", + "unsafe_response\n", + "\n", + "\n", + "\n", + "check_safety->unsafe_response\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "decide_mode\n", + "\n", + "decide_mode\n", + "\n", + "\n", + "\n", + "check_safety->decide_mode\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "unsafe_response->prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "generate_code\n", + "\n", + "generate_code\n", + "\n", + "\n", + "\n", + "decide_mode->generate_code\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "answer_question\n", + "\n", + "answer_question\n", + "\n", + "\n", + "\n", + "decide_mode->answer_question\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "generate_poem\n", + "\n", + "generate_poem\n", + "\n", + "\n", + "\n", + "decide_mode->generate_poem\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "prompt_for_more\n", + "\n", + "prompt_for_more\n", + "\n", + "\n", + "\n", + "decide_mode->prompt_for_more\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "generate_code->prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__model\n", + "\n", + "input: model\n", + "\n", + "\n", + "\n", + "input__model->generate_code\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__model->answer_question\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__model->generate_poem\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "answer_question->prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "generate_poem->prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "prompt_for_more->prompt\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "app = streaming_application()\n", + "app.visualize()" + ] + }, + { + "cell_type": "markdown", + "id": "6f57c117-8951-4f98-86eb-ea1eac16347b", + "metadata": {}, + "source": [ + "# Calling the application\n", + "\n", + "With async streaming, we get back an `AsyncStreamingResultContainer`. This allows us to get partial results streaming in, while also allowing us to get the full result.\n", + "In the following case, we just " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "eeb23b59-e207-47b1-b500-b60d431396f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Alexander and Aaron, a duo renowned,\n", + "Their story in history forever bound,\n", + "In a duel they met,\n", + "A fate they couldn't forget,\n", + "A tragic end to a rivalry unsewn, unground.\n", + "\n", + "{'response': {'content': 'Alexander and Aaron, a duo renowned,\\n'\n", + " 'Their story in history forever bound,\\n'\n", + " 'In a duel they met,\\n'\n", + " \"A fate they couldn't forget,\\n\"\n", + " 'A tragic end to a rivalry unsewn, unground.',\n", + " 'role': 'assistant',\n", + " 'type': 'text'}}\n" + ] + } + ], + "source": [ + "action, streaming_container = await app.astream_result(\n", + " halt_after=TERMINAL_ACTIONS, inputs={\"prompt\": \"Please generate a limerick about Alexander Hamilton and Aaron Burr\"}\n", + ")\n", + "# Stream results in\n", + "async for item in streaming_container:\n", + " print(item['delta'], end=\"\")\n", + "\n", + "# Or just get the final result\n", + "result, state = await streaming_container.get()\n", + "print(\"\\n\")\n", + "pprint.pprint(result)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/streaming-fastapi/server.py b/examples/streaming-fastapi/server.py new file mode 100644 index 00000000..51a1b15e --- /dev/null +++ b/examples/streaming-fastapi/server.py @@ -0,0 +1,125 @@ +import functools +import importlib +import json +from typing import List, Literal + +import pydantic +from fastapi import APIRouter +from starlette.responses import StreamingResponse + +from burr.core import Application, ApplicationBuilder +from burr.tracking import LocalTrackingClient + +"""This file represents a simple chatbot API backed with Burr. +We manage an application, write to it with post endpoints, and read with +get/ endpoints. + +This demonstrates how you can build interactive web applications with Burr! +""" +# We're doing dynamic import cause this lives within examples/ (and that module has dashes) +# navigate to the examples directory to read more about this! +chat_application = importlib.import_module( + "burr.examples.streaming-fastapi.application" +) # noqa: F401 + +# the app is commented out as we include the router. +# app = FastAPI() + +router = APIRouter() + +graph = chat_application.graph + + +class ChatItem(pydantic.BaseModel): + """Pydantic model for a chat item. This is used to render the chat history.""" + + content: str + type: Literal["image", "text", "code", "error"] + role: Literal["user", "assistant"] + + +@functools.lru_cache(maxsize=128) +def _get_application(project_id: str, app_id: str) -> Application: + """Quick tool to get the application -- caches it""" + tracker = LocalTrackingClient(project=project_id, storage_dir="~/.burr") + return ( + ApplicationBuilder() + .with_graph(graph) + # initializes from the tracking log if it does not already exist + .initialize_from( + tracker, + resume_at_next_action=False, # always resume from entrypoint in the case of failure + default_state={"chat_history": []}, + default_entrypoint="prompt", + ) + .with_tracker(tracker) + .with_identifiers(app_id=app_id) + .build() + ) + + +class PromptInput(pydantic.BaseModel): + prompt: str + + +@router.post("/response/{project_id}/{app_id}", response_class=StreamingResponse) +async def chat_response(project_id: str, app_id: str, prompt: PromptInput) -> StreamingResponse: + """Chat response endpoint. User passes in a prompt and the system returns the + full chat history, so its easier to render. + + :param project_id: Project ID to run + :param app_id: Application ID to run + :param prompt: Prompt to send to the chatbot + :return: + """ + burr_app = _get_application(project_id, app_id) + chat_history = burr_app.state.get("chat_history", []) + action, streaming_container = await burr_app.astream_result( + halt_after=chat_application.TERMINAL_ACTIONS, inputs=dict(prompt=prompt.prompt) + ) + + async def sse_generator(): + """This is a generator that yields Server-Sent Events (SSE) to the client + It is necessary to yield them in a special format to ensure the client can + access them streaming. We type them (using our own simple typing system) then + parse on the client side. Unfortunately, typing these in FastAPI is not feasible.""" + yield f"data: {json.dumps({'type': 'chat_history', 'value': chat_history})}\n\n" + + async for item in streaming_container: + yield f"data: {json.dumps({'type': 'delta', 'value': item['delta']})} \n\n" + + return StreamingResponse(sse_generator()) + + +@router.get("/history/{project_id}/{app_id}", response_model=List[ChatItem]) +def chat_history(project_id: str, app_id: str) -> List[ChatItem]: + """Endpoint to get chat history. Gets the application and returns the chat history from state. + + :param project_id: Project ID + :param app_id: App ID. + :return: The list of chat items in the state + """ + chat_app = _get_application(project_id, app_id) + state = chat_app.state + return state.get("chat_history", []) + + +@router.post("/create/{project_id}/{app_id}", response_model=str) +async def create_new_application(project_id: str, app_id: str) -> str: + """Endpoint to create a new application -- used by the FE when + the user types in a new App ID + + :param project_id: Project ID + :param app_id: App ID + :return: The app ID + """ + # side-effect of this persists it -- see the application function for details + _get_application(app_id=app_id, project_id=project_id) + return app_id # just return it for now + + +# # comment this back in for a standalone chatbot API +# import fastapi +# +# app = fastapi.FastAPI() +# app.include_router(router, prefix="/api/v0/chatbot") diff --git a/examples/streaming-fastapi/statemachine.png b/examples/streaming-fastapi/statemachine.png new file mode 100644 index 00000000..ba5527a4 Binary files /dev/null and b/examples/streaming-fastapi/statemachine.png differ diff --git a/examples/streaming-fastapi/streamlit_app.py b/examples/streaming-fastapi/streamlit_app.py new file mode 100644 index 00000000..9b6bd50b --- /dev/null +++ b/examples/streaming-fastapi/streamlit_app.py @@ -0,0 +1,55 @@ +import asyncio +import uuid + +import application as chatbot_application +import streamlit as st + +import burr.core +from burr.core.action import AsyncStreamingResultContainer + + +def render_chat_message(chat_item: dict): + content = chat_item["content"] + role = chat_item["role"] + with st.chat_message(role): + st.write(content) + + +async def render_streaming_chat_message(stream: AsyncStreamingResultContainer): + buffer = "" + with st.chat_message("assistant"): + # This is very ugly as streamlit does not support async generators + # Thus we have to ignore the benefit of writing the delta and instead write *everything* + with st.empty(): + async for item in stream: + buffer += item["delta"] + st.write(buffer) + + +def initialize_app() -> burr.core.Application: + if "burr_app" not in st.session_state: + st.session_state.burr_app = chatbot_application.application( + app_id=f"chat_streaming:{str(uuid.uuid4())[0:6]}" + ) + return st.session_state.burr_app + + +async def main(): + st.title("Streaming chatbot with Burr") + app = initialize_app() + + prompt = st.chat_input("Ask me a question!", key="chat_input") + for chat_message in app.state.get("chat_history", []): + render_chat_message(chat_message) + + if prompt: + render_chat_message({"role": "user", "content": prompt, "type": "text"}) + with st.spinner(text="Waiting for response..."): + action, streaming_container = await app.astream_result( + halt_after=chatbot_application.TERMINAL_ACTIONS, inputs={"prompt": prompt} + ) + await render_streaming_chat_message(streaming_container) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/telemetry/ui/package-lock.json b/telemetry/ui/package-lock.json index 4a2bf7e2..65b10a90 100644 --- a/telemetry/ui/package-lock.json +++ b/telemetry/ui/package-lock.json @@ -10,6 +10,7 @@ "dependencies": { "@headlessui/react": "^2.0.0-alpha.4", "@heroicons/react": "^2.1.1", + "@microsoft/fetch-event-source": "^2.0.1", "@testing-library/jest-dom": "^5.17.0", "@testing-library/react": "^13.4.0", "@testing-library/user-event": "^13.5.0", @@ -3496,6 +3497,11 @@ "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==" }, + "node_modules/@microsoft/fetch-event-source": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz", + "integrity": "sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA==" + }, "node_modules/@nicolo-ribaudo/eslint-scope-5-internals": { "version": "5.1.1-v1", "resolved": "https://registry.npmjs.org/@nicolo-ribaudo/eslint-scope-5-internals/-/eslint-scope-5-internals-5.1.1-v1.tgz", diff --git a/telemetry/ui/package.json b/telemetry/ui/package.json index 4470de6a..0b361479 100644 --- a/telemetry/ui/package.json +++ b/telemetry/ui/package.json @@ -5,6 +5,7 @@ "dependencies": { "@headlessui/react": "^2.0.0-alpha.4", "@heroicons/react": "^2.1.1", + "@microsoft/fetch-event-source": "^2.0.1", "@testing-library/jest-dom": "^5.17.0", "@testing-library/react": "^13.4.0", "@testing-library/user-event": "^13.5.0", diff --git a/telemetry/ui/src/App.tsx b/telemetry/ui/src/App.tsx index a9f9e9aa..32d5ae00 100644 --- a/telemetry/ui/src/App.tsx +++ b/telemetry/ui/src/App.tsx @@ -8,6 +8,7 @@ import { AppContainer } from './components/nav/appcontainer'; import { ChatbotWithTelemetry } from './examples/Chatbot'; import { Counter } from './examples/Counter'; import { EmailAssistantWithTelemetry } from './examples/EmailAssistant'; +import { StreamingChatbotWithTelemetry } from './examples/StreamingChatbot'; /** * Basic application. We have an AppContainer -- this has a breadcrumb and a sidebar. @@ -36,6 +37,7 @@ const App = () => { } /> } /> } /> + } /> } /> } /> diff --git a/telemetry/ui/src/api/models/ChildApplicationModel.ts b/telemetry/ui/src/api/models/ChildApplicationModel.ts index f98fc169..11206f37 100644 --- a/telemetry/ui/src/api/models/ChildApplicationModel.ts +++ b/telemetry/ui/src/api/models/ChildApplicationModel.ts @@ -12,7 +12,7 @@ export type ChildApplicationModel = { child: PointerModel; event_time: string; event_type: ChildApplicationModel.event_type; - sequence_id: number; + sequence_id: number | null; }; export namespace ChildApplicationModel { export enum event_type { diff --git a/telemetry/ui/src/api/services/DefaultService.ts b/telemetry/ui/src/api/services/DefaultService.ts index 447da977..4c58539f 100644 --- a/telemetry/ui/src/api/services/DefaultService.ts +++ b/telemetry/ui/src/api/services/DefaultService.ts @@ -96,6 +96,18 @@ export class DefaultService { url: '/api/v0/ready' }); } + /** + * Version + * Returns the burr version + * @returns any Successful Response + * @throws ApiError + */ + public static versionApiV0VersionGet(): CancelablePromise> { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v0/version' + }); + } /** * Chat Response * Chat response endpoint. User passes in a prompt and the system returns the @@ -350,6 +362,98 @@ export class DefaultService { url: '/api/v0/email_assistant/validate/{project_id}/{app_id}' }); } + /** + * Chat Response + * Chat response endpoint. User passes in a prompt and the system returns the + * full chat history, so its easier to render. + * + * :param project_id: Project ID to run + * :param app_id: Application ID to run + * :param prompt: Prompt to send to the chatbot + * :return: + * @param projectId + * @param appId + * @param prompt + * @returns any Successful Response + * @throws ApiError + */ + public static chatResponseApiV0StreamingChatbotResponseProjectIdAppIdPost( + projectId: string, + appId: string, + prompt: string + ): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v0/streaming_chatbot/response/{project_id}/{app_id}', + path: { + project_id: projectId, + app_id: appId + }, + query: { + prompt: prompt + }, + errors: { + 422: `Validation Error` + } + }); + } + /** + * Chat History + * Endpoint to get chat history. Gets the application and returns the chat history from state. + * + * :param project_id: Project ID + * :param app_id: App ID. + * :return: The list of chat items in the state + * @param projectId + * @param appId + * @returns ChatItem Successful Response + * @throws ApiError + */ + public static chatHistoryApiV0StreamingChatbotHistoryProjectIdAppIdGet( + projectId: string, + appId: string + ): CancelablePromise> { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v0/streaming_chatbot/history/{project_id}/{app_id}', + path: { + project_id: projectId, + app_id: appId + }, + errors: { + 422: `Validation Error` + } + }); + } + /** + * Create New Application + * Endpoint to create a new application -- used by the FE when + * the user types in a new App ID + * + * :param project_id: Project ID + * :param app_id: App ID + * :return: The app ID + * @param projectId + * @param appId + * @returns string Successful Response + * @throws ApiError + */ + public static createNewApplicationApiV0StreamingChatbotCreateProjectIdAppIdPost( + projectId: string, + appId: string + ): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v0/streaming_chatbot/create/{project_id}/{app_id}', + path: { + project_id: projectId, + app_id: appId + }, + errors: { + 422: `Validation Error` + } + }); + } /** * React App * Quick trick to server the react app diff --git a/telemetry/ui/src/components/nav/appcontainer.tsx b/telemetry/ui/src/components/nav/appcontainer.tsx index 2d04e00c..70a4c817 100644 --- a/telemetry/ui/src/components/nav/appcontainer.tsx +++ b/telemetry/ui/src/components/nav/appcontainer.tsx @@ -93,6 +93,12 @@ export const AppContainer = (props: { children: React.ReactNode }) => { href: '/demos/email-assistant', current: false, linkType: 'internal' + }, + { + name: 'streaming-chatbot', + href: '/demos/streaming-chatbot', + current: false, + linkType: 'internal' } ] }, diff --git a/telemetry/ui/src/components/routes/app/StepList.tsx b/telemetry/ui/src/components/routes/app/StepList.tsx index 53f9c8cd..a361765d 100644 --- a/telemetry/ui/src/components/routes/app/StepList.tsx +++ b/telemetry/ui/src/components/routes/app/StepList.tsx @@ -470,9 +470,9 @@ export const StepList = (props: { const displaySpansCol = props.steps.some((step) => step.spans.length > 0); const displayLinksCol = props.links.length > 0; const linksBySequenceID = props.links.reduce((acc, child) => { - const existing = acc.get(child.sequence_id) || []; + const existing = acc.get(child.sequence_id || -1) || []; existing.push(child); - acc.set(child.sequence_id, existing); + acc.set(child.sequence_id || -1, existing); return acc; }, new Map()); return ( diff --git a/telemetry/ui/src/examples/Chatbot.tsx b/telemetry/ui/src/examples/Chatbot.tsx index c66a0aeb..18e2edd5 100644 --- a/telemetry/ui/src/examples/Chatbot.tsx +++ b/telemetry/ui/src/examples/Chatbot.tsx @@ -2,15 +2,13 @@ import { ComputerDesktopIcon, UserIcon } from '@heroicons/react/24/outline'; import { classNames } from '../utils/tailwind'; import { Button } from '../components/common/button'; import { TwoColumnLayout } from '../components/common/layout'; -import { MiniTelemetry } from './MiniTelemetry'; import { ApplicationSummary, ChatItem, DefaultService } from '../api'; import { KeyboardEvent, useEffect, useState } from 'react'; import { useMutation, useQuery } from 'react-query'; import { Loading } from '../components/common/loading'; import Markdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; -import { DateTimeDisplay } from '../components/common/dates'; -import AsyncCreatableSelect from 'react-select/async-creatable'; +import { TelemetryWithSelector } from './Common'; type Role = 'assistant' | 'user'; @@ -211,95 +209,6 @@ export const Chatbot = (props: { projectId: string; appId: string | undefined }) ); }; -export const TelemetryWithSelector = (props: { - projectId: string; - currentApp: ApplicationSummary | undefined; - setCurrentApp: (app: ApplicationSummary) => void; -}) => { - return ( -
-
- -
- -
- ); -}; - -const Label = (props: { application: ApplicationSummary }) => { - return ( -
-
- {props.application.num_steps} - {props.application.app_id} -
- -
- ); -}; - -export const ChatbotAppSelector = (props: { - projectId: string; - setApp: (app: ApplicationSummary) => void; - currentApp: ApplicationSummary | undefined; - placeholder: string; -}) => { - const { projectId, setApp } = props; - const { data, refetch } = useQuery( - ['apps', projectId], - () => DefaultService.getAppsApiV0ProjectIdAppsGet(projectId as string), - { enabled: projectId !== undefined } - ); - const createAndUpdateMutation = useMutation( - (app_id: string) => - DefaultService.createNewApplicationApiV0ChatbotCreateProjectIdAppIdPost(projectId, app_id), - - { - onSuccess: (appID) => { - refetch().then((data) => { - const appSummaries = data.data || []; - const app = appSummaries.find((app) => app.app_id === appID); - if (app) { - setApp(app); - } - }); - } - } - ); - const appSetter = (appID: string) => createAndUpdateMutation.mutate(appID); - const dataOrEmpty = Array.from(data || []); - const options = dataOrEmpty - .sort((a, b) => { - return new Date(a.last_written) > new Date(b.last_written) ? -1 : 1; - }) - .map((app) => { - return { - value: app.app_id, - label: