Skip to content

Commit

Permalink
Very WIP streaming example
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Jul 4, 2024
1 parent 36cda37 commit b4127e7
Show file tree
Hide file tree
Showing 15 changed files with 913 additions and 95 deletions.
2 changes: 2 additions & 0 deletions burr/tracking/server/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# dynamic importing due to the dashes (which make reading the examples on github easier)
email_assistant = importlib.import_module("burr.examples.email-assistant.server")
chatbot = importlib.import_module("burr.examples.multi-modal-chatbot.server")
streaming_chatbot = importlib.import_module("burr.examples.streaming-fastapi.server")

except ImportError as e:
require_plugin(
Expand Down Expand Up @@ -97,6 +98,7 @@ async def version() -> dict:
# Examples -- todo -- put them behind `if` statements
app.include_router(chatbot.router, prefix="/api/v0/chatbot")
app.include_router(email_assistant.router, prefix="/api/v0/email_assistant")
app.include_router(streaming_chatbot.router, prefix="/api/v0/streaming_chatbot")
# email_assistant.register(app, "/api/v0/email_assistant")


Expand Down
Empty file.
191 changes: 191 additions & 0 deletions examples/streaming-fastapi/application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import asyncio
import copy
from typing import AsyncGenerator, Optional, Tuple

import openai

from burr.core import ApplicationBuilder, State, default, when
from burr.core.action import action, streaming_action
from burr.core.graph import GraphBuilder

MODES = [
"answer_question",
"generate_poem",
"generate_code",
"unknown",
]


@action(reads=[], writes=["chat_history", "prompt"])
def process_prompt(state: State, prompt: str) -> Tuple[dict, State]:
result = {"chat_item": {"role": "user", "content": prompt, "type": "text"}}
return result, state.wipe(keep=["prompt", "chat_history"]).append(
chat_history=result["chat_item"]
).update(prompt=prompt)


@action(reads=["prompt"], writes=["safe"])
def check_safety(state: State) -> Tuple[dict, State]:
result = {"safe": "unsafe" not in state["prompt"]} # quick hack to demonstrate
return result, state.update(safe=result["safe"])


def _get_openai_client():
return openai.AsyncOpenAI()


@action(reads=["prompt"], writes=["mode"])
async def choose_mode(state: State) -> Tuple[dict, State]:
prompt = (
f"You are a chatbot. You've been prompted this: {state['prompt']}. "
f"You have the capability of responding in the following modes: {', '.join(MODES)}. "
"Please respond with *only* a single word representing the mode that most accurately "
"corresponds to the prompt. Fr instance, if the prompt is 'write a poem about Alexander Hamilton and Aaron Burr', "
"the mode would be 'generate_poem'. If the prompt is 'what is the capital of France', the mode would be 'answer_question'."
"And so on, for every mode. If none of these modes apply, please respond with 'unknown'."
)

result = await _get_openai_client().chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": prompt},
],
)
content = result.choices[0].message.content
mode = content.lower()
if mode not in MODES:
mode = "unknown"
result = {"mode": mode}
return result, state.update(**result)


@streaming_action(reads=["prompt", "chat_history"], writes=["response"])
async def prompt_for_more(state: State) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
"""Not streaming, as we have the result immediately."""
result = {
"response": {
"content": "None of the response modes I support apply to your question. Please clarify?",
"type": "text",
"role": "assistant",
}
}
for word in result["response"]["content"].split():
await asyncio.sleep(0.1)
yield {"delta": word + " "}, None
yield result, state.update(**result).append(chat_history=result["response"])


@streaming_action(reads=["prompt", "chat_history", "mode"], writes=["response"])
async def chat_response(
state: State, prepend_prompt: str, model: str = "gpt-3.5-turbo"
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
"""Streaming action, as we don't have the result immediately. This makes it more interactive"""
chat_history = copy.deepcopy(state["chat_history"])
chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}"
chat_history_api_format = [
{
"role": chat["role"],
"content": chat["content"],
}
for chat in chat_history
]
client = _get_openai_client()
result = await client.chat.completions.create(
model=model, messages=chat_history_api_format, stream=True
)
buffer = []
async for chunk in result:
chunk_str = chunk.choices[0].delta.content
if chunk_str is None:
continue
buffer.append(chunk_str)
yield {
"delta": chunk_str,
}, None

result = {"response": {"content": "".join(buffer), "type": "text", "role": "assistant"}}
yield result, state.update(**result).append(chat_history=result["response"])


@streaming_action(reads=["prompt", "chat_history"], writes=["response"])
async def unsafe_response(state: State) -> Tuple[dict, State]:
result = {
"response": {
"content": "I am afraid I can't respond to that...",
"type": "text",
"role": "assistant",
}
}
for word in result["response"]["content"].split():
await asyncio.sleep(0.03)
yield {"delta": word + " "}, None
yield result, state.update(**result).append(chat_history=result["response"])


graph = (
GraphBuilder()
.with_actions(
prompt=process_prompt,
check_safety=check_safety,
unsafe_response=unsafe_response,
decide_mode=choose_mode,
generate_code=chat_response.bind(
prepend_prompt="Please respond with *only* code and no other text (at all) to the following",
),
answer_question=chat_response.bind(
prepend_prompt="Please answer the following question",
),
generate_poem=chat_response.bind(
prepend_prompt="Please generate a poem based on the following prompt",
),
prompt_for_more=prompt_for_more,
)
.with_transitions(
("prompt", "check_safety", default),
("check_safety", "decide_mode", when(safe=True)),
("check_safety", "unsafe_response", default),
("decide_mode", "generate_code", when(mode="generate_code")),
("decide_mode", "answer_question", when(mode="answer_question")),
("decide_mode", "generate_poem", when(mode="generate_poem")),
("decide_mode", "prompt_for_more", default),
(
[
"answer_question",
"generate_poem",
"generate_code",
"prompt_for_more",
"unsafe_response",
],
"prompt",
),
)
.build()
)


def application(app_id: Optional[str] = None):
return (
ApplicationBuilder()
.with_entrypoint("prompt")
.with_state(chat_history=[])
.with_graph(graph)
.with_tracker(project="demo_chatbot_streaming")
.with_identifiers(app_id=app_id)
.build()
)


# TODO -- replace these with action tags when we have the availability
TERMINAL_ACTIONS = [
"answer_question",
"generate_code",
"prompt_for_more",
"unsafe_response",
"generate_poem",
]
if __name__ == "__main__":
app = application()
app.visualize(
output_file_path="statemachine", include_conditions=False, view=True, format="png"
)
121 changes: 121 additions & 0 deletions examples/streaming-fastapi/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import functools
import importlib
import json
from typing import List, Literal

import pydantic
from fastapi import APIRouter
from starlette.responses import StreamingResponse

from burr.core import Application, ApplicationBuilder
from burr.tracking import LocalTrackingClient

"""This file represents a simple chatbot API backed with Burr.
We manage an application, write to it with post endpoints, and read with
get/ endpoints.
This demonstrates how you can build interactive web applications with Burr!
"""
# We're doing dynamic import cause this lives within examples/ (and that module has dashes)
# navigate to the examples directory to read more about this!
chat_application = importlib.import_module(
"burr.examples.streaming-fastapi.application"
) # noqa: F401

# the app is commented out as we include the router.
# app = FastAPI()

router = APIRouter()

graph = chat_application.graph


class ChatItem(pydantic.BaseModel):
"""Pydantic model for a chat item. This is used to render the chat history."""

content: str
type: Literal["image", "text", "code", "error"]
role: Literal["user", "assistant"]


@functools.lru_cache(maxsize=128)
def _get_application(project_id: str, app_id: str) -> Application:
"""Quick tool to get the application -- caches it"""
tracker = LocalTrackingClient(project=project_id, storage_dir="~/.burr")
return (
ApplicationBuilder()
.with_graph(graph)
# initializes from the tracking log if it does not already exist
.initialize_from(
tracker,
resume_at_next_action=False, # always resume from entrypoint in the case of failure
default_state={"chat_history": []},
default_entrypoint="prompt",
)
.with_tracker(tracker)
.with_identifiers(app_id=app_id)
.build()
)


@router.post("/response/{project_id}/{app_id}", response_class=StreamingResponse)
async def chat_response(project_id: str, app_id: str, prompt: str) -> StreamingResponse:
"""Chat response endpoint. User passes in a prompt and the system returns the
full chat history, so its easier to render.
:param project_id: Project ID to run
:param app_id: Application ID to run
:param prompt: Prompt to send to the chatbot
:return:
"""
burr_app = _get_application(project_id, app_id)
chat_history = burr_app.state.get("chat_history", [])
action, streaming_container = await burr_app.astream_result(
halt_after=chat_application.TERMINAL_ACTIONS, inputs=dict(prompt=prompt)
)

async def sse_generator():
"""This is a generator that yields Server-Sent Events (SSE) to the client
It is necessary to yield them in a special format to ensure the client can
access them streaming. We type them (using our own simple typing system) then
parse on the client side. Unfortunately, typing these in FastAPI is not feasible."""
yield f"data: {json.dumps({'type': 'chat_history', 'value': chat_history})}\n\n"

async for item in streaming_container:
yield f"data: {json.dumps({'type': 'delta', 'value' : item['delta']})} \n\n"

return StreamingResponse(sse_generator())


@router.get("/response/{project_id}/{app_id}", response_model=List[ChatItem])
def chat_history(project_id: str, app_id: str) -> List[ChatItem]:
"""Endpoint to get chat history. Gets the application and returns the chat history from state.
:param project_id: Project ID
:param app_id: App ID.
:return: The list of chat items in the state
"""
chat_app = _get_application(project_id, app_id)
state = chat_app.state
return state.get("chat_history", [])


@router.post("/create/{project_id}/{app_id}", response_model=str)
async def create_new_application(project_id: str, app_id: str) -> str:
"""Endpoint to create a new application -- used by the FE when
the user types in a new App ID
:param project_id: Project ID
:param app_id: App ID
:return: The app ID
"""
# side-effect of this persists it -- see the application function for details
_get_application(app_id=app_id, project_id=project_id)
return app_id # just return it for now


# # comment this back in for a standalone chatbot API
# import fastapi
#
# app = fastapi.FastAPI()
# app.include_router(router, prefix="/api/v0/chatbot")
55 changes: 55 additions & 0 deletions examples/streaming-fastapi/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import uuid

import application as chatbot_application
import streamlit as st

import burr.core
from burr.core.action import AsyncStreamingResultContainer


def render_chat_message(chat_item: dict):
content = chat_item["content"]
role = chat_item["role"]
with st.chat_message(role):
st.write(content)


async def render_streaming_chat_message(stream: AsyncStreamingResultContainer):
buffer = ""
with st.chat_message("assistant"):
# This is very ugly as streamlit does not support async generators
# Thus we have to ignore the benefit of writing the delta and instead write *everything*
with st.empty():
async for item in stream:
buffer += item["delta"]
st.write(buffer)


def initialize_app() -> burr.core.Application:
if "burr_app" not in st.session_state:
st.session_state.burr_app = chatbot_application.application(
app_id=f"chat_streaming:{str(uuid.uuid4())[0:6]}"
)
return st.session_state.burr_app


async def main():
st.title("Streaming chatbot with Burr")
app = initialize_app()

prompt = st.chat_input("Ask me a question!", key="chat_input")
for chat_message in app.state.get("chat_history", []):
render_chat_message(chat_message)

if prompt:
render_chat_message({"role": "user", "content": prompt, "type": "text"})
with st.spinner(text="Waiting for response..."):
action, streaming_container = await app.astream_result(
halt_after=chatbot_application.TERMINAL_ACTIONS, inputs={"prompt": prompt}
)
await render_streaming_chat_message(streaming_container)


if __name__ == "__main__":
asyncio.run(main())
Loading

0 comments on commit b4127e7

Please sign in to comment.