Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Mar 12, 2024
1 parent 2f2bf62 commit 1435bac
Show file tree
Hide file tree
Showing 19 changed files with 2,006 additions and 170 deletions.
40 changes: 30 additions & 10 deletions burr/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,28 +85,48 @@ def __init__(
def get_storage_path(cls, project, storage_dir):
return os.path.join(os.path.expanduser(storage_dir), project)

@classmethod
def app_log_exists(
cls,
project: str,
app_id: str,
storage_dir: str = DEFAULT_STORAGE_DIR,
) -> bool:
"""Function to check if state exists for a given project and app_id.
:param project: the name of the project
:param app_id: the application instance id
:param storage_dir: the storage directory.
:return: True if state exists, False otherwise.
"""
path = os.path.join(cls.get_storage_path(project, storage_dir), app_id, cls.LOG_FILENAME)
return os.path.exists(path)

@classmethod
def load_state(
cls,
project: str,
app_id: str,
sequence_no: int = -1,
sequence_id: int = -1,
storage_dir: str = DEFAULT_STORAGE_DIR,
) -> tuple[dict, str]:
"""Function to load state from what the tracking client got.
It defaults to loading the last state, but you can supply a sequence number.
We will make loading state more ergonomic, but at this time this is what you get.
This is a temporary solution -- not particularly ergonomic, and makes assumptions (particularly that
all logging is in order), which is fine for now. We will be improving this and making it a first-class
citizen.
:param project: the name of the project
:param app_id: the application instance id
:param sequence_no: the sequence number of the state to load. Defaults to last index (i.e. -1).
:param sequence_id: the sequence number of the state to load. Defaults to last index (i.e. -1).
:param storage_dir: the storage directory.
:return: the state as a dictionary, and the entry point as a string.
"""
if sequence_no is None:
sequence_no = -1 # get the last one
if sequence_id is None:
sequence_id = -1 # get the last one
path = os.path.join(cls.get_storage_path(project, storage_dir), app_id, cls.LOG_FILENAME)
if not os.path.exists(path):
raise ValueError(f"No logs found for {project}/{app_id} under {storage_dir}")
Expand All @@ -117,15 +137,15 @@ def load_state(
# filter to only end_entry
json_lines = [js_line for js_line in json_lines if js_line["type"] == "end_entry"]
try:
line = json_lines[sequence_no]
line = json_lines[sequence_id]
except IndexError:
raise ValueError(f"Sequence number {sequence_no} not found for {project}/{app_id}.")
raise ValueError(f"Sequence number {sequence_id} not found for {project}/{app_id}.")
# check sequence number matches if non-negative
line_seq = int(line["sequence_no"])
if -1 < sequence_no != line_seq:
line_seq = int(line["sequence_id"])
if -1 < sequence_id or sequence_id != line_seq:
logger.warning(
f"Sequence number mismatch. For {project}/{app_id}: "
f"actual:{line_seq} != expected:{sequence_no}"
f"actual:{line_seq} != expected:{sequence_id}"
)
# get the prior state
prior_state = line["state"]
Expand Down
8 changes: 3 additions & 5 deletions burr/tracking/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ async def get_application_logs(
if not os.path.exists(graph_file):
raise fastapi.HTTPException(
status_code=404,
detail=f"Graph file not found for app: "
f"{app_id} from project: {project_id}. "
f"Was this properly executed?",
detail=f"Graph file for app: {app_id} from project: {project_id} not found",
)
async with aiofiles.open(graph_file) as f:
str_graph = await f.read()
steps_by_sequence_id = {}
spans_by_id = {}
if os.path.exists(log_file):
Expand Down Expand Up @@ -179,8 +179,6 @@ async def get_application_logs(
for span in spans_by_id.values():
step = steps_by_sequence_id[span.begin_entry.action_sequence_id]
step.spans.append(span)
async with aiofiles.open(graph_file) as f:
str_graph = await f.read()
return ApplicationLogs(
application=schema.ApplicationModel.parse_raw(str_graph),
steps=list(steps_by_sequence_id.values()),
Expand Down
60 changes: 60 additions & 0 deletions burr/tracking/server/examples/chatbot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import functools
from typing import List, Literal

import pydantic
from fastapi import FastAPI
from starlette.requests import Request

import burr
from burr.core import Application, ApplicationBuilder, State
from burr.tracking import LocalTrackingClient

from examples.gpt import application as chat_application

PROJECT_ID = "demo:chatbot"


class ChatItem(pydantic.BaseModel):
content: str
type: Literal["image", "text", "code"]
role: Literal["user", "assistant"]


@functools.lru_cache(maxsize=128)
def _get_application(project_id: str, app_id: str) -> Application:
app = chat_application.application(use_hamilton=False, app_id=app_id, project_id="demo:chatbot")
if LocalTrackingClient.app_log_exists(project_id, app_id):
state, _ = LocalTrackingClient.load_state(project_id, app_id) # TODO -- handle entrypoint
app.update_state(
State(state)
) # TODO -- handle the entrypoint -- this will always reset to prompt
return app


def chat_response(project_id: str, app_id: str, prompt: str) -> List[ChatItem]:
"""Chat response endpoint.
:param project_id: Project ID to run
:param app_id: Application ID to run
:param prompt: Prompt to send to the chatbot
:return:
"""
burr_app = _get_application(project_id, app_id)
_, _, state = burr_app.run(halt_after=["response"], inputs=dict(prompt=prompt))
return state.get("chat_history", [])


def chat_history(project_id: str, app_id: str) -> List[ChatItem]:
burr_app = _get_application(project_id, app_id)
state = burr_app.state
print(state)
return state.get("chat_history", [])


def register(app: FastAPI, api_prefix: str):
app.post(f"{api_prefix}/{{project_id}}/{{app_id}}/response", response_model=List[ChatItem])(
chat_response
)
app.get(f"{api_prefix}/{{project_id}}/{{app_id}}/history", response_model=List[ChatItem])(
chat_history
)
12 changes: 9 additions & 3 deletions burr/tracking/server/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
from typing import Sequence

from burr.integrations.base import require_plugin
from burr.tracking.server.examples import chatbot

try:
import uvicorn
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates

from burr.tracking.server import backend, schema
from burr.tracking.server import backend as backend_module
from burr.tracking.server import schema
from burr.tracking.server.schema import ApplicationLogs
except ImportError as e:
require_plugin(
Expand All @@ -30,10 +32,10 @@

app = FastAPI()

backend = backend.LocalBackend()

SERVE_STATIC = os.getenv("BURR_SERVE_STATIC", "true").lower() == "true"

backend = backend_module.LocalBackend()


@app.get("/api/v0/projects", response_model=Sequence[schema.Project])
async def get_projects(request: Request) -> Sequence[schema.Project]:
Expand Down Expand Up @@ -75,6 +77,10 @@ async def ready() -> bool:
return True


# Examples -- todo -- put them behind `if` statements
chatbot.register(app, "/api/v0/chatbot")


if SERVE_STATIC:
BASE_ASSET_DIRECTORY = str(files("burr").joinpath("tracking/server/build"))

Expand Down
18 changes: 11 additions & 7 deletions examples/gpt/application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import List, Optional, Tuple

import dag
Expand Down Expand Up @@ -77,7 +78,7 @@ def prompt_for_more(state: State) -> Tuple[dict, State]:
def chat_response(
state: State, prepend_prompt: str, display_type: str = "text", model: str = "gpt-3.5-turbo"
) -> Tuple[dict, State]:
chat_history = state["chat_history"].copy()
chat_history = copy.deepcopy(state["chat_history"])
chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}"
chat_history_api_format = [
{
Expand Down Expand Up @@ -129,7 +130,7 @@ def response(state: State) -> Tuple[dict, State]:
# return result, state.append(chat_history=result["chat_record"])


def base_application(hooks: List[LifecycleAdapter], app_id: str, storage_dir: str):
def base_application(hooks: List[LifecycleAdapter], app_id: str, storage_dir: str, project_id: str):
if hooks is None:
hooks = []
return (
Expand Down Expand Up @@ -165,12 +166,14 @@ def base_application(hooks: List[LifecycleAdapter], app_id: str, storage_dir: st
("response", "prompt", default),
)
.with_hooks(*hooks)
.with_tracker("demo:chatbot", params={"app_id": app_id, "storage_dir": storage_dir})
.with_tracker(project_id, params={"app_id": app_id, "storage_dir": storage_dir})
.build()
)


def hamilton_application(hooks: List[LifecycleAdapter], app_id: str, storage_dir: str):
def hamilton_application(
hooks: List[LifecycleAdapter], app_id: str, storage_dir: str, project_id: str
):
dr = driver.Driver({"provider": "openai"}, dag) # TODO -- add modules
Hamilton.set_driver(dr)
application = (
Expand Down Expand Up @@ -231,7 +234,7 @@ def hamilton_application(hooks: List[LifecycleAdapter], app_id: str, storage_dir
("response", "prompt", default),
)
.with_hooks(*hooks)
.with_tracker("demo:chatbot", params={"app_id": app_id, "storage_dir": storage_dir})
.with_tracker(project_id, params={"app_id": app_id, "storage_dir": storage_dir})
.build()
)
return application
Expand All @@ -240,12 +243,13 @@ def hamilton_application(hooks: List[LifecycleAdapter], app_id: str, storage_dir
def application(
use_hamilton: bool,
app_id: Optional[str] = None,
project_id: str = "demo:chatbot",
storage_dir: Optional[str] = "~/.burr",
hooks: Optional[List[LifecycleAdapter]] = None,
) -> Application:
if use_hamilton:
return hamilton_application(hooks, app_id, storage_dir)
return base_application(hooks, app_id, storage_dir)
return hamilton_application(hooks, app_id, storage_dir, project_id=project_id)
return base_application(hooks, app_id, storage_dir, project_id=project_id)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 1435bac

Please sign in to comment.