-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
36cda37
commit b4127e7
Showing
15 changed files
with
913 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import asyncio | ||
import copy | ||
from typing import AsyncGenerator, Optional, Tuple | ||
|
||
import openai | ||
|
||
from burr.core import ApplicationBuilder, State, default, when | ||
from burr.core.action import action, streaming_action | ||
from burr.core.graph import GraphBuilder | ||
|
||
MODES = [ | ||
"answer_question", | ||
"generate_poem", | ||
"generate_code", | ||
"unknown", | ||
] | ||
|
||
|
||
@action(reads=[], writes=["chat_history", "prompt"]) | ||
def process_prompt(state: State, prompt: str) -> Tuple[dict, State]: | ||
result = {"chat_item": {"role": "user", "content": prompt, "type": "text"}} | ||
return result, state.wipe(keep=["prompt", "chat_history"]).append( | ||
chat_history=result["chat_item"] | ||
).update(prompt=prompt) | ||
|
||
|
||
@action(reads=["prompt"], writes=["safe"]) | ||
def check_safety(state: State) -> Tuple[dict, State]: | ||
result = {"safe": "unsafe" not in state["prompt"]} # quick hack to demonstrate | ||
return result, state.update(safe=result["safe"]) | ||
|
||
|
||
def _get_openai_client(): | ||
return openai.AsyncOpenAI() | ||
|
||
|
||
@action(reads=["prompt"], writes=["mode"]) | ||
async def choose_mode(state: State) -> Tuple[dict, State]: | ||
prompt = ( | ||
f"You are a chatbot. You've been prompted this: {state['prompt']}. " | ||
f"You have the capability of responding in the following modes: {', '.join(MODES)}. " | ||
"Please respond with *only* a single word representing the mode that most accurately " | ||
"corresponds to the prompt. Fr instance, if the prompt is 'write a poem about Alexander Hamilton and Aaron Burr', " | ||
"the mode would be 'generate_poem'. If the prompt is 'what is the capital of France', the mode would be 'answer_question'." | ||
"And so on, for every mode. If none of these modes apply, please respond with 'unknown'." | ||
) | ||
|
||
result = await _get_openai_client().chat.completions.create( | ||
model="gpt-4o", | ||
messages=[ | ||
{"role": "system", "content": "You are a helpful assistant"}, | ||
{"role": "user", "content": prompt}, | ||
], | ||
) | ||
content = result.choices[0].message.content | ||
mode = content.lower() | ||
if mode not in MODES: | ||
mode = "unknown" | ||
result = {"mode": mode} | ||
return result, state.update(**result) | ||
|
||
|
||
@streaming_action(reads=["prompt", "chat_history"], writes=["response"]) | ||
async def prompt_for_more(state: State) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: | ||
"""Not streaming, as we have the result immediately.""" | ||
result = { | ||
"response": { | ||
"content": "None of the response modes I support apply to your question. Please clarify?", | ||
"type": "text", | ||
"role": "assistant", | ||
} | ||
} | ||
for word in result["response"]["content"].split(): | ||
await asyncio.sleep(0.1) | ||
yield {"delta": word + " "}, None | ||
yield result, state.update(**result).append(chat_history=result["response"]) | ||
|
||
|
||
@streaming_action(reads=["prompt", "chat_history", "mode"], writes=["response"]) | ||
async def chat_response( | ||
state: State, prepend_prompt: str, model: str = "gpt-3.5-turbo" | ||
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: | ||
"""Streaming action, as we don't have the result immediately. This makes it more interactive""" | ||
chat_history = copy.deepcopy(state["chat_history"]) | ||
chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}" | ||
chat_history_api_format = [ | ||
{ | ||
"role": chat["role"], | ||
"content": chat["content"], | ||
} | ||
for chat in chat_history | ||
] | ||
client = _get_openai_client() | ||
result = await client.chat.completions.create( | ||
model=model, messages=chat_history_api_format, stream=True | ||
) | ||
buffer = [] | ||
async for chunk in result: | ||
chunk_str = chunk.choices[0].delta.content | ||
if chunk_str is None: | ||
continue | ||
buffer.append(chunk_str) | ||
yield { | ||
"delta": chunk_str, | ||
}, None | ||
|
||
result = {"response": {"content": "".join(buffer), "type": "text", "role": "assistant"}} | ||
yield result, state.update(**result).append(chat_history=result["response"]) | ||
|
||
|
||
@streaming_action(reads=["prompt", "chat_history"], writes=["response"]) | ||
async def unsafe_response(state: State) -> Tuple[dict, State]: | ||
result = { | ||
"response": { | ||
"content": "I am afraid I can't respond to that...", | ||
"type": "text", | ||
"role": "assistant", | ||
} | ||
} | ||
for word in result["response"]["content"].split(): | ||
await asyncio.sleep(0.03) | ||
yield {"delta": word + " "}, None | ||
yield result, state.update(**result).append(chat_history=result["response"]) | ||
|
||
|
||
graph = ( | ||
GraphBuilder() | ||
.with_actions( | ||
prompt=process_prompt, | ||
check_safety=check_safety, | ||
unsafe_response=unsafe_response, | ||
decide_mode=choose_mode, | ||
generate_code=chat_response.bind( | ||
prepend_prompt="Please respond with *only* code and no other text (at all) to the following", | ||
), | ||
answer_question=chat_response.bind( | ||
prepend_prompt="Please answer the following question", | ||
), | ||
generate_poem=chat_response.bind( | ||
prepend_prompt="Please generate a poem based on the following prompt", | ||
), | ||
prompt_for_more=prompt_for_more, | ||
) | ||
.with_transitions( | ||
("prompt", "check_safety", default), | ||
("check_safety", "decide_mode", when(safe=True)), | ||
("check_safety", "unsafe_response", default), | ||
("decide_mode", "generate_code", when(mode="generate_code")), | ||
("decide_mode", "answer_question", when(mode="answer_question")), | ||
("decide_mode", "generate_poem", when(mode="generate_poem")), | ||
("decide_mode", "prompt_for_more", default), | ||
( | ||
[ | ||
"answer_question", | ||
"generate_poem", | ||
"generate_code", | ||
"prompt_for_more", | ||
"unsafe_response", | ||
], | ||
"prompt", | ||
), | ||
) | ||
.build() | ||
) | ||
|
||
|
||
def application(app_id: Optional[str] = None): | ||
return ( | ||
ApplicationBuilder() | ||
.with_entrypoint("prompt") | ||
.with_state(chat_history=[]) | ||
.with_graph(graph) | ||
.with_tracker(project="demo_chatbot_streaming") | ||
.with_identifiers(app_id=app_id) | ||
.build() | ||
) | ||
|
||
|
||
# TODO -- replace these with action tags when we have the availability | ||
TERMINAL_ACTIONS = [ | ||
"answer_question", | ||
"generate_code", | ||
"prompt_for_more", | ||
"unsafe_response", | ||
"generate_poem", | ||
] | ||
if __name__ == "__main__": | ||
app = application() | ||
app.visualize( | ||
output_file_path="statemachine", include_conditions=False, view=True, format="png" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import functools | ||
import importlib | ||
import json | ||
from typing import List, Literal | ||
|
||
import pydantic | ||
from fastapi import APIRouter | ||
from starlette.responses import StreamingResponse | ||
|
||
from burr.core import Application, ApplicationBuilder | ||
from burr.tracking import LocalTrackingClient | ||
|
||
"""This file represents a simple chatbot API backed with Burr. | ||
We manage an application, write to it with post endpoints, and read with | ||
get/ endpoints. | ||
This demonstrates how you can build interactive web applications with Burr! | ||
""" | ||
# We're doing dynamic import cause this lives within examples/ (and that module has dashes) | ||
# navigate to the examples directory to read more about this! | ||
chat_application = importlib.import_module( | ||
"burr.examples.streaming-fastapi.application" | ||
) # noqa: F401 | ||
|
||
# the app is commented out as we include the router. | ||
# app = FastAPI() | ||
|
||
router = APIRouter() | ||
|
||
graph = chat_application.graph | ||
|
||
|
||
class ChatItem(pydantic.BaseModel): | ||
"""Pydantic model for a chat item. This is used to render the chat history.""" | ||
|
||
content: str | ||
type: Literal["image", "text", "code", "error"] | ||
role: Literal["user", "assistant"] | ||
|
||
|
||
@functools.lru_cache(maxsize=128) | ||
def _get_application(project_id: str, app_id: str) -> Application: | ||
"""Quick tool to get the application -- caches it""" | ||
tracker = LocalTrackingClient(project=project_id, storage_dir="~/.burr") | ||
return ( | ||
ApplicationBuilder() | ||
.with_graph(graph) | ||
# initializes from the tracking log if it does not already exist | ||
.initialize_from( | ||
tracker, | ||
resume_at_next_action=False, # always resume from entrypoint in the case of failure | ||
default_state={"chat_history": []}, | ||
default_entrypoint="prompt", | ||
) | ||
.with_tracker(tracker) | ||
.with_identifiers(app_id=app_id) | ||
.build() | ||
) | ||
|
||
|
||
@router.post("/response/{project_id}/{app_id}", response_class=StreamingResponse) | ||
async def chat_response(project_id: str, app_id: str, prompt: str) -> StreamingResponse: | ||
"""Chat response endpoint. User passes in a prompt and the system returns the | ||
full chat history, so its easier to render. | ||
:param project_id: Project ID to run | ||
:param app_id: Application ID to run | ||
:param prompt: Prompt to send to the chatbot | ||
:return: | ||
""" | ||
burr_app = _get_application(project_id, app_id) | ||
chat_history = burr_app.state.get("chat_history", []) | ||
action, streaming_container = await burr_app.astream_result( | ||
halt_after=chat_application.TERMINAL_ACTIONS, inputs=dict(prompt=prompt) | ||
) | ||
|
||
async def sse_generator(): | ||
"""This is a generator that yields Server-Sent Events (SSE) to the client | ||
It is necessary to yield them in a special format to ensure the client can | ||
access them streaming. We type them (using our own simple typing system) then | ||
parse on the client side. Unfortunately, typing these in FastAPI is not feasible.""" | ||
yield f"data: {json.dumps({'type': 'chat_history', 'value': chat_history})}\n\n" | ||
|
||
async for item in streaming_container: | ||
yield f"data: {json.dumps({'type': 'delta', 'value' : item['delta']})} \n\n" | ||
|
||
return StreamingResponse(sse_generator()) | ||
|
||
|
||
@router.get("/response/{project_id}/{app_id}", response_model=List[ChatItem]) | ||
def chat_history(project_id: str, app_id: str) -> List[ChatItem]: | ||
"""Endpoint to get chat history. Gets the application and returns the chat history from state. | ||
:param project_id: Project ID | ||
:param app_id: App ID. | ||
:return: The list of chat items in the state | ||
""" | ||
chat_app = _get_application(project_id, app_id) | ||
state = chat_app.state | ||
return state.get("chat_history", []) | ||
|
||
|
||
@router.post("/create/{project_id}/{app_id}", response_model=str) | ||
async def create_new_application(project_id: str, app_id: str) -> str: | ||
"""Endpoint to create a new application -- used by the FE when | ||
the user types in a new App ID | ||
:param project_id: Project ID | ||
:param app_id: App ID | ||
:return: The app ID | ||
""" | ||
# side-effect of this persists it -- see the application function for details | ||
_get_application(app_id=app_id, project_id=project_id) | ||
return app_id # just return it for now | ||
|
||
|
||
# # comment this back in for a standalone chatbot API | ||
# import fastapi | ||
# | ||
# app = fastapi.FastAPI() | ||
# app.include_router(router, prefix="/api/v0/chatbot") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import asyncio | ||
import uuid | ||
|
||
import application as chatbot_application | ||
import streamlit as st | ||
|
||
import burr.core | ||
from burr.core.action import AsyncStreamingResultContainer | ||
|
||
|
||
def render_chat_message(chat_item: dict): | ||
content = chat_item["content"] | ||
role = chat_item["role"] | ||
with st.chat_message(role): | ||
st.write(content) | ||
|
||
|
||
async def render_streaming_chat_message(stream: AsyncStreamingResultContainer): | ||
buffer = "" | ||
with st.chat_message("assistant"): | ||
# This is very ugly as streamlit does not support async generators | ||
# Thus we have to ignore the benefit of writing the delta and instead write *everything* | ||
with st.empty(): | ||
async for item in stream: | ||
buffer += item["delta"] | ||
st.write(buffer) | ||
|
||
|
||
def initialize_app() -> burr.core.Application: | ||
if "burr_app" not in st.session_state: | ||
st.session_state.burr_app = chatbot_application.application( | ||
app_id=f"chat_streaming:{str(uuid.uuid4())[0:6]}" | ||
) | ||
return st.session_state.burr_app | ||
|
||
|
||
async def main(): | ||
st.title("Streaming chatbot with Burr") | ||
app = initialize_app() | ||
|
||
prompt = st.chat_input("Ask me a question!", key="chat_input") | ||
for chat_message in app.state.get("chat_history", []): | ||
render_chat_message(chat_message) | ||
|
||
if prompt: | ||
render_chat_message({"role": "user", "content": prompt, "type": "text"}) | ||
with st.spinner(text="Waiting for response..."): | ||
action, streaming_container = await app.astream_result( | ||
halt_after=chatbot_application.TERMINAL_ACTIONS, inputs={"prompt": prompt} | ||
) | ||
await render_streaming_chat_message(streaming_container) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
Oops, something went wrong.