From 6602c8667b4218460e6122bf4fdceef7ada6f647 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Tue, 19 Mar 2024 14:17:33 -0700 Subject: [PATCH] Updates chatbot in-app demo to use new persistence capability This also chagnes the tracker to enable that. --- burr/tracking/client.py | 55 +++++++++++++++++++++++- burr/tracking/server/examples/chatbot.py | 10 +---- examples/gpt/application.py | 15 +++++-- 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/burr/tracking/client.py b/burr/tracking/client.py index ba3d0c53..6c6f0930 100644 --- a/burr/tracking/client.py +++ b/burr/tracking/client.py @@ -83,6 +83,7 @@ def __init__( self.f = None self.storage_dir = LocalTrackingClient.get_storage_path(project, storage_dir) + self.project_id = project @classmethod def get_storage_path(cls, project, storage_dir): @@ -118,7 +119,9 @@ def load_state( sequence_id: int = -1, storage_dir: str = DEFAULT_STORAGE_DIR, ) -> tuple[dict, str]: - """Function to load state from what the tracking client got. + """THis is deprecated and will be removed when we migrate over demos. Do not use! Instead use + the persistence API :py:class:`initialize_from ` + to load state. It defaults to loading the last state, but you can supply a sequence number. @@ -293,7 +296,55 @@ def load( self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None ) -> Optional[PersistedStateData]: # TODO: - pass + if app_id is None: + return # no application ID + if sequence_id is None: + sequence_id = -1 # get the last one + path = os.path.join(self.storage_dir, app_id, self.LOG_FILENAME) + if not os.path.exists(path): + raise ValueError( + f"No logs found for {self.project_id}/{app_id} under {self.storage_dir}" + ) + with open(path, "r") as f: + json_lines = f.readlines() + # load as JSON + json_lines = [json.loads(js_line) for js_line in json_lines] + # filter to only end_entry + json_lines = [js_line for js_line in json_lines if js_line["type"] == "end_entry"] + try: + line = json_lines[sequence_id] + except IndexError: + raise ValueError( + f"Sequence number {sequence_id} not found for {self.project_id}/{app_id}." + ) + # check sequence number matches if non-negative; will break if either is None. + line_seq = int(line["sequence_id"]) + if -1 < sequence_id != line_seq: + logger.warning( + f"Sequence number mismatch. For {self.project_id}/{app_id}: " + f"actual:{line_seq} != expected:{sequence_id}" + ) + # get the prior state + prior_state = line["state"] + position = line["action"] + # delete internally stuff. We can't loop over the keys and delete them in the same loop + to_delete = [] + for key in prior_state.keys(): + # remove any internal "__" state + if key.startswith("__"): + to_delete.append(key) + for key in to_delete: + del prior_state[key] + prior_state["__SEQUENCE_ID"] = line_seq # add the sequence id back + return { + "partition_key": partition_key, + "app_id": app_id, + "sequence_id": line_seq, + "position": position, + "state": State(prior_state), + "created_at": datetime.datetime.fromtimestamp(os.path.getctime(path)).isoformat(), + "status": "success" if line["exception"] is None else "failed", + } # TODO -- implement async version diff --git a/burr/tracking/server/examples/chatbot.py b/burr/tracking/server/examples/chatbot.py index 65efabf0..b2d988a4 100644 --- a/burr/tracking/server/examples/chatbot.py +++ b/burr/tracking/server/examples/chatbot.py @@ -5,9 +5,8 @@ from fastapi import FastAPI from starlette.requests import Request -from burr.core import Application, State +from burr.core import Application from burr.examples.gpt import application as chat_application -from burr.tracking import LocalTrackingClient class ChatItem(pydantic.BaseModel): @@ -18,12 +17,7 @@ class ChatItem(pydantic.BaseModel): @functools.lru_cache(maxsize=128) def _get_application(project_id: str, app_id: str) -> Application: - app = chat_application.application(use_hamilton=False, app_id=app_id, project_id="demo:chatbot") - if LocalTrackingClient.app_log_exists(project_id, app_id): - state, _ = LocalTrackingClient.load_state(project_id, app_id) # TODO -- handle entrypoint - app.update_state( - State(state) - ) # TODO -- handle the entrypoint -- this will always reset to prompt + app = chat_application.application(use_hamilton=False, app_id=app_id, project_id=project_id) return app diff --git a/examples/gpt/application.py b/examples/gpt/application.py index 379f4725..38ad8281 100644 --- a/examples/gpt/application.py +++ b/examples/gpt/application.py @@ -9,6 +9,7 @@ from burr.core.action import action from burr.integrations.hamilton import Hamilton, append_state, from_state, update_state from burr.lifecycle import LifecycleAdapter +from burr.tracking import LocalTrackingClient MODES = { "answer_question": "text", @@ -149,6 +150,9 @@ def base_application( ): if hooks is None: hooks = [] + # we're initializing above so we can load from this as well + # we could also use `with_tracker("local", project=project_id, params={"storage_dir": storage_dir})` + tracker = LocalTrackingClient(project=project_id, storage_dir=storage_dir) return ( ApplicationBuilder() .with_actions( @@ -166,8 +170,6 @@ def base_application( prompt_for_more=prompt_for_more, response=response, ) - .with_entrypoint("prompt") - .with_state(chat_history=[]) .with_transitions( ("prompt", "check_openai_key", default), ("check_openai_key", "check_safety", when(has_openai_key=True)), @@ -184,8 +186,15 @@ def base_application( ), ("response", "prompt", default), ) + # initializes from the tracking log if it does not already exist + .initialize_from( + tracker, + resume_at_next_action=False, # always resume from entrypoint in the case of failure + default_state={"chat_history": []}, + default_entrypoint="prompt", + ) .with_hooks(*hooks) - .with_tracker("local", project=project_id, params={"storage_dir": storage_dir}) + .with_tracker(tracker) .with_identifiers(app_id=app_id) .build() )