From 00104983e5903317023c7a6cd22c10f8252d515b Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Fri, 15 Mar 2024 14:24:50 -0700 Subject: [PATCH] Updates server to have chatbot example We still need to fix this up, it doesn't work with the python path, and examples are not in the server. --- burr/tracking/server/backend.py | 8 ++- burr/tracking/server/examples/chatbot.py | 68 ++++++++++++++++++++++++ burr/tracking/server/run.py | 12 +++-- 3 files changed, 80 insertions(+), 8 deletions(-) create mode 100644 burr/tracking/server/examples/chatbot.py diff --git a/burr/tracking/server/backend.py b/burr/tracking/server/backend.py index 2ae7c8ad2..4096ffd1c 100644 --- a/burr/tracking/server/backend.py +++ b/burr/tracking/server/backend.py @@ -145,10 +145,10 @@ async def get_application_logs( if not os.path.exists(graph_file): raise fastapi.HTTPException( status_code=404, - detail=f"Graph file not found for app: " - f"{app_id} from project: {project_id}. " - f"Was this properly executed?", + detail=f"Graph file for app: {app_id} from project: {project_id} not found", ) + async with aiofiles.open(graph_file) as f: + str_graph = await f.read() steps_by_sequence_id = {} spans_by_id = {} if os.path.exists(log_file): @@ -179,8 +179,6 @@ async def get_application_logs( for span in spans_by_id.values(): step = steps_by_sequence_id[span.begin_entry.action_sequence_id] step.spans.append(span) - async with aiofiles.open(graph_file) as f: - str_graph = await f.read() return ApplicationLogs( application=schema.ApplicationModel.parse_raw(str_graph), steps=list(steps_by_sequence_id.values()), diff --git a/burr/tracking/server/examples/chatbot.py b/burr/tracking/server/examples/chatbot.py new file mode 100644 index 000000000..54bcb6301 --- /dev/null +++ b/burr/tracking/server/examples/chatbot.py @@ -0,0 +1,68 @@ +import functools +from typing import List, Literal + +import pydantic +from fastapi import FastAPI +from starlette.requests import Request + +from burr.core import Application, State +from burr.tracking import LocalTrackingClient + +from examples.gpt import application as chat_application + + +class ChatItem(pydantic.BaseModel): + content: str + type: Literal["image", "text", "code", "error"] + role: Literal["user", "assistant"] + + +@functools.lru_cache(maxsize=128) +def _get_application(project_id: str, app_id: str) -> Application: + app = chat_application.application(use_hamilton=False, app_id=app_id, project_id="demo:chatbot") + if LocalTrackingClient.app_log_exists(project_id, app_id): + state, _ = LocalTrackingClient.load_state(project_id, app_id) # TODO -- handle entrypoint + app.update_state( + State(state) + ) # TODO -- handle the entrypoint -- this will always reset to prompt + return app + + +def chat_response(project_id: str, app_id: str, prompt: str) -> List[ChatItem]: + """Chat response endpoint. + + :param project_id: Project ID to run + :param app_id: Application ID to run + :param prompt: Prompt to send to the chatbot + :return: + """ + burr_app = _get_application(project_id, app_id) + _, _, state = burr_app.run(halt_after=["response"], inputs=dict(prompt=prompt)) + return state.get("chat_history", []) + + +def chat_history(project_id: str, app_id: str) -> List[ChatItem]: + burr_app = _get_application(project_id, app_id) + state = burr_app.state + return state.get("chat_history", []) + + +async def create_new_application(request: Request, project_id: str, app_id: str) -> str: + """Quick helper to create a new application. Just returns true, you'll want to fetch afterwards. + In a better chatbot you'd want to either have the frontend store this and create on demand or return + the actual application model""" + # side-effect is to create the application + chat_application.application(use_hamilton=False, app_id=app_id, project_id=project_id) + return app_id # just return it for now + + +def register(app: FastAPI, api_prefix: str): + app.post(f"{api_prefix}/{{project_id}}/{{app_id}}/response", response_model=List[ChatItem])( + chat_response + ) + app.get(f"{api_prefix}/{{project_id}}/{{app_id}}/history", response_model=List[ChatItem])( + chat_history + ) + app.post(f"{api_prefix}/{{project_id}}/{{app_id}}/create", response_model=str)( + create_new_application + ) diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index d51c5bb32..ddc631441 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -3,6 +3,7 @@ from typing import Sequence from burr.integrations.base import require_plugin +from burr.tracking.server.examples import chatbot try: import uvicorn @@ -10,7 +11,8 @@ from fastapi.staticfiles import StaticFiles from starlette.templating import Jinja2Templates - from burr.tracking.server import backend, schema + from burr.tracking.server import backend as backend_module + from burr.tracking.server import schema from burr.tracking.server.schema import ApplicationLogs except ImportError as e: require_plugin( @@ -30,10 +32,10 @@ app = FastAPI() -backend = backend.LocalBackend() - SERVE_STATIC = os.getenv("BURR_SERVE_STATIC", "true").lower() == "true" +backend = backend_module.LocalBackend() + @app.get("/api/v0/projects", response_model=Sequence[schema.Project]) async def get_projects(request: Request) -> Sequence[schema.Project]: @@ -75,6 +77,10 @@ async def ready() -> bool: return True +# Examples -- todo -- put them behind `if` statements +chatbot.register(app, "/api/v0/chatbot") + + if SERVE_STATIC: BASE_ASSET_DIRECTORY = str(files("burr").joinpath("tracking/server/build"))