From d60908d641180eafd78a6ab98a478400f10d295c Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 31 Jul 2024 15:57:00 -0700 Subject: [PATCH 01/11] Adds s3 tracking client to Burr This way the user can write to s3 and have Burr server pick it up. --- burr/tracking/s3client.py | 375 ++++++++++++++++++++++++++++++++++++++ burr/tracking/utils.py | 7 + 2 files changed, 382 insertions(+) create mode 100644 burr/tracking/s3client.py create mode 100644 burr/tracking/utils.py diff --git a/burr/tracking/s3client.py b/burr/tracking/s3client.py new file mode 100644 index 00000000..61657799 --- /dev/null +++ b/burr/tracking/s3client.py @@ -0,0 +1,375 @@ +import dataclasses +import datetime +import json +import logging +import queue +import re +import threading +import time +import traceback +import uuid +from typing import Any, Dict, List, Optional, Tuple, Union + +import pydantic + +from burr.common import types as burr_types +from burr.core import Action, ApplicationGraph, State, serde +from burr.integrations.base import require_plugin +from burr.tracking.base import SyncTrackingClient +from burr.tracking.common.models import ( + ApplicationMetadataModel, + ApplicationModel, + BeginEntryModel, + BeginSpanModel, + EndEntryModel, + EndSpanModel, + PointerModel, +) +from burr.visibility import ActionSpan + +logger = logging.getLogger(__name__) + +try: + import boto3 +except ImportError as e: + require_plugin( + e, + ["boto3"], + "tracking-s3", + ) + + +def fire_and_forget(func): + def wrapper(self, *args, **kwargs): + if self.non_blocking: # must be used with the S3TrackingClient + + def run(): + try: + func(self, *args, **kwargs) + except Exception: + logger.exception( + "Exception occurred in fire-and-forget function: %s", func.__name__ + ) + + threading.Thread(target=run).start() + return func(self, *args, **kwargs) + + return wrapper + + +# TODO -- move to common and share with client.py + +INPUT_FILTERLIST = {"__tracer"} + + +def _format_exception(exception: Exception) -> Optional[str]: + if exception is None: + return None + return "".join(traceback.format_exception(type(exception), exception, exception.__traceback__)) + + +INPUT_FILTERLIST = {"__tracer"} + + +def _filter_inputs(d: dict) -> dict: + return {k: v for k, v in d.items() if k not in INPUT_FILTERLIST} + + +def _allowed_project_name(project_name: str, on_windows: bool) -> bool: + allowed_chars = "a-zA-Z0-9_\-" + if not on_windows: + allowed_chars += ":" + pattern = f"^[{allowed_chars}]+$" + + # Use regular expression to check if the string is valid + return bool(re.match(pattern, project_name)) + + +EventType = Union[BeginEntryModel, EndEntryModel, BeginSpanModel, EndSpanModel] + + +def unique_ordered_prefix() -> str: + return datetime.datetime.now().isoformat() + str(uuid.uuid4()) + + +def str_partition_key(partition_key: Optional[str]) -> str: + return partition_key or "__none__" + + +class S3TrackingClient(SyncTrackingClient): + """Synchronous tracking client that logs to S3. Experimental. Errs on the side of writing many little files. + General schema is: + - bucket + - data/ + - project_name_1 + - YYYY/MM/DD/HH + - application.json (optional, will be on the first write from this tracker object) + - metadata.json (optional, will be on the first write from this tracker object) + - log_.jsonl + - log_.jsonl + - YYYY/MM/DD/HH + ... + ... + + This is designed to be fast to write, generally slow(ish) to read, but doable, and require no db. + This also has a non-blocking mode that just launches a thread (expensive but doable solution) + TODO -- get working with aiobotocore and an async tracker + """ + + def __init__( + self, + project: str, + bucket: str, + region: str = None, + endpoint_url: Optional[str] = None, + non_blocking: bool = False, + serde_kwargs: Optional[dict] = None, + unique_tracker_id: str = None, + flush_interval: int = 5, + ): + self.bucket = bucket + self.project = project + self.region = region + self.endpoint_url = endpoint_url + self.non_blocking = non_blocking + self.s3 = boto3.client("s3", region_name=region, endpoint_url=endpoint_url) + self.serde_kwargs = serde_kwargs or {} + self.unique_tracker_id = ( + unique_tracker_id + if unique_tracker_id is not None + else datetime.datetime.now().isoformat() + "-" + str(uuid.uuid4()) + ) + self.log_queue = queue.Queue() # Tuple[app_id, EventType] + self.flush_interval = flush_interval + self.max_batch_size = 10000 # rather large batch size -- why not? It'll flush every 5 seconds otherwise and we don't want to spam s3 with files + self.initialized = False + self.running = True + self.init() + + def _get_time_partition(self): + time = datetime.datetime.utcnow().isoformat() + return [time[:4], time[5:7], time[8:10], time[11:13], time[14:]] + + def get_prefix(self): + return [ + "data", + self.project, + *self._get_time_partition(), + ] + + def init(self): + if not self.initialized: + logger.info("Initializing S3TrackingClient with flushing thread") + thread = threading.Thread(target=self.thread) + # This will quit when the main thread is ready to, and gracefully + # But it will gracefully exit due to the "None" on the queue + threading.Thread( + target=lambda: threading.main_thread().join() or self.log_queue.put(None) + ).start() + thread.start() + self.initialized = True + + def thread(self): + batch = [] + last_flush_time = time.time() + + while self.running: + try: + logger.info(f"Checking for new data to flush -- batch is of size: {len(batch)}") + # Wait up to flush_interval for new data + item = self.log_queue.get(timeout=self.flush_interval) + # signal that we're done + if item is None: + self.log_events(batch) + self.running = False + break + batch.append(item) + # Check if batch is full or flush interval has passed + if ( + len(batch) >= self.max_batch_size + or (time.time() - last_flush_time) >= self.flush_interval + ): + logger.info(f"Flushing batch with {len(batch)} events") + self.log_events(batch) + batch = [] + last_flush_time = time.time() + except queue.Empty: + # Flush on timeout if there's any data + if batch: + logger.info(f"Flushing batch on queue empty with {len(batch)} events") + self.log_events(batch) + batch = [] + last_flush_time = time.time() + + def stop(self, thread: threading.Thread): + self.running = False # will stop the thread + thread.join() # then wait for it to be done + events = [] + # Flush any remaining events + while self.log_queue.qsize() > 0: + events.append(self.log_queue.get()) + self.log_events(events) + + def submit_log_event(self, event: EventType, app_id: str, partition_key: str): + self.log_queue.put((app_id, partition_key, event)) + + def log_events(self, events: List[Tuple[str, EventType]]): + events_by_app_id = {} + for app_id, partition_key, event in events: + if (app_id, partition_key) not in events_by_app_id: + events_by_app_id[(app_id, partition_key)] = [] + events_by_app_id[(app_id, partition_key)].append(event) + for (app_id, partition_key), app_events in events_by_app_id.items(): + logger.debug(f"Logging {len(app_events)} events for app {app_id}") + min_sequence_id = min([e.sequence_id for e in app_events]) + max_sequence_id = max([e.sequence_id for e in app_events]) + path = [ + str_partition_key(partition_key), + app_id, + str(uuid.uuid4()) + + "__log.jsonl", # in case we happen to have multiple at the same time.... + ] + self.log_object( + *path, + data=app_events, + metadata={ + "min_sequence_id": str(min_sequence_id), + "max_sequence_id": str(max_sequence_id), + }, + ) + + def log_object( + self, + *path_within_project: str, + data: Union[pydantic.BaseModel, List[pydantic.BaseModel]], + metadata: Dict[str, str] = None, + ): + if metadata is None: + metadata = {} + metadata = {**metadata, "tracker_id": self.unique_tracker_id} + full_path = self.get_prefix() + list(path_within_project) + key = "/".join(full_path) + if isinstance(data, list): + body = "\n".join([d.model_dump_json() for d in data]) + else: + body = data.model_dump_json() + self.s3.put_object(Bucket=self.bucket, Key=key, Body=body, Metadata=metadata) + + @fire_and_forget + def post_application_create( + self, + *, + app_id: str, + partition_key: Optional[str], + state: "State", + application_graph: "ApplicationGraph", + parent_pointer: Optional[burr_types.ParentPointer], + spawning_parent_pointer: Optional[burr_types.ParentPointer], + **future_kwargs: Any, + ): + graph = ApplicationModel.from_application_graph( + application_graph, + ) + graph_path = [str_partition_key(partition_key), app_id, "graph.json"] + self.log_object(*graph_path, data=graph) + metadata = ApplicationMetadataModel( + partition_key=partition_key, + parent_pointer=PointerModel.from_pointer(parent_pointer), + spawning_parent_pointer=PointerModel.from_pointer(spawning_parent_pointer), + ) + metadata_path = [str_partition_key(partition_key), app_id, "metadata.json"] + # we put these here to allow for quicker retrieval on the server side + # It's a bit of a hack to put it all into metadata, but it helps with ingestion + self.log_object( + *metadata_path, + data=metadata, + metadata={ + "parent_pointer": json.dumps(dataclasses.asdict(parent_pointer)) + if parent_pointer is not None + else "None", + "spawning_parent_pointer": json.dumps(dataclasses.asdict(spawning_parent_pointer)) + if spawning_parent_pointer is not None + else "None", + }, + ) + # TODO -- log parent relationship + + def pre_run_step( + self, + *, + app_id: str, + partition_key: str, + sequence_id: int, + state: "State", + action: "Action", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + _filtered_inputs = _filter_inputs(inputs) + pre_run_entry = BeginEntryModel( + start_time=datetime.datetime.now(), + action=action.name, + inputs=serde.serialize(_filtered_inputs, **self.serde_kwargs), + sequence_id=sequence_id, + ) + self.submit_log_event(pre_run_entry, app_id, partition_key) + + def post_run_step( + self, + *, + app_id: str, + partition_key: str, + sequence_id: int, + state: "State", + action: "Action", + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + post_run_entry = EndEntryModel( + end_time=datetime.datetime.now(), + action=action.name, + result=serde.serialize(result, **self.serde_kwargs), + sequence_id=sequence_id, + exception=_format_exception(exception), + state=state.serialize(), + ) + self.submit_log_event(post_run_entry, app_id, partition_key) + + def pre_start_span( + self, + *, + sequence_id: int, + partition_key: str, + app_id: str, + span: ActionSpan, + span_dependencies: list[str], + **future_kwargs: Any, + ): + begin_span_model = BeginSpanModel( + start_time=datetime.datetime.now(), + action_sequence_id=sequence_id, + span_id=span.uid, + parent_span_id=span.parent.uid if span.parent else None, + span_dependencies=span_dependencies, + span_name=span.name, + ) + self.submit_log_event(begin_span_model, app_id, partition_key) + + def post_end_span( + self, + *, + action: str, + sequence_id: int, + span: ActionSpan, + span_dependencies: list[str], + app_id: str, + partition_key: str, + **future_kwargs: Any, + ): # TODO -- implemenet + end_span_model = EndSpanModel( + end_time=datetime.datetime.now(), + action_sequence_id=sequence_id, + span_id=span.uid, + ) + self.submit_log_event(end_span_model, app_id, partition_key) diff --git a/burr/tracking/utils.py b/burr/tracking/utils.py new file mode 100644 index 00000000..61f1e192 --- /dev/null +++ b/burr/tracking/utils.py @@ -0,0 +1,7 @@ +import json + + +def safe_json_load(line: bytes): + # Every once in a while we'll hit a non-utf-8 character + # In this case we replace it and hope for the best + return json.loads(line.decode("utf-8", errors="replace")) From 2d97d1562b5a1408ef2ef107afdf92a806d96446 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 31 Jul 2024 16:14:48 -0700 Subject: [PATCH 02/11] Passes app_id/partition_key to traces/spans We need this to write the log files to s3 and index them properly. --- burr/core/application.py | 5 ++++- burr/lifecycle/base.py | 8 ++++++++ burr/tracking/client.py | 4 ++-- burr/visibility/tracing.py | 20 ++++++++++++++++++-- tests/visibility/test_tracing.py | 16 ++++++++++++++++ 5 files changed, 48 insertions(+), 5 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index 98091af0..e6774fff 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -459,7 +459,10 @@ def __init__( self._parent_pointer = fork_parent_pointer self.dependency_factory = { "__tracer": functools.partial( - visibility.tracing.TracerFactory, lifecycle_adapters=self._adapter_set + visibility.tracing.TracerFactory, + lifecycle_adapters=self._adapter_set, + app_id=self._uid, + partition_key=self._partition_key, ), "__context": self._context_factory, } diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py index e2f235af..897f41be 100644 --- a/burr/lifecycle/base.py +++ b/burr/lifecycle/base.py @@ -167,6 +167,8 @@ def pre_start_span( action_sequence_id: int, span: "ActionSpan", span_dependencies: list[str], + app_id: str, + partition_key: Optional[str], **future_kwargs: Any, ): pass @@ -182,6 +184,8 @@ async def pre_start_span( action_sequence_id: int, span: "ActionSpan", span_dependencies: list[str], + app_id: str, + partition_key: Optional[str], **future_kwargs: Any, ): pass @@ -200,6 +204,8 @@ def post_end_span( action_sequence_id: int, span: "ActionSpan", span_dependencies: list[str], + app_id: str, + partition_key: Optional[str], **future_kwargs: Any, ): pass @@ -215,6 +221,8 @@ async def post_end_span( action_sequence_id: int, span: "ActionSpan", span_dependencies: list[str], + app_id: str, + partition_key: Optional[str], **future_kwargs: Any, ): pass diff --git a/burr/tracking/client.py b/burr/tracking/client.py index 5a8d4ac4..b175e66e 100644 --- a/burr/tracking/client.py +++ b/burr/tracking/client.py @@ -428,7 +428,7 @@ def pre_start_span( span_dependencies: list[str], **future_kwargs: Any, ): - being_span_model = BeginSpanModel( + begin_span_model = BeginSpanModel( start_time=datetime.datetime.now(), action_sequence_id=sequence_id, span_id=span.uid, @@ -436,7 +436,7 @@ def pre_start_span( span_dependencies=span_dependencies, span_name=span.name, ) - self._append_write_line(being_span_model) + self._append_write_line(begin_span_model) def post_end_span( self, diff --git a/burr/visibility/tracing.py b/burr/visibility/tracing.py index 1c6d84fd..99c1e175 100644 --- a/burr/visibility/tracing.py +++ b/burr/visibility/tracing.py @@ -103,6 +103,8 @@ def __init__( action_sequence_id: int, span_name: str, lifecycle_adapters: LifecycleAdapterSet, + app_id: str, + partition_key: Optional[str], span_dependencies: List[str], top_level_span_count: int = 0, context_var=execution_context_var, @@ -122,6 +124,8 @@ def __init__( self.span_dependencies = span_dependencies self.top_level_span_count = top_level_span_count self.context_var = context_var + self.app_id = app_id + self.partition_key = partition_key def _sync_hooks_enter(self, context: ActionSpan): self.lifecycle_adapters.call_all_lifecycle_hooks_sync( @@ -130,6 +134,8 @@ def _sync_hooks_enter(self, context: ActionSpan): span=context, span_dependencies=self.span_dependencies, sequence_id=self.action_sequence_id, + app_id=self.app_id, + partition_key=self.partition_key, ) async def _async_hooks_enter(self, context: ActionSpan): @@ -139,6 +145,8 @@ async def _async_hooks_enter(self, context: ActionSpan): span=context, span_dependencies=self.span_dependencies, sequence_id=self.action_sequence_id, + app_id=self.app_id, + partition_key=self.partition_key, ) async def _async_hooks_exit(self, context: ActionSpan): @@ -148,6 +156,8 @@ async def _async_hooks_exit(self, context: ActionSpan): span=context, span_dependencies=self.span_dependencies, sequence_id=self.action_sequence_id, + app_id=self.app_id, + partition_key=self.partition_key, ) def _enter(self): @@ -178,6 +188,8 @@ def _sync_hooks_exit(self, context: ActionSpan): span=context, span_dependencies=self.span_dependencies, sequence_id=self.action_sequence_id, + app_id=self.app_id, + partition_key=self.partition_key, ) def __enter__(self): @@ -241,14 +253,14 @@ def my_action(state: State, __tracer: TracerFactory) -> tuple[dict, State]: context_manager: ActionSpanTracer = __tracer("my_span_name") with context_manager: ... - - """ def __init__( self, action: str, sequence_id: int, + app_id: str, + partition_key: str, lifecycle_adapters: LifecycleAdapterSet, _context_var: ContextVar[Optional[ActionSpan]] = execution_context_var, ): @@ -263,6 +275,8 @@ def __init__( self.context_var = _context_var self.top_level_span_count = 0 self.action_sequence_id = sequence_id + self.app_id = app_id + self.partition_key = partition_key def __call__( self, span_name: str, span_dependencies: Optional[List[str]] = None @@ -279,4 +293,6 @@ def __call__( span_dependencies=span_dependencies, context_var=self.context_var, top_level_span_count=self.top_level_span_count, + app_id=self.app_id, + partition_key=self.partition_key, ) diff --git a/tests/visibility/test_tracing.py b/tests/visibility/test_tracing.py index 572643c7..4550a452 100644 --- a/tests/visibility/test_tracing.py +++ b/tests/visibility/test_tracing.py @@ -26,6 +26,8 @@ def test_action_span_tracer_correct_span_count(request): sequence_id=0, lifecycle_adapters=LifecycleAdapterSet(), _context_var=context_var, + app_id="test_app_id", + partition_key=None, ) assert context_var.get() is None # nothing to start assert tracer_factory.top_level_span_count == 0 # and thus no top-level spans @@ -54,6 +56,8 @@ async def test_action_span_tracer_correct_span_count_async(request): sequence_id=0, lifecycle_adapters=LifecycleAdapterSet(), _context_var=context_var, + app_id="test_app_id", + partition_key=None, ) assert context_var.get() is None # nothing to start assert tracer_factory.top_level_span_count == 0 # and thus no top-level spans @@ -82,6 +86,8 @@ def test_action_span_tracer_correct_span_count_nested(request): sequence_id=0, lifecycle_adapters=LifecycleAdapterSet(), _context_var=context_var, + app_id="test_app_id", + partition_key=None, ) with tracer_factory("0") as outside_span_0: @@ -119,6 +125,8 @@ async def test_action_span_tracer_correct_span_count_nested_async(request): sequence_id=0, lifecycle_adapters=LifecycleAdapterSet(), _context_var=context_var, + app_id="test_app_id", + partition_key=None, ) async with tracer_factory("0") as outside_span_0: @@ -188,6 +196,8 @@ def post_end_span( sequence_id=0, lifecycle_adapters=adapter_set, _context_var=context_var, + app_id="test_app_id", + partition_key=None, ) # 0:0 with tracer_factory_0("0"): @@ -212,6 +222,8 @@ def post_end_span( sequence_id=tracer_factory_0.action_sequence_id + 1, lifecycle_adapters=adapter_set, _context_var=context_var, + app_id="test_app_id", + partition_key=None, ) # 1:0 with tracer_factory_1("2"): @@ -283,6 +295,8 @@ async def post_end_span( sequence_id=0, lifecycle_adapters=adapter_set, _context_var=context_var, + app_id="test_app_id", + partition_key=None, ) # 0:0 async with tracer_factory_0("0"): @@ -307,6 +321,8 @@ async def post_end_span( sequence_id=tracer_factory_0.action_sequence_id + 1, lifecycle_adapters=adapter_set, _context_var=context_var, + app_id="test_app_id", + partition_key=None, ) # 1:0 async with tracer_factory_1("2"): From 699a283068a94af7972e521372ef612fc8b616e2 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 31 Jul 2024 16:29:23 -0700 Subject: [PATCH 03/11] Adds S3-based server High-level architecture: 1. Clients writes to s3 bucket 2. Server powers up with a SQLite(pluggable) db 3. Server indexes the s3 on a recurring job 4. We have pointers for everything in the UI stored in the db except the data for the traces 5. Server saves/loads sqlite database with highwatermark to s3 We have not implemented (5) yet, but the rest are done. Some specifics: 1. backend has been broken into mixins -- e.g. indexing backend, standard backend, etc... -- this allows us to have it implement classes and have that be called 2. If it's the indexing backend we have an admin view with jobs 3. We use tortoise ORM to make switching between DBs easy -- we will very likely enable postgres soon 4. The indexing function should be easy to invert control -- E.G. rather than writing to s3, we write to the server which logs to s3. 5. We store a high-watermark so we don't go over the same one twice --- burr/cli/__main__.py | 80 ++- burr/tracking/common/models.py | 8 + burr/tracking/server/backend.py | 92 ++- burr/tracking/server/requirements-s3.txt | 4 + burr/tracking/server/run.py | 131 +++- burr/tracking/server/s3/README.md | 25 + burr/tracking/server/s3/__init__.py | 0 burr/tracking/server/s3/backend.py | 559 ++++++++++++++++++ burr/tracking/server/s3/initialize_db.py | 28 + .../models/0_20240730151503_init.py | 70 +++ burr/tracking/server/s3/models.py | 96 +++ burr/tracking/server/s3/pyproject.toml | 4 + burr/tracking/server/s3/settings.py | 13 + burr/tracking/server/s3/utils.py | 14 + burr/tracking/server/schema.py | 84 ++- 15 files changed, 1169 insertions(+), 39 deletions(-) create mode 100644 burr/tracking/server/requirements-s3.txt create mode 100644 burr/tracking/server/s3/README.md create mode 100644 burr/tracking/server/s3/__init__.py create mode 100644 burr/tracking/server/s3/backend.py create mode 100644 burr/tracking/server/s3/initialize_db.py create mode 100644 burr/tracking/server/s3/migrations/models/0_20240730151503_init.py create mode 100644 burr/tracking/server/s3/models.py create mode 100644 burr/tracking/server/s3/pyproject.toml create mode 100644 burr/tracking/server/s3/settings.py create mode 100644 burr/tracking/server/s3/utils.py diff --git a/burr/cli/__main__.py b/burr/cli/__main__.py index 51793715..a686933e 100644 --- a/burr/cli/__main__.py +++ b/burr/cli/__main__.py @@ -1,5 +1,6 @@ import importlib.util import json +import logging import os import shutil import subprocess @@ -27,21 +28,46 @@ ) +class InterceptHandler(logging.Handler): + def emit(self, record): + # Get corresponding Loguru level if it exists + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the log message + frame, depth = logging.currentframe(), 2 + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + + +# Clear default handlers +logging.basicConfig(handlers=[InterceptHandler()], level=logging.INFO) + + # TODO -- add this as a general callback to the CLI def _telemetry_if_enabled(event: str): if telemetry.is_telemetry_enabled(): telemetry.create_and_send_cli_event(event) -def _command(command: str, capture_output: bool) -> str: +def _command(command: str, capture_output: bool, addl_env: dict = None) -> str: """Runs a simple command""" + if addl_env is None: + addl_env = {} + env = os.environ.copy() + env.update(addl_env) logger.info(f"Running command: {command}") if isinstance(command, str): command = command.split(" ") if capture_output: try: return ( - subprocess.check_output(command, stderr=subprocess.PIPE, shell=False) + subprocess.check_output(command, stderr=subprocess.PIPE, shell=False, env=env) .decode() .strip() ) @@ -49,7 +75,7 @@ def _command(command: str, capture_output: bool) -> str: print(e.stdout.decode()) print(e.stderr.decode()) raise e - subprocess.run(command, shell=False, check=True) + subprocess.run(command, shell=False, check=True, env=env) def _get_git_root() -> str: @@ -102,6 +128,12 @@ def build_ui(): _build_ui() +BACKEND_MODULES = { + "local": "burr.tracking.server.backend.LocalBackend", + "s3": "burr.tracking.server.s3.backend.S3Backend", +} + + def _run_server( port: int, dev_mode: bool, @@ -109,6 +141,7 @@ def _run_server( no_copy_demo_data: bool, initial_page="", host: str = "127.0.0.1", + backend: str = "local", ): _telemetry_if_enabled("run_server") # TODO: Implement server running logic here @@ -142,7 +175,10 @@ def _run_server( daemon=True, ) thread.start() - _command(cmd, capture_output=False) + env = { + "BURR_BACKEND_IMPL": BACKEND_MODULES[backend], + } + _command(cmd, capture_output=False, addl_env=env) @cli.command() @@ -156,8 +192,16 @@ def _run_server( help="Host to run the server on -- use 0.0.0.0 if you want " "to expose it to the network (E.G. in a docker image)", ) -def run_server(port: int, dev_mode: bool, no_open: bool, no_copy_demo_data: bool, host: str): - _run_server(port, dev_mode, no_open, no_copy_demo_data, host=host) +@click.option( + "--backend", + default="local", + help="Backend to use for the server.", + type=click.Choice(["local", "s3"]), +) +def run_server( + port: int, dev_mode: bool, no_open: bool, no_copy_demo_data: bool, host: str, backend: str +): + _run_server(port, dev_mode, no_open, no_copy_demo_data, host=host, backend=backend) @cli.command() @@ -186,7 +230,17 @@ def build_and_publish(prod: bool, no_wipe_dist: bool): @cli.command(help="generate demo data for the UI") -def generate_demo_data(): +@click.option( + "--s3-bucket", help="S3 URI to save to, will use the s3 tracker, not local mode", required=False +) +@click.option( + "--data-dir", + help="Local directory to save to", + required=False, + default="burr/tracking/server/demo_data", +) +@click.option("--unique-app-names", help="Use unique app names", is_flag=True) +def generate_demo_data(s3_bucket, data_dir, unique_app_names: bool): _telemetry_if_enabled("generate_demo_data") git_root = _get_git_root() # We need to add the examples directory to the path so we have all the imports @@ -194,10 +248,14 @@ def generate_demo_data(): sys.path.extend([git_root, f"{git_root}/examples/multi-modal-chatbot"]) from burr.cli.demo_data import generate_all - with cd(git_root): - logger.info("Removing old demo data") - shutil.rmtree("burr/tracking/server/demo_data", ignore_errors=True) - generate_all("burr/tracking/server/demo_data") + # local mode + if s3_bucket is None: + with cd(git_root): + logger.info("Removing old demo data") + shutil.rmtree(data_dir, ignore_errors=True) + generate_all(data_dir=data_dir, unique_app_names=unique_app_names) + else: + generate_all(s3_bucket=s3_bucket, unique_app_names=unique_app_names) def _transform_state_to_test_case(state: dict, action_name: str, test_name: str) -> dict: diff --git a/burr/tracking/common/models.py b/burr/tracking/common/models.py index c66b7641..e4ed8a4d 100644 --- a/burr/tracking/common/models.py +++ b/burr/tracking/common/models.py @@ -175,6 +175,10 @@ class BeginSpanModel(IdentifyingModel): span_dependencies: list[str] type: str = "begin_span" + @property + def sequence_id(self) -> int: + return self.action_sequence_id + class EndSpanModel(IdentifyingModel): """Pydantic model that represents an entry for the end of a span""" @@ -183,3 +187,7 @@ class EndSpanModel(IdentifyingModel): action_sequence_id: int span_id: str # unique among the application type: str = "end_span" + + @property + def sequence_id(self) -> int: + return self.action_sequence_id diff --git a/burr/tracking/server/backend.py b/burr/tracking/server/backend.py index 1316279e..e520ed3b 100644 --- a/burr/tracking/server/backend.py +++ b/burr/tracking/server/backend.py @@ -1,11 +1,15 @@ import abc +import importlib import json import os.path -from typing import Sequence, TypeVar +import sys +from typing import Any, Optional, Sequence, Type, TypeVar import aiofiles import aiofiles.os as aiofilesos import fastapi +from fastapi import FastAPI +from pydantic_settings import BaseSettings, SettingsConfigDict from burr.tracking.common import models from burr.tracking.common.models import ( @@ -20,7 +24,6 @@ T = TypeVar("T") - # The following is a backend for the server. # Note this is not a fixed API yet, and thus not documented (in Burr's documentation) # Specifically, this does not have: @@ -29,7 +32,42 @@ # - Authentication/Authorization +if sys.version_info <= (3, 11): + Self = Any +else: + from typing import Self + + +class BurrSettings(BaseSettings): + model_config = SettingsConfigDict(env_prefix="burr_") + + +class IndexingBackendMixin(abc.ABC): + """Base mixin for an indexing backend -- one that index from + logs (E.G. s3)""" + + @abc.abstractmethod + async def update(self): + """Updates the index""" + pass + + @abc.abstractmethod + async def update_interval_milliseconds(self) -> Optional[int]: + """Returns the update interval in milliseconds""" + pass + + @abc.abstractmethod + async def indexing_jobs( + self, offset: int = 0, limit: int = 100, filter_empty: bool = True + ) -> Sequence[schema.IndexingJob]: + """Returns the indexing jobs""" + pass + + class BackendBase(abc.ABC): + async def lifespan(self, app: FastAPI): + yield + @abc.abstractmethod async def list_projects(self, request: fastapi.Request) -> Sequence[schema.Project]: """Lists out all projects -- this relies on the paginate function to work properly. @@ -63,6 +101,39 @@ async def get_application_logs( """ pass + @classmethod + @abc.abstractmethod + def settings_model(cls) -> Type[BaseSettings]: + """Gives a settings model that tells us how to configure the backend. + This is a class of pydantic BaseSettings type + + :return: the settings model + """ + pass + + @classmethod + def from_settings(cls, settings_model: BaseSettings) -> Self: + """Creates a backend from settings, of the type of settings_model above + This defaults to assuming the constructor takes in settings parameters + + :param settings_model: + :return: + """ + return cls(**settings_model.dict()) + + @classmethod + def create_from_env(cls, dotenv_path: Optional[str] = None) -> Self: + cls_path = os.environ.get("BURR_BACKEND_IMPL", "burr.tracking.server.backend.LocalBackend") + mod_path = ".".join(cls_path.split(".")[:-1]) + mod = importlib.import_module(mod_path) + cls_name = cls_path.split(".")[-1] + if not hasattr(mod, cls_name): + raise ValueError(f"Could not find {cls_name} in {mod_path}") + cls = getattr(mod, cls_name) + return cls.from_settings( + cls.settings_model()(_env_file=dotenv_path, _env_file_encoding="utf-8") + ) + def safe_json_load(line: bytes): # Every once in a while we'll hit a non-utf-8 character @@ -80,11 +151,13 @@ def get_uri(project_id: str) -> str: return project_id_map.get(project_id, "") +DEFAULT_PATH = os.path.expanduser("~/.burr") + + class LocalBackend(BackendBase): """Quick implementation of a local backend for testing purposes. This is not a production backend.""" # TODO -- make this configurable through an env variable - DEFAULT_PATH = os.path.expanduser("~/.burr") def __init__(self, path: str = DEFAULT_PATH): self.path = path @@ -180,6 +253,7 @@ async def get_application_logs( str_graph = await f.read() steps_by_sequence_id = {} spans_by_id = {} + # TODO -- use the Step.from_logs method if os.path.exists(log_file): async with aiofiles.open(log_file, "rb") as f: for line in await f.readlines(): @@ -206,6 +280,7 @@ async def get_application_logs( span = spans_by_id[end_span.span_id] span.end_entry = end_span for span in spans_by_id.values(): + # They should have one, the other, or both set step = steps_by_sequence_id[span.begin_entry.action_sequence_id] step.spans.append(span) children = [] @@ -223,3 +298,14 @@ async def get_application_logs( spawning_parent_pointer=metadata.spawning_parent_pointer, children=children, ) + + class BackendSettings(BurrSettings): + path: str = DEFAULT_PATH + + @classmethod + def settings_model(cls) -> Type[BurrSettings]: + return cls.BackendSettings + + @classmethod + def from_settings(cls, settings_model: BurrSettings) -> Self: + return cls(**settings_model.dict()) diff --git a/burr/tracking/server/requirements-s3.txt b/burr/tracking/server/requirements-s3.txt new file mode 100644 index 00000000..e1f4632b --- /dev/null +++ b/burr/tracking/server/requirements-s3.txt @@ -0,0 +1,4 @@ +aerich +aiobotocore +fastapi-utils +tortoise-orm[accel, asyncmy] diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index be99070a..c06931fe 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -1,19 +1,32 @@ import importlib +import logging import os +from contextlib import asynccontextmanager from importlib.resources import files from typing import Sequence -from burr.integrations.base import require_plugin +# TODO -- remove this, just for testing +from hamilton.log_setup import setup_logging +from starlette import status + +from burr.tracking.server.backend import BackendBase, IndexingBackendMixin + +setup_logging(logging.INFO) + +logger = logging.getLogger(__name__) try: import uvicorn - from fastapi import FastAPI, Request + from fastapi import FastAPI, HTTPException, Request from fastapi.staticfiles import StaticFiles + from fastapi_utils.tasks import repeat_every from starlette.templating import Jinja2Templates - from burr.tracking.server import backend as backend_module from burr.tracking.server import schema - from burr.tracking.server.schema import ApplicationLogs + + # from burr.tracking.server import backend as backend_module + # from burr.tracking.server.s3 import backend as s3_backend + from burr.tracking.server.schema import ApplicationLogs, BackendSpec, IndexingJob # dynamic importing due to the dashes (which make reading the examples on github easier) email_assistant = importlib.import_module("burr.examples.email-assistant.server") @@ -21,26 +34,87 @@ streaming_chatbot = importlib.import_module("burr.examples.streaming-fastapi.server") except ImportError as e: - require_plugin( - e, - [ - "click", - "fastapi", - "uvicorn", - "pydantic", - "fastapi-pagination", - "aiofiles", - "requests", - "jinja2", - ], - "tracking", - ) - -app = FastAPI() + raise e + # require_plugin( + # e, + # [ + # "click", + # "fastapi", + # "uvicorn", + # "pydantic", + # "fastapi-pagination", + # "aiofiles", + # "requests", + # "jinja2", + # ], + # "tracking", + # ) SERVE_STATIC = os.getenv("BURR_SERVE_STATIC", "true").lower() == "true" -backend = backend_module.LocalBackend() +# TODO -- get based on a config +# backend = backend_module.LocalBackend() +# backend = s3_backend.S3Backend( +# bucket="burr-prod-test", +# ) + +backend = BackendBase.create_from_env() + + +# if it is an indexing backend we want to expose a few endpoints + + +# TODO -- add a health check for intialization + + +async def update(): + if app_spec.indexing: + logger.info("Updating backend") + await backend.update() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # No yield from allowed + await backend.lifespan(app).__anext__() + await update() # this will trigger the repeat every N seconds + yield + await backend.lifespan(app).__anext__() + + +app = FastAPI(lifespan=lifespan) + + +@app.get("/api/v0/metadata/app_spec", response_model=BackendSpec) +def get_app_spec(): + is_indexing_backend = isinstance(backend, IndexingBackendMixin) + return BackendSpec(indexing=is_indexing_backend) + + +app_spec = get_app_spec() + +logger = logging.getLogger(__name__) + +# @repeat_every( +# seconds=update_interval if update_interval is not None else float("inf"), +# wait_first=True, +# logger=logger, +# ) + + +if app_spec.indexing: + update_interval = backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None + update = repeat_every( + seconds=backend.update_interval_milliseconds() / 1000, + wait_first=True, + logger=logger, + )(update) + + +@app.on_event("startup") +async def startup_event(): + if app_spec.indexing: + await update() @app.get("/api/v0/projects", response_model=Sequence[schema.Project]) @@ -72,7 +146,7 @@ async def get_application_logs(request: Request, project_id: str, app_id: str) - :param request: FastAPI :param project_id: ID of the project - :param app_id: ID of the associated application + :param app_id: ID of the assIndociated application :return: A list of steps with all associated step data """ return await backend.get_application_logs(request, project_id=project_id, app_id=app_id) @@ -83,6 +157,18 @@ async def ready() -> bool: return True +@app.get("/api/v0/indexing_jobs", response_model=Sequence[IndexingJob]) +async def get_indexing_jobs( + offset: int = 0, limit: int = 100, filter_empty: bool = True +) -> Sequence[IndexingJob]: + if not app_spec.indexing: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="This backend does not support indexing jobs.", + ) + return await backend.indexing_jobs(offset=offset, limit=limit, filter_empty=filter_empty) + + @app.get("/api/v0/version") async def version() -> dict: """Returns the burr version""" @@ -100,7 +186,6 @@ async def version() -> dict: app.include_router(email_assistant.router, prefix="/api/v0/email_assistant") app.include_router(streaming_chatbot.router, prefix="/api/v0/streaming_chatbot") - if SERVE_STATIC: BASE_ASSET_DIRECTORY = str(files("burr").joinpath("tracking/server/build")) diff --git a/burr/tracking/server/s3/README.md b/burr/tracking/server/s3/README.md new file mode 100644 index 00000000..43f5a6b2 --- /dev/null +++ b/burr/tracking/server/s3/README.md @@ -0,0 +1,25 @@ +# S3-backed server + +## Architecture + +## Deployment + +## Migrating/Initializing + +- make sure aerich is installed + +To reset, do: +```bash +rm -rf ~/.burr_server && +mkdir ~/.burr_server && +rm -rf ./burr/tracking/server/s3/migrations && +aerich init -t burr.tracking.server.s3.settings.TORTOISE_ORM --location ./burr/tracking/server/s3/migrations && +aerich init-db && +AWS_PROFILE=dagworks burr --no-open +``` +``` +- `rm -rf ~/.burr_server` (will be turned to an env variable) +- `mkdir ~/.burr_server` (ditto) +- (from git root) `rm -rf ./burr/tracking/server/s3/migrations` +- (from git root) `aerich init -t burr.tracking.server.s3.settings.TORTOISE_ORM --location ./burr/tracking/server/s3/migrations` +- (from git root) `aerich init-db` diff --git a/burr/tracking/server/s3/__init__.py b/burr/tracking/server/s3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py new file mode 100644 index 00000000..f87c94e3 --- /dev/null +++ b/burr/tracking/server/s3/backend.py @@ -0,0 +1,559 @@ +import dataclasses +import datetime +import functools +import itertools +import json +import logging +import operator +import uuid +from collections import Counter +from typing import List, Literal, Optional, Sequence, Tuple, Type, TypeVar, Union + +import fastapi +import pydantic +from aiobotocore import session +from fastapi import FastAPI +from pydantic_settings import BaseSettings +from tortoise import functions, transactions +from tortoise.contrib.fastapi import RegisterTortoise +from tortoise.expressions import Q + +from burr.tracking.common.models import ApplicationModel +from burr.tracking.server import schema +from burr.tracking.server.backend import BackendBase, BurrSettings, IndexingBackendMixin +from burr.tracking.server.s3 import settings, utils +from burr.tracking.server.s3.models import ( + Application, + IndexingJob, + IndexingJobStatus, + IndexStatus, + LogFile, + Project, +) +from burr.tracking.server.schema import ApplicationLogs, Step + +logger = logging.getLogger(__name__) + +FileType = Literal["log", "metadata", "graph"] + +ContentsModel = TypeVar("ContentsModel", bound=pydantic.BaseModel) + + +async def _query_s3_file( + bucket: str, + key: str, + client: session.AioBaseClient, +) -> Union[ContentsModel, List[ContentsModel]]: + response = await client.get_object(Bucket=bucket, Key=key) + body = await response["Body"].read() + return body + + +@dataclasses.dataclass +class DataFile: + """Generic data file object meant to represent a file in the s3 bucket. This has a few possible roles (log, metadata, and graph file)""" + + prefix: str + yyyy: str + mm: str + dd: str + hh: str + minutes_string: str + partition_key: str + application_id: str + file_type: FileType + path: str + created_date: datetime.datetime + + @classmethod + def from_path(cls, path: str, created_date: datetime.datetime) -> "DataFile": + parts = path.split("/") + + # Validate that there are enough parts to extract the needed fields + if len(parts) < 9: + raise ValueError(f"Path '{path}' is not valid") + + prefix = "/".join(parts[:-8]) # Everything before the year part + yyyy = parts[2] + mm = parts[3] + dd = parts[4] + hh = parts[5] + minutes_string = parts[6] + application_id = parts[8] + partition_key = parts[7] + filename = parts[9] + file_type = ( + "graph" + if filename.endswith("graph.json") + else "metadata" + if filename.endswith("_metadata.json") + else "log" + ) + + # # Validate the date parts + # if not (yyyy.isdigit() and mm.isdigit() and dd.isdigit() and hh.isdigit()): + # raise ValueError(f"Date components in the path '{path}' are not valid") + + return cls( + prefix=prefix, + yyyy=yyyy, + mm=mm, + dd=dd, + hh=hh, + minutes_string=minutes_string, + application_id=application_id, + partition_key=partition_key, + file_type=file_type, + path=path, + created_date=created_date, + ) + + +class S3Settings(BurrSettings): + s3_bucket: str + update_interval_milliseconds: int = 60_000 + aws_max_concurrency: int = 100 + + +class S3Backend(BackendBase, IndexingBackendMixin): + @classmethod + def settings_model(cls) -> Type[BaseSettings]: + return S3Settings + + def __init__(self, s3_bucket: str, update_interval_milliseconds: int, aws_max_concurrency: int): + self._backend_id = datetime.datetime.utcnow().isoformat() + str(uuid.uuid4()) + self._bucket = s3_bucket + self._session = session.get_session() + self._update_interval_milliseconds = update_interval_milliseconds + self._aws_max_concurrency = aws_max_concurrency + + def update_interval_milliseconds(self) -> Optional[int]: + return self._update_interval_milliseconds + + async def _s3_get_first_write_date(self, project_id: str): + async with self._session.create_client("s3") as client: + paginator = client.get_paginator("list_objects_v2") + async for result in paginator.paginate( + Bucket=self._bucket, Prefix=f"data/{project_id}/", Delimiter="/", MaxKeys=1 + ): + if "Contents" in result: + first_object = result["Contents"][0] + return first_object["LastModified"] + return ( + datetime.datetime.utcnow() + ) # This should never be hit unless someone is concurrently deleting... + + async def _update_projects(self): + current_projects = await Project.all() + project_names = {project.name for project in current_projects} + logger.info(f"Current projects: {project_names}") + async with self._session.create_client("s3") as client: + paginator = client.get_paginator("list_objects_v2") + async for result in paginator.paginate( + Bucket=self._bucket, Prefix="data/", Delimiter="/" + ): + for prefix in result.get("CommonPrefixes", []): + project_name = prefix.get("Prefix").split("/")[-2] + if project_name not in project_names: + now = datetime.datetime.utcnow() + logger.info(f"Creating project: {project_name}") + await Project.create( + name=project_name, + uri=None, + created_at=await self._s3_get_first_write_date(project_id=project_name), + indexed_at=now, + updated_at=now, + ) + + async def query_applications_by_key( + self, application_keys: Sequence[tuple[str, Optional[str]]] + ): + conditions = [ + Q(name=app_id, partition_key=partition_key) + for app_id, partition_key in application_keys + ] + + # Combine the conditions with an OR operation + query = Application.filter(functools.reduce(operator.or_, conditions)) + + # Execute the query + applications = await query.all() + return applications + + async def _gather_metadata_files( + self, + metadata_files: List[DataFile], + ) -> Sequence[dict]: + """Gives a list of metadata files so we can update the application""" + + async def _query_metadata_file(metadata_file: DataFile) -> dict: + async with self._session.create_client("s3") as client: + response = await client.head_object( + Bucket=self._bucket, + Key=metadata_file.path, + ) + # metadata = await response['Body'].read() + parent_pointer_raw = response["Metadata"].get("parent_pointer") + spawning_parent_pointer_raw = response["Metadata"].get("spawning_parent_pointer") + return dict( + partition_key=metadata_file.partition_key, + parent_pointer=json.loads(parent_pointer_raw) + if parent_pointer_raw != "None" + else None, + spawning_parent_pointer=json.loads(spawning_parent_pointer_raw) + if spawning_parent_pointer_raw != "None" + else None, + ) + + out = await utils.gather_with_concurrency( + self._aws_max_concurrency, + *[_query_metadata_file(metadata_file) for metadata_file in metadata_files], + ) + return out + + async def _gather_log_file_data(self, log_files: List[DataFile]) -> Sequence[dict]: + """Gives a list of log files so we can update the application""" + + async def _query_log_file(log_file: DataFile) -> dict: + async with self._session.create_client("s3") as client: + response = await client.head_object( + Bucket=self._bucket, + Key=log_file.path, + ) + # TODO -- consider the default cases, we should not have them and instead mark this as failed + return { + "min_sequence_id": response["Metadata"].get("min_sequence_id", 0), + "max_sequence_id": response["Metadata"].get("max_sequence_id", 0), + "tracker_id": response["Metadata"].get("tracker_id", "unknown"), + } + + out = await utils.gather_with_concurrency( + self._aws_max_concurrency, *[_query_log_file(log_file) for log_file in log_files] + ) + return out + + async def _gather_paths_to_update( + self, project: Project, high_watermark_s3_path: str + ) -> Sequence[DataFile]: + """Gathers all paths to update in s3 -- we store file pointers in the db for these. + This allows us to periodically scan for more files to index. + + :return: list of paths to update + """ + logger.info(f"Scanning db with highwatermark: {high_watermark_s3_path}") + paths_to_update = [] + logger.info(f"Scanning log data for project: {project.name}") + async with self._session.create_client("s3") as client: + paginator = client.get_paginator("list_objects_v2") + async for result in paginator.paginate( + Bucket=self._bucket, + Prefix=f"data/{project.name}/", + StartAfter=high_watermark_s3_path, + ): + for content in result.get("Contents", []): + key = content["Key"] + last_modified = content["LastModified"] + # Created == last_modified as we have an immutable data model + logger.info(f"Found new file: {key}") + paths_to_update.append(DataFile.from_path(key, created_date=last_modified)) + logger.info(f"Found {len(paths_to_update)} new files to index") + return paths_to_update + + async def _ensure_applications_exist( + self, paths_to_update: Sequence[DataFile], project: Project + ): + """Given the paths to update, ensure that all corresponding applications exist in the database. + + :param paths_to_update: + :param project: + :return: + """ + all_application_keys = sorted( + {(path.application_id, path.partition_key) for path in paths_to_update} + ) + counter = Counter([path.file_type for path in paths_to_update]) + logger.info( + f"Found {len(all_application_keys)} applications in the scan, " + f"including: {counter['log']} log files, " + f"{counter['metadata']} metadata files, and {counter['graph']} graph files, " + f"and {len(paths_to_update) - len(all_application_keys)} other files." + ) + + # First, let's create all applications, ignoring them if they exist + + # first let's create all the applications if they don't exist + existing_applications = { + (app.name, app.partition_key): app + for app in await self.query_applications_by_key(all_application_keys) + } + # all_applications = await Application.all() + + apps_to_create = [ + Application( + name=app_id, + partition_key=pk, + project=project, + created_at=datetime.datetime.utcnow(), + ) + for app_id, pk in all_application_keys + if (app_id, pk) not in existing_applications + ] + + logger.info( + f"Creating {len(apps_to_create)} new applications, with keys: {[(app.name, app.partition_key) for app in apps_to_create]}" + ) + await Application.bulk_create(apps_to_create) + all_applications = await self.query_applications_by_key(all_application_keys) + return all_applications + + async def _update_all_applications( + self, all_applications: Sequence[Application], paths_to_update: Sequence[DataFile] + ) -> Sequence[Application]: + """Updates all application with associate metadata and graph files + + :param all_applications: All applications that are relevant + :param paths_to_update: All paths to update + :return: + """ + logger.info(f"found: {len(all_applications)} applications to update in the db") + metadata_data = [path for path in paths_to_update if path.file_type == "metadata"] + graph_data = [path for path in paths_to_update if path.file_type == "graph"] + metadata_objects = await self._gather_metadata_files(metadata_data) + key_to_application_map = {(app.name, app.partition_key): app for app in all_applications} + # For every metadata file we want to add the metadata file + for metadata, datafile in zip(metadata_objects, metadata_data): + key = (datafile.application_id, datafile.partition_key) + app = key_to_application_map[key] + app.metadata_file_pointer = datafile.path + + # TODO -- download the metadata file and update the application + + # for every graph file, we want to add the pointer + for graph_file in graph_data: + key = (graph_file.application_id, graph_file.partition_key) + app = key_to_application_map[key] + app.graph_file_pointer = graph_file.path + # Go through every application and save them + async with transactions.in_transaction(): + # TODO -- look at bulk saving, instead of transactions + for app in all_applications: + await app.save() + return all_applications + + async def update_log_files( + self, paths_to_update: Sequence[DataFile], all_applications: Sequence[Application] + ): + log_data = [path for path in paths_to_update if path.file_type == "log"] + logfile_objects = await self._gather_log_file_data(log_data) + key_to_application_map = {(app.name, app.partition_key): app for app in all_applications} + + # TODO -- gather referenced apps (parent pointers) and get the map of IDs to names + + # Go through every log file we've stored and update the appropriate item in the db + logfiles_to_save = [] + for logfile, datafile in zip(logfile_objects, log_data): + # get the application for the log file + app = key_to_application_map[(datafile.application_id, datafile.partition_key)] + # create the log file object + logfiles_to_save.append( + LogFile( + s3_path=datafile.path, + application=app, + tracker_id=logfile["tracker_id"], + min_sequence_id=logfile["min_sequence_id"], + max_sequence_id=logfile["max_sequence_id"], + created_at=datafile.created_date, + ) + ) + # Save all the log files + await LogFile.bulk_create(logfiles_to_save) + + async def _update_high_watermark( + self, paths_to_update: Sequence[DataFile], project: Project, indexing_job: IndexingJob + ): + new_high_watermark = max(paths_to_update, key=lambda x: x.path).path + next_status = IndexStatus(s3_highwatermark=new_high_watermark, project=project) + await next_status.save() + return next_status + + async def _scan_and_update_db_for_project( + self, project: Project, indexing_job: IndexingJob + ) -> Tuple[IndexStatus, int]: + """Scans and updates the database for a project. + + TODO -- break this up into functions + + :param project: Project to scan/update + :param max_length: Maximum length of the scan -- will pause and return after this. This is so we don't block for too long. + :return: tuple of index status/num files processed + """ + # get the current status + current_status = ( + await IndexStatus.filter(project=project).order_by("-captured_time").first() + ) + # This way we can sort by the latest captured time + high_watermark = current_status.s3_highwatermark if current_status is not None else "" + logger.info(f"Scanning db with highwatermark: {high_watermark}") + paths_to_update = await self._gather_paths_to_update( + project=project, high_watermark_s3_path=high_watermark + ) + # Nothing new to see here + if len(paths_to_update) == 0: + return current_status, 0 + + all_applications = await self._ensure_applications_exist(paths_to_update, project) + await self._update_all_applications(all_applications, paths_to_update) + await self.update_log_files(paths_to_update, all_applications) + next_status = await self._update_high_watermark(paths_to_update, project, indexing_job) + return next_status, len(paths_to_update) + + async def _scan_and_update_db(self): + for project in await Project.all(): + indexing_job = IndexingJob( + records_processed=0, # start with zero + end_time=None, + status=IndexingJobStatus.RUNNING, + ) + await indexing_job.save() + + # TODO -- add error catching + status, num_files = await self._scan_and_update_db_for_project(project, indexing_job) + logger.info(f"Scanned: {num_files} files with status stored at ID={status.id}") + + indexing_job.records_processed = num_files + indexing_job.end_time = datetime.datetime.utcnow() + # TODO -- handle failure + indexing_job.status = IndexingJobStatus.SUCCESS + indexing_job.index_status = status + await indexing_job.save() + + async def update(self): + await self._update_projects() + await self._scan_and_update_db() + + async def lifespan(self, app: FastAPI): + async with RegisterTortoise(app, config=settings.TORTOISE_ORM, add_exception_handlers=True): + yield + + async def list_projects(self, request: fastapi.Request) -> Sequence[schema.Project]: + project_query = await Project.all() + out = [] + for project in project_query: + latest_logfile = ( + await LogFile.filter(application__project=project).order_by("-created_at").first() + ) + out.append( + schema.Project( + name=project.name, + id=project.name, + uri=project.uri if project.uri is not None else "TODO", + last_written=latest_logfile.created_at + if latest_logfile is not None + else project.created_at, + created=project.created_at, + num_apps=await Application.filter(project=project).count(), + ) + ) + return out + + async def list_apps( + self, request: fastapi.Request, project_id: str, limit: int = 100, offset: int = 0 + ) -> Sequence[schema.ApplicationSummary]: + # TODO -- distinctify between project name and project ID + # Currently they're the same in the UI but we'll want to have them decoupled + applications = ( + await Application.filter(project__name=project_id) + .annotate( + latest_logfile_created_at=functions.Max("log_files__created_at"), + logfile_count=functions.Max("log_files__max_sequence_id"), + ) + .order_by("created_at") + .offset(offset) + .limit(limit) + .prefetch_related("log_files", "project") + ) + out = [] + for application in applications: + last_written = ( + datetime.datetime.fromisoformat(application.latest_logfile_created_at) + if (application.latest_logfile_created_at is not None) + else application.created_at + ) + out.append( + schema.ApplicationSummary( + app_id=application.name, + partition_key=application.partition_key, + first_written=application.created_at, + last_written=last_written, + num_steps=application.logfile_count, + tags={}, + ) + ) + return out + + async def get_application_logs( + self, request: fastapi.Request, project_id: str, app_id: str + ) -> ApplicationLogs: + # TODO -- handle partition keys + applications = await Application.filter(name=app_id, project__name=project_id).all() + application = applications[0] + application_logs = await LogFile.filter(application__id=application.id).order_by( + "-created_at" + ) + async with self._session.create_client("s3") as client: + # Get all the files + files = await utils.gather_with_concurrency( + 1, + _query_s3_file(self._bucket, application.graph_file_pointer, client), + # _query_s3_files(self.bucket, application.metadata_file_pointer, client), + *itertools.chain( + _query_s3_file(self._bucket, log_file.s3_path, client) + for log_file in application_logs + ), + ) + graph_data = ApplicationModel.parse_raw(files[0]) + # TODO -- deal with what happens if the application is None + # TODO -- handle metadata + # metadata = ApplicationMetadataModel.parse_raw(files[1]) + steps = Step.from_logs(list(itertools.chain(*[f.splitlines() for f in files[1:]]))) + + return ApplicationLogs( + children=[], + steps=steps, + # TODO -- get this in + parent_pointer=None, + spawning_parent_pointer=None, + application=graph_data, + ) + + async def indexing_jobs( + self, offset: int = 0, limit: int = 100, filter_empty: bool = True + ) -> Sequence[schema.IndexingJob]: + indexing_jobs_query = ( + IndexingJob.all().order_by("-start_time").prefetch_related("index_status__project") + ) + + # Apply filter conditionally + if filter_empty: + indexing_jobs_query = indexing_jobs_query.filter(records_processed__gt=0) + indexing_jobs = await indexing_jobs_query.offset(offset).limit(limit) + out = [] + for indexing_job in indexing_jobs: + out.append( + schema.IndexingJob( + id=indexing_job.id, + start_time=indexing_job.start_time, + end_time=indexing_job.end_time, + status=indexing_job.status, + records_processed=indexing_job.records_processed, + metadata={ + "project": indexing_job.index_status.project.name + if indexing_job.index_status + else "unknown", + "s3_highwatermark": indexing_job.index_status.s3_highwatermark + if indexing_job.index_status + else "unknown", + }, + ) + ) + return out diff --git a/burr/tracking/server/s3/initialize_db.py b/burr/tracking/server/s3/initialize_db.py new file mode 100644 index 00000000..2bbbf65e --- /dev/null +++ b/burr/tracking/server/s3/initialize_db.py @@ -0,0 +1,28 @@ +import os +from pathlib import Path + +from tortoise import Tortoise + +from burr.tracking.server.s3 import settings + +DB_PATH = Path("~/.burr_server/db.sqlite3").expanduser() + + +async def connect(): + if not os.path.exists(DB_PATH): + os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) + await Tortoise.init( + config=settings.TORTOISE_ORM, + ) + + +# +# async def first_time_init(): +# await connect() +# # Generate the schema +# await Tortoise.generate_schemas() +# +# +# if __name__ == '__main__': +# # db_path = sys.argv[1] +# run_async(first_time_init()) diff --git a/burr/tracking/server/s3/migrations/models/0_20240730151503_init.py b/burr/tracking/server/s3/migrations/models/0_20240730151503_init.py new file mode 100644 index 00000000..b6461010 --- /dev/null +++ b/burr/tracking/server/s3/migrations/models/0_20240730151503_init.py @@ -0,0 +1,70 @@ +from tortoise import BaseDBAsyncClient + + +async def upgrade(db: BaseDBAsyncClient) -> str: + return """ + CREATE TABLE IF NOT EXISTS "project" ( + "created_at" TIMESTAMP NOT NULL, + "indexed_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "name" VARCHAR(255) NOT NULL UNIQUE, + "uri" VARCHAR(255) +) /* Static model representing a project */; +CREATE INDEX IF NOT EXISTS "idx_project_name_4d952a" ON "project" ("name"); +CREATE TABLE IF NOT EXISTS "application" ( + "created_at" TIMESTAMP NOT NULL, + "indexed_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "name" VARCHAR(255) NOT NULL, + "partition_key" VARCHAR(255) NOT NULL, + "graph_file_pointer" VARCHAR(255), + "metadata_file_pointer" VARCHAR(255), + "fork_parent_id" INT REFERENCES "application" ("id") ON DELETE CASCADE, + "project_id" INT NOT NULL REFERENCES "project" ("id") ON DELETE CASCADE, + "spawning_parent_id" INT REFERENCES "application" ("id") ON DELETE CASCADE, + CONSTRAINT "uid_application_name_488894" UNIQUE ("name", "partition_key") +); +CREATE INDEX IF NOT EXISTS "idx_application_name_18706d" ON "application" ("name"); +CREATE INDEX IF NOT EXISTS "idx_application_partiti_d302c8" ON "application" ("partition_key"); +CREATE INDEX IF NOT EXISTS "idx_application_project_13a4e1" ON "application" ("project_id"); +CREATE TABLE IF NOT EXISTS "indexstatus" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "s3_highwatermark" VARCHAR(1023) NOT NULL, + "captured_time" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "project_id" INT NOT NULL REFERENCES "project" ("id") ON DELETE CASCADE +) /* Status to index. These are per-project and the latest is chosen */; +CREATE INDEX IF NOT EXISTS "idx_indexstatus_capture_d2163c" ON "indexstatus" ("captured_time"); +CREATE INDEX IF NOT EXISTS "idx_indexstatus_project_52e6eb" ON "indexstatus" ("project_id"); +CREATE TABLE IF NOT EXISTS "indexingjob" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "start_time" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "records_processed" INT NOT NULL, + "end_time" TIMESTAMP, + "status" VARCHAR(7) NOT NULL /* SUCCESS: SUCCESS\nFAILURE: FAILURE\nRUNNING: RUNNING */, + "index_status_id" INT REFERENCES "indexstatus" ("id") ON DELETE CASCADE +) /* Job for indexing data in s3. Records only if there's something to index */; +CREATE TABLE IF NOT EXISTS "logfile" ( + "created_at" TIMESTAMP NOT NULL, + "indexed_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "s3_path" VARCHAR(1024) NOT NULL, + "tracker_id" VARCHAR(255) NOT NULL, + "min_sequence_id" INT NOT NULL, + "max_sequence_id" INT NOT NULL, + "application_id" INT NOT NULL REFERENCES "application" ("id") ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS "idx_logfile_applica_9633be" ON "logfile" ("application_id"); +CREATE TABLE IF NOT EXISTS "aerich" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + "version" VARCHAR(255) NOT NULL, + "app" VARCHAR(100) NOT NULL, + "content" JSON NOT NULL +);""" + + +async def downgrade(db: BaseDBAsyncClient) -> str: + return """ + """ diff --git a/burr/tracking/server/s3/models.py b/burr/tracking/server/s3/models.py new file mode 100644 index 00000000..1290b677 --- /dev/null +++ b/burr/tracking/server/s3/models.py @@ -0,0 +1,96 @@ +import enum + +from tortoise import fields +from tortoise.models import Model + + +class IndexingJobStatus(enum.Enum): + SUCCESS = "SUCCESS" + FAILURE = "FAILURE" + RUNNING = "RUNNING" + + +class IndexedModel(Model): + """Base model for all models that are indexed in s3. Contains data on creating/updating""" + + created_at = fields.DatetimeField(null=False) + indexed_at = fields.DatetimeField(null=True, auto_now_add=True) + updated_at = fields.DatetimeField(null=False, auto_now=True) + + class Meta: + abstract = True + + +class IndexingJob(Model): + """Job for indexing data in s3. Records only if there's something to index""" + + id = fields.IntField(pk=True) + start_time = fields.DatetimeField(auto_now_add=True) + records_processed = fields.IntField() + end_time = fields.DatetimeField(null=True) + status = fields.CharEnumField(IndexingJobStatus) + index_status = fields.ForeignKeyField( + "models.IndexStatus", related_name="index_status", null=True + ) + + def __str__(self): + return f"{self.start_time} - {self.end_time}" + + +class IndexStatus(Model): + """Status to index. These are per-project and the latest is chosen""" + + id = fields.IntField(pk=True) + s3_highwatermark = fields.CharField(max_length=1023) + captured_time = fields.DatetimeField(index=True, auto_now_add=True) + project = fields.ForeignKeyField("models.Project", related_name="project", index=True) + + def __str__(self): + return f"{self.project} - {self.captured_time}" + + +class Project(IndexedModel): + """Static model representing a project""" + + id = fields.IntField(pk=True) + name = fields.CharField(index=True, max_length=255, unique=True) + uri = fields.CharField(max_length=255, null=True) + + def __str__(self): + return self.name + + +class Application(IndexedModel): + id = fields.IntField(pk=True) + name = fields.CharField(index=True, max_length=255) + partition_key = fields.CharField(max_length=255, index=True, null=False) + project = fields.ForeignKeyField("models.Project", related_name="applications", index=True) + graph_file_pointer = fields.CharField(max_length=255, null=True) + metadata_file_pointer = fields.CharField(max_length=255, null=True) + fork_parent = fields.ForeignKeyField("models.Application", related_name="forks", null=True) + spawning_parent = fields.ForeignKeyField("models.Application", related_name="spawns", null=True) + + class Meta: + # App name is unique together + unique_together = (("name", "partition_key"),) + + def graph_file_indexed(self) -> bool: + return self.graph_file_pointer is not None + + def metadata_file_indexed(self) -> bool: + return self.metadata_file_pointer is not None + + +class LogFile(IndexedModel): + # s3 path is named + # ---.jsonl + id = fields.IntField(pk=True) + s3_path = fields.CharField(max_length=1024) + application = fields.ForeignKeyField( + "models.Application", + related_name="log_files", + index=True, + ) + tracker_id = fields.CharField(max_length=255) + min_sequence_id = fields.IntField() + max_sequence_id = fields.IntField() diff --git a/burr/tracking/server/s3/pyproject.toml b/burr/tracking/server/s3/pyproject.toml new file mode 100644 index 00000000..36b1df16 --- /dev/null +++ b/burr/tracking/server/s3/pyproject.toml @@ -0,0 +1,4 @@ +[tool.aerich] +tortoise_orm = "burr.tracking.server.s3.settings.TORTOISE_ORM" +location = "./burr/tracking/server/s3/migrations" +src_folder = "./." diff --git a/burr/tracking/server/s3/settings.py b/burr/tracking/server/s3/settings.py new file mode 100644 index 00000000..fe172ed9 --- /dev/null +++ b/burr/tracking/server/s3/settings.py @@ -0,0 +1,13 @@ +from pathlib import Path + +DB_PATH = Path("~/.burr_server/db.sqlite3").expanduser() + +TORTOISE_ORM = { + "connections": {"default": f"sqlite:///{DB_PATH}"}, + "apps": { + "models": { + "models": ["burr.tracking.server.s3.models", "aerich.models"], + "default_connection": "default", + }, + }, +} diff --git a/burr/tracking/server/s3/utils.py b/burr/tracking/server/s3/utils.py new file mode 100644 index 00000000..52f6f655 --- /dev/null +++ b/burr/tracking/server/s3/utils.py @@ -0,0 +1,14 @@ +import asyncio +from typing import Awaitable, TypeVar + +AwaitableType = TypeVar("AwaitableType") + + +async def gather_with_concurrency(n, *coros: Awaitable[AwaitableType]) -> tuple[AwaitableType, ...]: + semaphore = asyncio.Semaphore(n) + + async def sem_coro(coro: Awaitable[AwaitableType]) -> AwaitableType: + async with semaphore: + return await coro + + return await asyncio.gather(*(sem_coro(c) for c in coros)) diff --git a/burr/tracking/server/schema.py b/burr/tracking/server/schema.py index 61989c6e..950a911c 100644 --- a/burr/tracking/server/schema.py +++ b/burr/tracking/server/schema.py @@ -1,7 +1,9 @@ +import collections import datetime -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import pydantic +from pydantic import fields from burr.tracking.common.models import ( ApplicationModel, @@ -12,15 +14,16 @@ EndSpanModel, PointerModel, ) +from burr.tracking.utils import safe_json_load class Project(pydantic.BaseModel): name: str id: str # defaults to name for local, not for remote - uri: str # TODO -- figure out what last_written: datetime.datetime created: datetime.datetime num_apps: int + uri: str class ApplicationSummary(pydantic.BaseModel): @@ -40,6 +43,11 @@ class ApplicationModelWithChildren(pydantic.BaseModel): type: str = "application_with_children" +class PartialSpan(pydantic.BaseModel): + begin_entry: Optional[BeginSpanModel] = fields.Field(default_factory=lambda: None) + end_entry: Optional[EndSpanModel] = fields.Field(default_factory=lambda: None) + + class Span(pydantic.BaseModel): """Represents a span. These have action sequence IDs associated with them to put them in order.""" @@ -48,6 +56,12 @@ class Span(pydantic.BaseModel): end_entry: Optional[EndSpanModel] +class PartialStep(pydantic.BaseModel): + step_start_log: Optional[BeginEntryModel] = fields.Field(default_factory=lambda: None) + step_end_log: Optional[EndEntryModel] = fields.Field(default_factory=lambda: None) + spans: List[Span] = fields.Field(default_factory=list) + + class Step(pydantic.BaseModel): """Log of astep -- has a start and an end.""" @@ -55,6 +69,55 @@ class Step(pydantic.BaseModel): step_end_log: Optional[EndEntryModel] spans: List[Span] + @staticmethod + def from_logs(log_lines: List[bytes]) -> List["Step"]: + steps_by_sequence_id = collections.defaultdict(PartialStep) + spans_by_id = collections.defaultdict(Span) + for line in log_lines: + json_line = safe_json_load(line) + # TODO -- make these into constants + if json_line["type"] == "begin_entry": + begin_step = BeginEntryModel.parse_obj(json_line) + steps_by_sequence_id[begin_step.sequence_id].step_start_log = begin_step + elif json_line["type"] == "end_entry": + step_end_log = EndEntryModel.parse_obj(json_line) + steps_by_sequence_id[step_end_log.sequence_id].step_end_log = step_end_log + elif json_line["type"] == "begin_span": + span = BeginSpanModel.parse_obj(json_line) + spans_by_id[span.span_id] = Span( + begin_entry=span, + end_entry=None, + ) + elif json_line["type"] == "end_span": + end_span = EndSpanModel.parse_obj(json_line) + span = spans_by_id[end_span.span_id] + span.end_entry = end_span + for span in spans_by_id.values(): + sequence_id = ( + span.begin_entry.action_sequence_id + if span.begin_entry + else span.end_entry.action_sequence_id + ) + step = ( + steps_by_sequence_id[sequence_id] if sequence_id in steps_by_sequence_id else None + ) + if step is not None: + step.spans.append(span) + # filter out all the non-null start steps + return [ + Step( + step_start_log=value.step_start_log, + step_end_log=value.step_end_log, + spans=[Span(**span.dict()) for span in value.spans if span.begin_entry is not None], + ) + for key, value in sorted(steps_by_sequence_id.items()) + if value.step_start_log is not None + ] + + +class StepWithMinimalData(Step): + step_start_log: Optional[BeginEntryModel] + class ApplicationLogs(pydantic.BaseModel): """Application logs are purely flat -- @@ -65,3 +128,20 @@ class ApplicationLogs(pydantic.BaseModel): steps: List[Step] parent_pointer: Optional[PointerModel] = None spawning_parent_pointer: Optional[PointerModel] = None + + +class IndexingJob(pydantic.BaseModel): + """Generic link for indexing job -- can be exposed in 'admin mode' in the UI""" + + id: int + start_time: datetime.datetime + end_time: Optional[datetime.datetime] + status: str + records_processed: int + metadata: Dict[str, Any] + + +class BackendSpec(pydantic.BaseModel): + """Generic link for indexing job -- can be exposed in 'admin mode' in the UI""" + + indexing: bool From c793a6a5b0f00f0edee7b005820235e78d425914 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 31 Jul 2024 16:35:30 -0700 Subject: [PATCH 04/11] Adds scripts to generate demo data on s3 We also made it so we can wire through the s3 data through the command --- burr/cli/demo_data.py | 130 ++++++++++++++---- .../simple_example/application.py | 38 +++-- examples/multi-modal-chatbot/application.py | 4 +- examples/multi-modal-chatbot/server.py | 2 +- examples/streaming-fastapi/application.py | 8 +- examples/streaming-fastapi/streamlit_app.py | 4 + examples/tracing-and-spans/application.py | 64 +++++---- pyproject.toml | 21 +++ 8 files changed, 202 insertions(+), 69 deletions(-) diff --git a/burr/cli/demo_data.py b/burr/cli/demo_data.py index 4582c2ae..25300931 100644 --- a/burr/cli/demo_data.py +++ b/burr/cli/demo_data.py @@ -1,15 +1,32 @@ import importlib +import logging import os +import uuid +from typing import Optional -from application import logger +from burr.core import ApplicationBuilder, Result, default, expr +from burr.core.graph import GraphBuilder +from burr.tracking import LocalTrackingClient +from burr.tracking.s3client import S3TrackingClient -conversational_rag_application = importlib.import_module("examples.conversational-rag.application") +logger = logging.getLogger(__name__) + +conversational_rag_application = importlib.import_module( + "examples.conversational-rag.simple_example.application" +) counter_application = importlib.import_module("examples.hello-world-counter.application") chatbot_application = importlib.import_module("examples.multi-modal-chatbot.application") chatbot_application_with_traces = importlib.import_module("examples.tracing-and-spans.application") -def generate_chatbot_data(data_dir: str, use_traces: bool): +def generate_chatbot_data( + data_dir: Optional[str] = None, + s3_bucket: Optional[str] = None, + use_traces: bool = False, + unique_app_names: bool = False, +): + project_id = "demo_chatbot" if not use_traces else "demo_chatbot_with_traces" + run_prefix = str(uuid.uuid4())[0:8] + "-" if unique_app_names else "" working_conversations = { "chat-1-giraffe": [ "Please draw a giraffe.", # Answered by the image mode @@ -44,39 +61,87 @@ def generate_chatbot_data(data_dir: str, use_traces: bool): } broken_conversations = {"chat-6-demonstrate-errors": working_conversations["chat-1-giraffe"]} + def _modify(app_id: str) -> str: + return run_prefix + app_id + def _run_conversation(app_id, prompts): - app = (chatbot_application_with_traces if use_traces else chatbot_application).application( - app_id=app_id, - storage_dir=data_dir, + tracker = ( + LocalTrackingClient(project=project_id, storage_dir=data_dir) + if not s3_bucket + else S3TrackingClient(project=project_id, bucket=s3_bucket) + ) + graph = (chatbot_application_with_traces if use_traces else chatbot_application).graph + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_identifiers(app_id=app_id) + .with_tracker(tracker) + .with_entrypoint("prompt") + .build() ) for prompt in prompts: app.run(halt_after=["response"], inputs={"prompt": prompt}) for app_id, prompts in sorted(working_conversations.items()): - _run_conversation(app_id, prompts) + _run_conversation(_modify(app_id), prompts) old_api_key = os.environ.get("OPENAI_API_KEY") os.environ["OPENAI_API_KEY"] = "fake" for app_id, prompts in sorted(broken_conversations.items()): try: - _run_conversation(app_id, prompts) + _run_conversation(_modify(app_id), prompts) except Exception as e: print(f"Got an exception: {e}") os.environ["OPENAI_API_KEY"] = old_api_key -def generate_counter_data(data_dir: str = "~/.burr"): +def generate_counter_data( + data_dir: str = "~/.burr", s3_bucket: Optional[str] = None, unique_app_names: bool = False +): + counter = counter_application.counter + tracker = ( + LocalTrackingClient(project="demo_counter", storage_dir=data_dir) + if not s3_bucket + else S3TrackingClient(project="demo_counter", bucket=s3_bucket) + ) + counts = [1, 10, 100, 50, 42] + # This is just cause we don't want to change the code + # TODO -- add ability to grab graph from application or something like that + graph = ( + GraphBuilder() + .with_actions(counter=counter, result=Result("counter")) + .with_transitions( + ("counter", "counter", expr("counter < count_to")), + ("counter", "result", default), + ) + .build() + ) for i, count in enumerate(counts): - app = counter_application.application( - count_up_to=count, - app_id=f"count-to-{count}", - storage_dir=data_dir, - partition_key=f"user_{i}", + app_id = f"count-to-{count}" + if unique_app_names: + suffix = str(uuid.uuid4())[0:8] + app_id = f"{app_id}-{suffix}" + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_identifiers(app_id=app_id, partition_key=f"user_{i}") + .with_state(count_to=count, counter=0) + .with_tracker(tracker) + .with_entrypoint("counter") + .build() ) + # app = counter_application.application( + # count_up_to=count, + # app_id=f"count-to-{count}", + # storage_dir=data_dir, + # partition_key=f"user_{i}", + # ) app.run(halt_after=["result"]) -def generate_rag_data(data_dir: str = "~/.burr"): +def generate_rag_data( + data_dir: Optional[str] = None, s3_bucket: Optional[str] = None, unique_app_names: bool = False +): conversations = { "rag-1-food": [ "What is Elijah's favorite food?", @@ -105,24 +170,43 @@ def generate_rag_data(data_dir: str = "~/.burr"): "Whose favorite food is better, Elijah's or Stefan's?" "exit", ], } + prefix = str(uuid.uuid4())[0:8] + "-" if unique_app_names else "" for app_id, prompts in sorted(conversations.items()): - app = conversational_rag_application.application( - app_id=app_id, - storage_dir=data_dir, + graph = conversational_rag_application.graph() + tracker = ( + LocalTrackingClient(project="demo_conversational-rag", storage_dir=data_dir) + if not s3_bucket + else S3TrackingClient(project="demo_conversational-rag", bucket=s3_bucket) + ) + app_id = f"{prefix}{app_id}" if unique_app_names else app_id + app = ( + ApplicationBuilder() + .with_graph(graph) + .with_identifiers(app_id=app_id) + .with_tracker(tracker) + .with_entrypoint("human_converse") + .build() ) + logger.warning(f"Running {app_id}...") for prompt in prompts: app.run(halt_after=["ai_converse", "terminal"], inputs={"user_question": prompt}) -def generate_all(data_dir: str): +def generate_all( + data_dir: Optional[str] = None, s3_bucket: Optional[str] = None, unique_app_names: bool = False +): logger.info("Generating chatbot data") - generate_chatbot_data(data_dir, False) + generate_chatbot_data( + data_dir=data_dir, s3_bucket=s3_bucket, use_traces=False, unique_app_names=unique_app_names + ) logger.info("Generating chatbot data with traces") - generate_chatbot_data(data_dir, True) + generate_chatbot_data( + data_dir=data_dir, s3_bucket=s3_bucket, use_traces=True, unique_app_names=unique_app_names + ) logger.info("Generating counter data") - generate_counter_data(data_dir) + generate_counter_data(data_dir=data_dir, s3_bucket=s3_bucket, unique_app_names=unique_app_names) logger.info("Generating RAG data") - generate_rag_data(data_dir) + generate_rag_data(data_dir=data_dir, s3_bucket=s3_bucket, unique_app_names=unique_app_names) # diff --git a/examples/conversational-rag/simple_example/application.py b/examples/conversational-rag/simple_example/application.py index 1f798c49..9b1c9b65 100644 --- a/examples/conversational-rag/simple_example/application.py +++ b/examples/conversational-rag/simple_example/application.py @@ -6,6 +6,7 @@ import burr.core from burr.core import Action, Application, ApplicationBuilder, State, default, expr from burr.core.action import action +from burr.core.graph import GraphBuilder from burr.lifecycle import LifecycleAdapter, PostRunStepHook, PreRunStepHook # create the pipeline @@ -70,12 +71,7 @@ def human_converse(state: State, user_question: str) -> Tuple[dict, State]: return {"question": user_question}, state -def application( - app_id: Optional[str] = None, - storage_dir: Optional[str] = "~/.burr", - hooks: Optional[List[LifecycleAdapter]] = None, -) -> Application: - # our initial knowledge base +def graph(): input_text = [ "harrison worked at kensho", "stefan worked at Stitch Fix", @@ -87,14 +83,8 @@ def application( "stefan likes to bake sourdough", ] vector_store = bootstrap_vector_db(conversational_rag_driver, input_text) - app = ( - ApplicationBuilder() - .with_state( - **{ - "question": "", - "chat_history": [], - } - ) + return ( + GraphBuilder() .with_actions( # bind the vector store to the AI conversational step ai_converse=ai_converse.bind(vector_store=vector_store), @@ -106,6 +96,26 @@ def application( ("human_converse", "terminal", expr("'exit' in question")), ("human_converse", "ai_converse", default), ) + .build() + ) + + +def application( + app_id: Optional[str] = None, + storage_dir: Optional[str] = "~/.burr", + hooks: Optional[List[LifecycleAdapter]] = None, +) -> Application: + # our initial knowledge base + + app = ( + ApplicationBuilder() + .with_state( + **{ + "question": "", + "chat_history": [], + } + ) + .with_graph(graph()) .with_entrypoint("human_converse") .with_tracker(project="demo_conversational-rag", params={"storage_dir": storage_dir}) .with_identifiers(app_id=app_id, partition_key="sample_user") diff --git a/examples/multi-modal-chatbot/application.py b/examples/multi-modal-chatbot/application.py index 03b95d99..5f40da57 100644 --- a/examples/multi-modal-chatbot/application.py +++ b/examples/multi-modal-chatbot/application.py @@ -142,7 +142,7 @@ def response(state: State) -> State: return state.append(chat_history=result["chat_item"]) -base_graph = ( +graph = ( graph.GraphBuilder() .with_actions( prompt=process_prompt, @@ -192,7 +192,7 @@ def base_application( tracker = LocalTrackingClient(project=project_id, storage_dir=storage_dir) return ( ApplicationBuilder() - .with_graph(base_graph) + .with_graph(graph) # initializes from the tracking log if it does not already exist .initialize_from( tracker, diff --git a/examples/multi-modal-chatbot/server.py b/examples/multi-modal-chatbot/server.py index f009f4ae..344e819a 100644 --- a/examples/multi-modal-chatbot/server.py +++ b/examples/multi-modal-chatbot/server.py @@ -25,7 +25,7 @@ router = APIRouter() -graph = chat_application.base_graph +graph = chat_application.graph class ChatItem(pydantic.BaseModel): diff --git a/examples/streaming-fastapi/application.py b/examples/streaming-fastapi/application.py index 8e90df8e..63a84d29 100644 --- a/examples/streaming-fastapi/application.py +++ b/examples/streaming-fastapi/application.py @@ -7,6 +7,7 @@ from burr.core import ApplicationBuilder, State, default, when from burr.core.action import action, streaming_action from burr.core.graph import GraphBuilder +from burr.tracking.s3client import S3TrackingClient MODES = [ "answer_question", @@ -173,7 +174,12 @@ def application(app_id: Optional[str] = None): .with_entrypoint("prompt") .with_state(chat_history=[]) .with_graph(graph) - .with_tracker(project="demo_chatbot_streaming") + # .with_tracker(project="demo_chatbot_streaming") + .with_tracker( + tracker=S3TrackingClient( + bucket="burr-prod-test", project="demo_chatbot_streaming", non_blocking=True + ) + ) .with_identifiers(app_id=app_id) .build() ) diff --git a/examples/streaming-fastapi/streamlit_app.py b/examples/streaming-fastapi/streamlit_app.py index 9b6bd50b..bed5445a 100644 --- a/examples/streaming-fastapi/streamlit_app.py +++ b/examples/streaming-fastapi/streamlit_app.py @@ -1,12 +1,16 @@ import asyncio +import logging import uuid import application as chatbot_application import streamlit as st +from hamilton.log_setup import setup_logging import burr.core from burr.core.action import AsyncStreamingResultContainer +setup_logging(logging.INFO) + def render_chat_message(chat_item: dict): content = chat_item["content"] diff --git a/examples/tracing-and-spans/application.py b/examples/tracing-and-spans/application.py index d3bb8244..f6aa8e16 100644 --- a/examples/tracing-and-spans/application.py +++ b/examples/tracing-and-spans/application.py @@ -4,6 +4,7 @@ from burr.core import Application, ApplicationBuilder, State, default, when from burr.core.action import action +from burr.core.graph import GraphBuilder from burr.visibility import TracerFactory MODES = { @@ -146,42 +147,49 @@ def response(state: State, __tracer: TracerFactory) -> Tuple[dict, State]: return result, state.append(chat_history=result["chat_item"]) +graph = ( + GraphBuilder() + .with_actions( + prompt=process_prompt, + check_safety=check_safety, + decide_mode=choose_mode, + generate_image=image_response, + generate_code=chat_response.bind( + prepend_prompt="Please respond with *only* code and no other text (at all) to the following:", + ), + answer_question=chat_response.bind( + prepend_prompt="Please answer the following question:", + ), + prompt_for_more=prompt_for_more, + response=response, + ) + .with_transitions( + ("prompt", "check_safety", default), + ("check_safety", "decide_mode", when(safe=True)), + ("check_safety", "response", default), + ("decide_mode", "generate_image", when(mode="generate_image")), + ("decide_mode", "generate_code", when(mode="generate_code")), + ("decide_mode", "answer_question", when(mode="answer_question")), + ("decide_mode", "prompt_for_more", default), + ( + ["generate_image", "answer_question", "generate_code", "prompt_for_more"], + "response", + ), + ("response", "prompt", default), + ) + .build() +) + + def application( app_id: Optional[str] = None, storage_dir: Optional[str] = "~/.burr", ) -> Application: return ( ApplicationBuilder() - .with_actions( - prompt=process_prompt, - check_safety=check_safety, - decide_mode=choose_mode, - generate_image=image_response, - generate_code=chat_response.bind( - prepend_prompt="Please respond with *only* code and no other text (at all) to the following:", - ), - answer_question=chat_response.bind( - prepend_prompt="Please answer the following question:", - ), - prompt_for_more=prompt_for_more, - response=response, - ) .with_entrypoint("prompt") .with_state(chat_history=[]) - .with_transitions( - ("prompt", "check_safety", default), - ("check_safety", "decide_mode", when(safe=True)), - ("check_safety", "response", default), - ("decide_mode", "generate_image", when(mode="generate_image")), - ("decide_mode", "generate_code", when(mode="generate_code")), - ("decide_mode", "answer_question", when(mode="answer_question")), - ("decide_mode", "prompt_for_more", default), - ( - ["generate_image", "answer_question", "generate_code", "prompt_for_more"], - "response", - ), - ("response", "prompt", default), - ) + .with_graph(graph) .with_tracker(project="demo_tracing", params={"storage_dir": storage_dir}) .with_identifiers(app_id=app_id) .build() diff --git a/pyproject.toml b/pyproject.toml index f019a08a..f55e6419 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,11 +80,17 @@ tracking-client = [ "pydantic" ] +tracking-client-s3 = [ + "burr[tracking-client]", + "aiobotocore" +] + tracking-server = [ "click", "fastapi", "uvicorn", "pydantic", + "pydantic-settings", "fastapi-pagination", "aiofiles", "requests", @@ -106,6 +112,15 @@ start = [ "burr[learn]" ] +# All the bloatware from various LLM demos +# In the future most people will be relying on simple APIs, not this +# But its good for demos! +bloat = [ + "langchain", + "langchain-community", + "langchain-openai" +] + # just install everything for developers developer = [ "burr[streamlit]", @@ -113,6 +128,7 @@ developer = [ "burr[tracking]", "burr[tests]", "burr[documentation]", + "burr[bloat]", "build", "twine", "pre-commit", @@ -131,6 +147,11 @@ burr = [ "burr/tracking/server/demo_data/**/*" ] + +[tool.aerich] +tortoise_orm = "burr.tracking.server.s3.settings.TORTOISE_ORM" +location = "./burr/tracking/server/s3/migrations" +src_folder = "./." [project.urls] Homepage = "https://github.com/dagworks-inc/burr" Documentation = "https://github.com/dagworks-inc/burr" From 5f3b61c8c4eced386f0de09e992bfe38dda4191a Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 31 Jul 2024 16:38:36 -0700 Subject: [PATCH 05/11] Adds updates for s3 server on frontend This has pages on indexing jobs + a few other updates --- telemetry/ui/src/App.tsx | 2 + telemetry/ui/src/api/index.ts | 3 + telemetry/ui/src/api/models/BackendSpec.ts | 10 ++ telemetry/ui/src/api/models/IndexingJob.ts | 15 +++ telemetry/ui/src/api/models/Project.ts | 13 ++- telemetry/ui/src/api/models/PromptInput.ts | 7 ++ .../ui/src/api/services/DefaultService.ts | 49 +++++++++- .../ui/src/components/nav/appcontainer.tsx | 16 ++++ .../ui/src/components/routes/AdminView.tsx | 94 +++++++++++++++++++ .../ui/src/components/routes/AppList.tsx | 12 ++- 10 files changed, 207 insertions(+), 14 deletions(-) create mode 100644 telemetry/ui/src/api/models/BackendSpec.ts create mode 100644 telemetry/ui/src/api/models/IndexingJob.ts create mode 100644 telemetry/ui/src/api/models/PromptInput.ts create mode 100644 telemetry/ui/src/components/routes/AdminView.tsx diff --git a/telemetry/ui/src/App.tsx b/telemetry/ui/src/App.tsx index 32d5ae00..3f08361e 100644 --- a/telemetry/ui/src/App.tsx +++ b/telemetry/ui/src/App.tsx @@ -9,6 +9,7 @@ import { ChatbotWithTelemetry } from './examples/Chatbot'; import { Counter } from './examples/Counter'; import { EmailAssistantWithTelemetry } from './examples/EmailAssistant'; import { StreamingChatbotWithTelemetry } from './examples/StreamingChatbot'; +import { AdminView } from './components/routes/AdminView'; /** * Basic application. We have an AppContainer -- this has a breadcrumb and a sidebar. @@ -39,6 +40,7 @@ const App = () => { } /> } /> } /> + } /> } /> diff --git a/telemetry/ui/src/api/index.ts b/telemetry/ui/src/api/index.ts index 4969521a..b0c0638e 100644 --- a/telemetry/ui/src/api/index.ts +++ b/telemetry/ui/src/api/index.ts @@ -11,6 +11,7 @@ export type { ActionModel } from './models/ActionModel'; export type { ApplicationLogs } from './models/ApplicationLogs'; export type { ApplicationModel } from './models/ApplicationModel'; export type { ApplicationSummary } from './models/ApplicationSummary'; +export type { BackendSpec } from './models/BackendSpec'; export type { BeginEntryModel } from './models/BeginEntryModel'; export type { BeginSpanModel } from './models/BeginSpanModel'; export { ChatItem } from './models/ChatItem'; @@ -21,8 +22,10 @@ export type { EndEntryModel } from './models/EndEntryModel'; export type { EndSpanModel } from './models/EndSpanModel'; export type { Feedback } from './models/Feedback'; export type { HTTPValidationError } from './models/HTTPValidationError'; +export type { IndexingJob } from './models/IndexingJob'; export type { PointerModel } from './models/PointerModel'; export type { Project } from './models/Project'; +export type { PromptInput } from './models/PromptInput'; export type { QuestionAnswers } from './models/QuestionAnswers'; export type { Span } from './models/Span'; export type { Step } from './models/Step'; diff --git a/telemetry/ui/src/api/models/BackendSpec.ts b/telemetry/ui/src/api/models/BackendSpec.ts new file mode 100644 index 00000000..ff240c5d --- /dev/null +++ b/telemetry/ui/src/api/models/BackendSpec.ts @@ -0,0 +1,10 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +/** + * Generic link for indexing job -- can be exposed in 'admin mode' in the UI + */ +export type BackendSpec = { + indexing: boolean; +}; diff --git a/telemetry/ui/src/api/models/IndexingJob.ts b/telemetry/ui/src/api/models/IndexingJob.ts new file mode 100644 index 00000000..e7e17734 --- /dev/null +++ b/telemetry/ui/src/api/models/IndexingJob.ts @@ -0,0 +1,15 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +/** + * Generic link for indexing job -- can be exposed in 'admin mode' in the UI + */ +export type IndexingJob = { + id: number; + start_time: string; + end_time: string | null; + status: string; + records_processed: number; + metadata: Record; +}; diff --git a/telemetry/ui/src/api/models/Project.ts b/telemetry/ui/src/api/models/Project.ts index 11cf7773..a9a5908c 100644 --- a/telemetry/ui/src/api/models/Project.ts +++ b/telemetry/ui/src/api/models/Project.ts @@ -3,11 +3,10 @@ /* tslint:disable */ /* eslint-disable */ export type Project = { - name: string; - id: string; - uri: string; - last_written: string; - created: string; - num_apps: number; + name: string; + id: string; + last_written: string; + created: string; + num_apps: number; + uri: string; }; - diff --git a/telemetry/ui/src/api/models/PromptInput.ts b/telemetry/ui/src/api/models/PromptInput.ts new file mode 100644 index 00000000..dc867a38 --- /dev/null +++ b/telemetry/ui/src/api/models/PromptInput.ts @@ -0,0 +1,7 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type PromptInput = { + prompt: string; +}; diff --git a/telemetry/ui/src/api/services/DefaultService.ts b/telemetry/ui/src/api/services/DefaultService.ts index 4c58539f..26dd2507 100644 --- a/telemetry/ui/src/api/services/DefaultService.ts +++ b/telemetry/ui/src/api/services/DefaultService.ts @@ -4,11 +4,14 @@ /* eslint-disable */ import type { ApplicationLogs } from '../models/ApplicationLogs'; import type { ApplicationSummary } from '../models/ApplicationSummary'; +import type { BackendSpec } from '../models/BackendSpec'; import type { ChatItem } from '../models/ChatItem'; import type { DraftInit } from '../models/DraftInit'; import type { EmailAssistantState } from '../models/EmailAssistantState'; import type { Feedback } from '../models/Feedback'; +import type { IndexingJob } from '../models/IndexingJob'; import type { Project } from '../models/Project'; +import type { PromptInput } from '../models/PromptInput'; import type { QuestionAnswers } from '../models/QuestionAnswers'; import type { CancelablePromise } from '../core/CancelablePromise'; import { OpenAPI } from '../core/OpenAPI'; @@ -96,6 +99,43 @@ export class DefaultService { url: '/api/v0/ready' }); } + /** + * Get App Spec + * @returns BackendSpec Successful Response + * @throws ApiError + */ + public static getAppSpecApiV0MetadataAppSpecGet(): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v0/metadata/app_spec' + }); + } + /** + * Get Indexing Jobs + * @param offset + * @param limit + * @param filterEmpty + * @returns IndexingJob Successful Response + * @throws ApiError + */ + public static getIndexingJobsApiV0IndexingJobsGet( + offset?: number, + limit: number = 100, + filterEmpty: boolean = true + ): CancelablePromise> { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v0/indexing_jobs', + query: { + offset: offset, + limit: limit, + filter_empty: filterEmpty + }, + errors: { + 422: `Validation Error` + } + }); + } /** * Version * Returns the burr version @@ -373,14 +413,14 @@ export class DefaultService { * :return: * @param projectId * @param appId - * @param prompt + * @param requestBody * @returns any Successful Response * @throws ApiError */ public static chatResponseApiV0StreamingChatbotResponseProjectIdAppIdPost( projectId: string, appId: string, - prompt: string + requestBody: PromptInput ): CancelablePromise { return __request(OpenAPI, { method: 'POST', @@ -389,9 +429,8 @@ export class DefaultService { project_id: projectId, app_id: appId }, - query: { - prompt: prompt - }, + body: requestBody, + mediaType: 'application/json', errors: { 422: `Validation Error` } diff --git a/telemetry/ui/src/components/nav/appcontainer.tsx b/telemetry/ui/src/components/nav/appcontainer.tsx index 70a4c817..1555fe8c 100644 --- a/telemetry/ui/src/components/nav/appcontainer.tsx +++ b/telemetry/ui/src/components/nav/appcontainer.tsx @@ -16,6 +16,8 @@ import { BreadCrumb } from './breadcrumb'; import { Link } from 'react-router-dom'; import { classNames } from '../../utils/tailwind'; import React from 'react'; +import { DefaultService } from '../../api'; +import { useQuery } from 'react-query'; // Define your GitHub logo SVG as a React component const GithubLogo = () => ( @@ -64,6 +66,11 @@ const ToggleOpenButton = (props: { open: boolean; toggleSidebar: () => void }) = export const AppContainer = (props: { children: React.ReactNode }) => { const [sidebarOpen, setSidebarOpen] = useState(true); const [smallSidebarOpen, setSmallSidebarOpen] = useState(false); + const { data: backendSpec } = useQuery(['backendSpec'], () => + DefaultService.getAppSpecApiV0MetadataAppSpecGet().then((response) => { + return response; + }) + ); const toggleSidebar = () => { setSidebarOpen(!sidebarOpen); }; @@ -129,6 +136,15 @@ export const AppContainer = (props: { children: React.ReactNode }) => { } ]; + if (backendSpec?.indexing) { + navigation.push({ + name: 'Admin', + href: '/admin', + icon: ListBulletIcon, + linkType: 'internal' + }); + } + const isCurrent = (href: string, linkType: string) => { if (linkType === 'external') { return false; diff --git a/telemetry/ui/src/components/routes/AdminView.tsx b/telemetry/ui/src/components/routes/AdminView.tsx new file mode 100644 index 00000000..61232c98 --- /dev/null +++ b/telemetry/ui/src/components/routes/AdminView.tsx @@ -0,0 +1,94 @@ +import { useQuery } from 'react-query'; +import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from '../common/table'; +import { DefaultService } from '../../api'; +import { Loading } from '../common/loading'; +import { DateTimeDisplay, DurationDisplay } from '../common/dates'; +import JsonView from '@uiw/react-json-view'; +import { useState } from 'react'; +import { FunnelIcon } from '@heroicons/react/24/outline'; + +const RecordsHeader = (props: { + displayZeroCount: boolean; + setDisplayZeroCount: (displayZeroCount: boolean) => void; +}) => { + const fillColor = props.displayZeroCount ? 'fill-gray-300' : 'fill-gray-700'; + const borderColor = props.displayZeroCount ? 'text-gray-300' : 'text-gray-700'; + return ( +
+ { + props.setDisplayZeroCount(!props.displayZeroCount); + }} + /> + Seq ID +
+ ); +}; + +/** + * Currently just shows indexing jobs, but we'll likely + * want to add more depending on whether the backend supports it. + * @returns A rendered admin view object + */ +export const AdminView = () => { + const [displayZeroCount, setDisplayZeroCount] = useState(false); + + const { data, isLoading } = useQuery(['indexingJobs', displayZeroCount], () => + DefaultService.getIndexingJobsApiV0IndexingJobsGet( + 0, // TODO -- add pagination + 100, + !displayZeroCount + ) + ); + if (isLoading) { + return ; + } + + return ( + + + + ID + Start Time + Duration + Status + + + + Metadata + + + + {data?.map((job) => { + return ( + + {job.id} + + {} + + + {job.end_time ? ( + + ) : ( + <> + )} + + {job.status.toLowerCase()} + {job.records_processed} + + + + + ); + })} + +
+ ); +}; diff --git a/telemetry/ui/src/components/routes/AppList.tsx b/telemetry/ui/src/components/routes/AppList.tsx index 9aa042bc..a87e522a 100644 --- a/telemetry/ui/src/components/routes/AppList.tsx +++ b/telemetry/ui/src/components/routes/AppList.tsx @@ -29,6 +29,10 @@ const StepCountHeader = (props: { ); }; +const isNullPartitionKey = (partitionKey: string | null) => { + return partitionKey === null || partitionKey === '__none__'; +}; + const getForkID = (app: ApplicationSummary) => { if (app.parent_pointer) { return app.parent_pointer.app_id; @@ -77,7 +81,9 @@ const AppSubList = (props: { }} > {props.displayPartitionKey && ( - {app.partition_key} + + {isNullPartitionKey(app.partition_key) ? '' : app.partition_key} + )}
@@ -174,7 +180,9 @@ export const AppListTable = (props: { apps: ApplicationSummary[]; projectId: str // Display the parents no matter what const rootAppsToDisplay = appsToDisplay.filter((app) => app.spawning_parent_pointer === null); - const anyHavePartitionKey = rootAppsToDisplay.some((app) => app.partition_key !== null); + const anyHavePartitionKey = rootAppsToDisplay.some( + (app) => !isNullPartitionKey(app.partition_key) + ); return ( From 1791faf14bd97eabf86d3447caf5bed90567b970 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Thu, 1 Aug 2024 02:52:13 -0700 Subject: [PATCH 06/11] Adds ability to save/load snapshot of DB This is a hack for the local version, soon we'll be using postgres/others and it will be less necessary. --- burr/cli/__main__.py | 2 +- burr/tracking/server/backend.py | 27 +++++++ burr/tracking/server/run.py | 63 +++++++++++----- burr/tracking/server/s3/backend.py | 116 ++++++++++++++++++++++++++--- burr/tracking/server/schema.py | 12 ++- 5 files changed, 189 insertions(+), 31 deletions(-) diff --git a/burr/cli/__main__.py b/burr/cli/__main__.py index a686933e..2f952278 100644 --- a/burr/cli/__main__.py +++ b/burr/cli/__main__.py @@ -130,7 +130,7 @@ def build_ui(): BACKEND_MODULES = { "local": "burr.tracking.server.backend.LocalBackend", - "s3": "burr.tracking.server.s3.backend.S3Backend", + "s3": "burr.tracking.server.s3.backend.SQLiteS3Backend", } diff --git a/burr/tracking/server/backend.py b/burr/tracking/server/backend.py index e520ed3b..e017b43e 100644 --- a/burr/tracking/server/backend.py +++ b/burr/tracking/server/backend.py @@ -64,8 +64,35 @@ async def indexing_jobs( pass +class SnapshottingBackendMixin(abc.ABC): + """Mixin for backend that conducts snapshotting -- e.g. saves + the data to a file or database.""" + + @abc.abstractmethod + async def load_snapshot(self): + """Loads the snapshot if it exists. + + :return: + """ + pass + + @abc.abstractmethod + async def snapshot(self): + """Snapshots the data""" + pass + + @abc.abstractmethod + def snapshot_interval_milliseconds(self) -> Optional[int]: + """Returns the snapshot interval in milliseconds""" + pass + + class BackendBase(abc.ABC): async def lifespan(self, app: FastAPI): + """Quick tool to allow plugin to the app's lifecycle. + This is fine given that it's an internal API, but if we open it up more + we should make this less flexible. For now this allows us to do clever + initializations in the right order.""" yield @abc.abstractmethod diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index c06931fe..bde037a5 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -9,7 +9,7 @@ from hamilton.log_setup import setup_logging from starlette import status -from burr.tracking.server.backend import BackendBase, IndexingBackendMixin +from burr.tracking.server.backend import BackendBase, IndexingBackendMixin, SnapshottingBackendMixin setup_logging(logging.INFO) @@ -52,11 +52,6 @@ SERVE_STATIC = os.getenv("BURR_SERVE_STATIC", "true").lower() == "true" -# TODO -- get based on a config -# backend = backend_module.LocalBackend() -# backend = s3_backend.S3Backend( -# bucket="burr-prod-test", -# ) backend = BackendBase.create_from_env() @@ -67,17 +62,46 @@ # TODO -- add a health check for intialization -async def update(): +async def sync_index(): if app_spec.indexing: - logger.info("Updating backend") + logger.info("Updating backend index...") await backend.update() + logger.info("Updated backend index...") + + +async def download_snapshot(): + if app_spec.snapshotting: + logger.info("Downloading snapshot of DB for backend to use") + await backend.load_snapshot() + logger.info("Downloaded snapshot of DB for backend to use") + + +first_snapshot = True + + +async def save_snapshot(): + # is_first is due to the weirdness of the repeat_every decorator + # It has to be called but we don't want this to run every time + # So we just skip the first + global first_snapshot + if first_snapshot: + first_snapshot = False + return + if app_spec.snapshotting: + logger.info("Saving snapshot of DB for recovery") + await backend.snapshot() + logger.info("Saved snapshot of DB for recovery") @asynccontextmanager async def lifespan(app: FastAPI): + # Download if it does it + # For now we do this before the lifespan + await download_snapshot() # No yield from allowed await backend.lifespan(app).__anext__() - await update() # this will trigger the repeat every N seconds + await sync_index() # this will trigger the repeat every N seconds + await save_snapshot() # this will trigger the repeat every N seconds yield await backend.lifespan(app).__anext__() @@ -88,7 +112,8 @@ async def lifespan(app: FastAPI): @app.get("/api/v0/metadata/app_spec", response_model=BackendSpec) def get_app_spec(): is_indexing_backend = isinstance(backend, IndexingBackendMixin) - return BackendSpec(indexing=is_indexing_backend) + is_snapshotting_backend = isinstance(backend, SnapshottingBackendMixin) + return BackendSpec(indexing=is_indexing_backend, snapshotting=is_snapshotting_backend) app_spec = get_app_spec() @@ -104,17 +129,21 @@ def get_app_spec(): if app_spec.indexing: update_interval = backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None - update = repeat_every( + sync_index = repeat_every( seconds=backend.update_interval_milliseconds() / 1000, wait_first=True, logger=logger, - )(update) - + )(sync_index) -@app.on_event("startup") -async def startup_event(): - if app_spec.indexing: - await update() +if app_spec.snapshotting: + snapshot_interval = ( + backend.snapshot_interval_milliseconds() / 1000 if app_spec.snapshotting else None + ) + save_snapshot = repeat_every( + seconds=backend.snapshot_interval_milliseconds() / 1000, + wait_first=True, + logger=logger, + )(save_snapshot) @app.get("/api/v0/projects", response_model=Sequence[schema.Project]) diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py index f87c94e3..7f21a57a 100644 --- a/burr/tracking/server/s3/backend.py +++ b/burr/tracking/server/s3/backend.py @@ -5,6 +5,7 @@ import json import logging import operator +import os.path import uuid from collections import Counter from typing import List, Literal, Optional, Sequence, Tuple, Type, TypeVar, Union @@ -20,7 +21,12 @@ from burr.tracking.common.models import ApplicationModel from burr.tracking.server import schema -from burr.tracking.server.backend import BackendBase, BurrSettings, IndexingBackendMixin +from burr.tracking.server.backend import ( + BackendBase, + BurrSettings, + IndexingBackendMixin, + SnapshottingBackendMixin, +) from burr.tracking.server.s3 import settings, utils from burr.tracking.server.s3.models import ( Application, @@ -113,19 +119,95 @@ class S3Settings(BurrSettings): s3_bucket: str update_interval_milliseconds: int = 60_000 aws_max_concurrency: int = 100 + snapshot_interval_milliseconds: int = 3_600_000 + load_snapshot_on_start: bool = True + prior_snapshots_to_keep: int = 5 -class S3Backend(BackendBase, IndexingBackendMixin): - @classmethod - def settings_model(cls) -> Type[BaseSettings]: - return S3Settings +def timestamp_to_reverse_alphabetical(timestamp: datetime) -> str: + # Get the inverse of the timestamp + epoch = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) + total_seconds = int((timestamp - epoch).total_seconds()) + + # Invert the seconds (latest timestamps become smallest values) + inverted_seconds = 2**32 - total_seconds + + # Convert the inverted seconds to a zero-padded string + inverted_str = str(inverted_seconds).zfill(10) - def __init__(self, s3_bucket: str, update_interval_milliseconds: int, aws_max_concurrency: int): - self._backend_id = datetime.datetime.utcnow().isoformat() + str(uuid.uuid4()) + return inverted_str + "-" + timestamp.isoformat() + + +class SQLiteS3Backend(BackendBase, IndexingBackendMixin, SnapshottingBackendMixin): + def __init__( + self, + s3_bucket: str, + update_interval_milliseconds: int, + aws_max_concurrency: int, + snapshot_interval_milliseconds: int, + load_snapshot_on_start: bool, + prior_snapshots_to_keep: int, + ): + self._backend_id = datetime.datetime.now(datetime.UTC).isoformat() + str(uuid.uuid4()) self._bucket = s3_bucket self._session = session.get_session() self._update_interval_milliseconds = update_interval_milliseconds self._aws_max_concurrency = aws_max_concurrency + self._snapshot_interval_milliseconds = snapshot_interval_milliseconds + self._data_prefix = "data" + self._snapshot_prefix = "snapshots" + self._load_snapshot_on_start = load_snapshot_on_start + self._snapshot_key_history = [] + self._prior_snapshots_to_keep = prior_snapshots_to_keep + + async def load_snapshot(self): + if not self._load_snapshot_on_start: + return + path = settings.DB_PATH + # if it already exists then return + if os.path.exists(path): + return + async with self._session.create_client("s3") as client: + objects = await client.list_objects_v2( + Bucket=self._bucket, Prefix=self._snapshot_prefix, MaxKeys=1 + ) + # nothing there + # TODO -- + if len(objects["Contents"]) == 0: + return + # get the latest snapshot -- it's organized by alphabetical order + latest_snapshot = objects["Contents"][0] + # download the snapshot + response = await client.get_object(Bucket=self._bucket, Key=latest_snapshot["Key"]) + async with response["Body"] as stream: + with open(path, "wb") as file: + file.write(await stream.read()) + + def snapshot_interval_milliseconds(self) -> Optional[int]: + return self._snapshot_interval_milliseconds + + @classmethod + def settings_model(cls) -> Type[BaseSettings]: + return S3Settings + + async def snapshot(self): + path = settings.DB_PATH + timestamp = timestamp_to_reverse_alphabetical(datetime.datetime.now(datetime.UTC)) + # latest + s3_key = f"{self._snapshot_prefix}/{timestamp}/{self._backend_id}/snapshot.db" + # TODO -- copy the path at snapshot_path to s3 using aiobotocore + session = self._session + logger.info(f"Saving db snapshot at: {s3_key}") + async with session.create_client("s3") as s3_client: + with open(path, "rb") as file_data: + await s3_client.put_object(Bucket=self._bucket, Key=s3_key, Body=file_data) + + self._snapshot_key_history.append(s3_key) + if len(self._snapshot_key_history) > 5: + old_snapshot_to_remove = self._snapshot_key_history.pop(0) + logger.info(f"Removing old snapshot: {old_snapshot_to_remove}") + async with session.create_client("s3") as s3_client: + await s3_client.delete_object(Bucket=self._bucket, Key=old_snapshot_to_remove) def update_interval_milliseconds(self) -> Optional[int]: return self._update_interval_milliseconds @@ -134,7 +216,10 @@ async def _s3_get_first_write_date(self, project_id: str): async with self._session.create_client("s3") as client: paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( - Bucket=self._bucket, Prefix=f"data/{project_id}/", Delimiter="/", MaxKeys=1 + Bucket=self._bucket, + Prefix=f"{self._data_prefix}/{project_id}/", + Delimiter="/", + MaxKeys=1, ): if "Contents" in result: first_object = result["Contents"][0] @@ -150,7 +235,7 @@ async def _update_projects(self): async with self._session.create_client("s3") as client: paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( - Bucket=self._bucket, Prefix="data/", Delimiter="/" + Bucket=self._bucket, Prefix=f"{self._data_prefix}/", Delimiter="/" ): for prefix in result.get("CommonPrefixes", []): project_name = prefix.get("Prefix").split("/")[-2] @@ -247,7 +332,7 @@ async def _gather_paths_to_update( paginator = client.get_paginator("list_objects_v2") async for result in paginator.paginate( Bucket=self._bucket, - Prefix=f"data/{project.name}/", + Prefix=f"{self._data_prefix}/{project.name}/", StartAfter=high_watermark_s3_path, ): for content in result.get("Contents", []): @@ -557,3 +642,14 @@ async def indexing_jobs( ) ) return out + + +if __name__ == "__main__": + os.environ["BURR_LOAD_SNAPSHOT_ON_START"] = "True" + import asyncio + + be = SQLiteS3Backend.from_settings(S3Settings()) + # coro = be.snapshot() # save to s3 + # asyncio.run(coro) + coro = be.load_snapshot() # load from s3 + asyncio.run(coro) diff --git a/burr/tracking/server/schema.py b/burr/tracking/server/schema.py index 950a911c..c707c99f 100644 --- a/burr/tracking/server/schema.py +++ b/burr/tracking/server/schema.py @@ -72,7 +72,7 @@ class Step(pydantic.BaseModel): @staticmethod def from_logs(log_lines: List[bytes]) -> List["Step"]: steps_by_sequence_id = collections.defaultdict(PartialStep) - spans_by_id = collections.defaultdict(Span) + spans_by_id = collections.defaultdict(PartialSpan) for line in log_lines: json_line = safe_json_load(line) # TODO -- make these into constants @@ -84,7 +84,7 @@ def from_logs(log_lines: List[bytes]) -> List["Step"]: steps_by_sequence_id[step_end_log.sequence_id].step_end_log = step_end_log elif json_line["type"] == "begin_span": span = BeginSpanModel.parse_obj(json_line) - spans_by_id[span.span_id] = Span( + spans_by_id[span.span_id] = PartialSpan( begin_entry=span, end_entry=None, ) @@ -102,7 +102,12 @@ def from_logs(log_lines: List[bytes]) -> List["Step"]: steps_by_sequence_id[sequence_id] if sequence_id in steps_by_sequence_id else None ) if step is not None: - step.spans.append(span) + if span.begin_entry is not None: + full_span = Span( + begin_entry=span.begin_entry, + end_entry=span.end_entry, + ) + step.spans.append(full_span) # filter out all the non-null start steps return [ Step( @@ -145,3 +150,4 @@ class BackendSpec(pydantic.BaseModel): """Generic link for indexing job -- can be exposed in 'admin mode' in the UI""" indexing: bool + snapshotting: bool From ce0f1b450ce236725862bba6fed5cc544c90a954 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Thu, 1 Aug 2024 17:41:20 -0700 Subject: [PATCH 07/11] Adds partition key to Burr URL You can now: 1. List all apps of a partition key 2. Navigate to a specific partition key Note that the file storage is still not distinct between partition keys. This will change stability of URL but that's OK for now. For null partition keys we just use __none__. --- burr/tracking/server/backend.py | 24 +++++++--- burr/tracking/server/run.py | 26 ++++++++--- burr/tracking/server/s3/backend.py | 28 +++++++++-- telemetry/ui/src/App.tsx | 3 +- telemetry/ui/src/api/models/BackendSpec.ts | 1 + .../ui/src/api/services/DefaultService.ts | 46 +++++++++++-------- .../ui/src/components/routes/AppList.tsx | 29 +++++++++--- .../ui/src/components/routes/app/AppView.tsx | 17 +++++-- telemetry/ui/src/examples/Common.tsx | 3 +- telemetry/ui/src/examples/EmailAssistant.tsx | 2 +- 10 files changed, 128 insertions(+), 51 deletions(-) diff --git a/burr/tracking/server/backend.py b/burr/tracking/server/backend.py index e017b43e..4beccdc0 100644 --- a/burr/tracking/server/backend.py +++ b/burr/tracking/server/backend.py @@ -106,20 +106,21 @@ async def list_projects(self, request: fastapi.Request) -> Sequence[schema.Proje @abc.abstractmethod async def list_apps( - self, request: fastapi.Request, project_id: str + self, request: fastapi.Request, project_id: str, partition_key: Optional[str] ) -> Sequence[schema.ApplicationSummary]: """Lists out all apps (continual state machine runs with shared state) for a given project. :param request: The request object, used for authentication/authorization if needed - :param project_id: - :return: + :param project_id: filter by project id + :param partition_key: filter by partition key + :return: A list of apps """ pass @abc.abstractmethod async def get_application_logs( - self, request: fastapi.Request, project_id: str, app_id: str - ) -> Sequence[schema.Step]: + self, request: fastapi.Request, project_id: str, app_id: str, partition_key: Optional[str] + ) -> ApplicationLogs: """Lists out all steps for a given app. :param request: The request object, used for authentication/authorization if needed @@ -228,7 +229,7 @@ async def _load_metadata(self, metadata_path: str) -> models.ApplicationMetadata return models.ApplicationMetadataModel() async def list_apps( - self, request: fastapi.Request, project_id: str + self, request: fastapi.Request, project_id: str, partition_key: Optional[str] ) -> Sequence[ApplicationSummary]: project_filepath = os.path.join(self.path, project_id) if not os.path.exists(project_filepath): @@ -244,6 +245,13 @@ async def list_apps( log_path = os.path.join(full_path, "log.jsonl") if os.path.isdir(full_path): metadata = await self._load_metadata(metadata_path) + app_partition_key = metadata.partition_key + # quick, hacky way to do it -- we should really have this be part of the path + # But we load it up anyway. TODO -- add partition key to the path + # If this is slow you'll want to use the s3-based storage system + # Which has an actual index + if partition_key is not None and partition_key != app_partition_key: + continue out.append( schema.ApplicationSummary( app_id=entry, @@ -259,8 +267,10 @@ async def list_apps( return out async def get_application_logs( - self, request: fastapi.Request, project_id: str, app_id: str + self, request: fastapi.Request, project_id: str, app_id: str, partition_key: Optional[str] ) -> ApplicationLogs: + # TODO -- handle partition key here + # This currently assumes uniqueness app_filepath = os.path.join(self.path, project_id, app_id) if not os.path.exists(app_filepath): raise fastapi.HTTPException( diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index bde037a5..3020dbf9 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -52,6 +52,8 @@ SERVE_STATIC = os.getenv("BURR_SERVE_STATIC", "true").lower() == "true" +SENTINEL_PARTITION_KEY = "__none__" + backend = BackendBase.create_from_env() @@ -156,19 +158,27 @@ async def get_projects(request: Request) -> Sequence[schema.Project]: return await backend.list_projects(request) -@app.get("/api/v0/{project_id}/apps", response_model=Sequence[schema.ApplicationSummary]) -async def get_apps(request: Request, project_id: str) -> Sequence[schema.ApplicationSummary]: +@app.get( + "/api/v0/{project_id}/{partition_key}/apps", response_model=Sequence[schema.ApplicationSummary] +) +async def get_apps( + request: Request, project_id: str, partition_key: str +) -> Sequence[schema.ApplicationSummary]: """Gets all apps visible by the user :param request: FastAPI request :param project_id: project name :return: a list of projects visible by the user """ - return await backend.list_apps(request, project_id) + if partition_key == SENTINEL_PARTITION_KEY: + partition_key = None + return await backend.list_apps(request, project_id, partition_key=partition_key) -@app.get("/api/v0/{project_id}/{app_id}/apps") -async def get_application_logs(request: Request, project_id: str, app_id: str) -> ApplicationLogs: +@app.get("/api/v0/{project_id}/{app_id}/{partition_key}/apps") +async def get_application_logs( + request: Request, project_id: str, app_id: str, partition_key: str +) -> ApplicationLogs: """Lists steps for a given App. TODO: add streaming capabilities for bi-directional communication TODO: add pagination for quicker loading @@ -178,7 +188,11 @@ async def get_application_logs(request: Request, project_id: str, app_id: str) - :param app_id: ID of the assIndociated application :return: A list of steps with all associated step data """ - return await backend.get_application_logs(request, project_id=project_id, app_id=app_id) + if partition_key == SENTINEL_PARTITION_KEY: + partition_key = None + return await backend.get_application_logs( + request, project_id=project_id, app_id=app_id, partition_key=partition_key + ) @app.get("/api/v0/ready") diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py index 7f21a57a..41fe847e 100644 --- a/burr/tracking/server/s3/backend.py +++ b/burr/tracking/server/s3/backend.py @@ -542,13 +542,24 @@ async def list_projects(self, request: fastapi.Request) -> Sequence[schema.Proje return out async def list_apps( - self, request: fastapi.Request, project_id: str, limit: int = 100, offset: int = 0 + self, + request: fastapi.Request, + project_id: str, + partition_key: Optional[str], + limit: int = 100, + offset: int = 0, ) -> Sequence[schema.ApplicationSummary]: # TODO -- distinctify between project name and project ID # Currently they're the same in the UI but we'll want to have them decoupled + app_query = ( + Application.filter(project__name=project_id) + if partition_key is None + else Application.filter(project__name=project_id, partition_key=partition_key) + ) + applications = ( - await Application.filter(project__name=project_id) - .annotate( + # Sentinel value for partition_key is __none__ -- passing it in required makes querying easier + await app_query.annotate( latest_logfile_created_at=functions.Max("log_files__created_at"), logfile_count=functions.Max("log_files__max_sequence_id"), ) @@ -577,10 +588,17 @@ async def list_apps( return out async def get_application_logs( - self, request: fastapi.Request, project_id: str, app_id: str + self, request: fastapi.Request, project_id: str, app_id: str, partition_key: str ) -> ApplicationLogs: # TODO -- handle partition keys - applications = await Application.filter(name=app_id, project__name=project_id).all() + query = ( + Application.filter(name=app_id, project__name=project_id) + if partition_key is None + else Application.filter( + name=app_id, project__name=project_id, partition_key=partition_key + ) + ) + applications = await query.all() application = applications[0] application_logs = await LogFile.filter(application__id=application.id).order_by( "-created_at" diff --git a/telemetry/ui/src/App.tsx b/telemetry/ui/src/App.tsx index 3f08361e..905b0f18 100644 --- a/telemetry/ui/src/App.tsx +++ b/telemetry/ui/src/App.tsx @@ -35,7 +35,8 @@ const App = () => { } /> } /> } /> - } /> + } /> + } /> } /> } /> } /> diff --git a/telemetry/ui/src/api/models/BackendSpec.ts b/telemetry/ui/src/api/models/BackendSpec.ts index ff240c5d..6f1a1058 100644 --- a/telemetry/ui/src/api/models/BackendSpec.ts +++ b/telemetry/ui/src/api/models/BackendSpec.ts @@ -7,4 +7,5 @@ */ export type BackendSpec = { indexing: boolean; + snapshotting: boolean; }; diff --git a/telemetry/ui/src/api/services/DefaultService.ts b/telemetry/ui/src/api/services/DefaultService.ts index 26dd2507..a3129212 100644 --- a/telemetry/ui/src/api/services/DefaultService.ts +++ b/telemetry/ui/src/api/services/DefaultService.ts @@ -17,6 +17,17 @@ import type { CancelablePromise } from '../core/CancelablePromise'; import { OpenAPI } from '../core/OpenAPI'; import { request as __request } from '../core/request'; export class DefaultService { + /** + * Get App Spec + * @returns BackendSpec Successful Response + * @throws ApiError + */ + public static getAppSpecApiV0MetadataAppSpecGet(): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v0/metadata/app_spec' + }); + } /** * Get Projects * Gets all projects visible by the user. @@ -40,17 +51,20 @@ export class DefaultService { * :param project_id: project name * :return: a list of projects visible by the user * @param projectId + * @param partitionKey * @returns ApplicationSummary Successful Response * @throws ApiError */ - public static getAppsApiV0ProjectIdAppsGet( - projectId: string + public static getAppsApiV0ProjectIdPartitionKeyAppsGet( + projectId: string, + partitionKey: string ): CancelablePromise> { return __request(OpenAPI, { method: 'GET', - url: '/api/v0/{project_id}/apps', + url: '/api/v0/{project_id}/{partition_key}/apps', path: { - project_id: projectId + project_id: projectId, + partition_key: partitionKey }, errors: { 422: `Validation Error` @@ -65,23 +79,26 @@ export class DefaultService { * * :param request: FastAPI * :param project_id: ID of the project - * :param app_id: ID of the associated application + * :param app_id: ID of the assIndociated application * :return: A list of steps with all associated step data * @param projectId * @param appId + * @param partitionKey * @returns ApplicationLogs Successful Response * @throws ApiError */ - public static getApplicationLogsApiV0ProjectIdAppIdAppsGet( + public static getApplicationLogsApiV0ProjectIdAppIdPartitionKeyAppsGet( projectId: string, - appId: string + appId: string, + partitionKey: string ): CancelablePromise { return __request(OpenAPI, { method: 'GET', - url: '/api/v0/{project_id}/{app_id}/apps', + url: '/api/v0/{project_id}/{app_id}/{partition_key}/apps', path: { project_id: projectId, - app_id: appId + app_id: appId, + partition_key: partitionKey }, errors: { 422: `Validation Error` @@ -99,17 +116,6 @@ export class DefaultService { url: '/api/v0/ready' }); } - /** - * Get App Spec - * @returns BackendSpec Successful Response - * @throws ApiError - */ - public static getAppSpecApiV0MetadataAppSpecGet(): CancelablePromise { - return __request(OpenAPI, { - method: 'GET', - url: '/api/v0/metadata/app_spec' - }); - } /** * Get Indexing Jobs * @param offset diff --git a/telemetry/ui/src/components/routes/AppList.tsx b/telemetry/ui/src/components/routes/AppList.tsx index a87e522a..ae16ef3d 100644 --- a/telemetry/ui/src/components/routes/AppList.tsx +++ b/telemetry/ui/src/components/routes/AppList.tsx @@ -6,7 +6,7 @@ import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from '. import { DateTimeDisplay } from '../common/dates'; import { useState } from 'react'; import { FunnelIcon, MinusIcon, PlusIcon } from '@heroicons/react/24/outline'; -import { useNavigate } from 'react-router-dom'; +import { Link, useNavigate } from 'react-router-dom'; import { MdForkRight } from 'react-icons/md'; import { RiCornerDownRightLine } from 'react-icons/ri'; @@ -77,12 +77,25 @@ const AppSubList = (props: { key={props.app.app_id} className={`cursor-pointer ${isHighlighted ? 'bg-gray-50' : ''}`} onClick={() => { - props.navigate(`/project/${props.projectId}/${app.app_id}`); + props.navigate(`/project/${props.projectId}/${app.partition_key}/${app.app_id}`); }} > {props.displayPartitionKey && ( - {isNullPartitionKey(app.partition_key) ? '' : app.partition_key} + {isNullPartitionKey(app.partition_key) ? ( + <> + ) : ( + { + props.navigate(`/project/${props.projectId}/${app.partition_key}`); + e.stopPropagation(); + }} + > + {app.partition_key} + + )} )} @@ -223,10 +236,14 @@ export const AppListTable = (props: { apps: ApplicationSummary[]; projectId: str * List of applications. This fetches data from the BE and passes it to the table */ export const AppList = () => { - const { projectId } = useParams(); + const { projectId, partitionKey } = useParams(); const { data, error } = useQuery( - ['apps', projectId], - () => DefaultService.getAppsApiV0ProjectIdAppsGet(projectId as string), + ['apps', projectId, partitionKey], + () => + DefaultService.getAppsApiV0ProjectIdPartitionKeyAppsGet( + projectId as string, + partitionKey ? partitionKey : '__none__' + ), { enabled: projectId !== undefined } ); if (projectId === undefined) { diff --git a/telemetry/ui/src/components/routes/app/AppView.tsx b/telemetry/ui/src/components/routes/app/AppView.tsx index fb37883e..c7ed7620 100644 --- a/telemetry/ui/src/components/routes/app/AppView.tsx +++ b/telemetry/ui/src/components/routes/app/AppView.tsx @@ -90,6 +90,7 @@ const NUM_PREVIOUS_ACTIONS = 6; export const AppView = (props: { projectId: string; appId: string; + partitionKey?: string; orientation: 'stacked_vertical' | 'stacked_horizontal'; defaultAutoRefresh?: boolean; }) => { @@ -102,9 +103,10 @@ export const AppView = (props: { const { data, error } = useQuery( ['steps', appId], () => - DefaultService.getApplicationLogsApiV0ProjectIdAppIdAppsGet( + DefaultService.getApplicationLogsApiV0ProjectIdAppIdPartitionKeyAppsGet( projectId as string, - appId as string + appId as string, + props.partitionKey !== undefined ? props.partitionKey : '__none__' ), { refetchInterval: autoRefresh ? REFRESH_INTERVAL : false, @@ -250,9 +252,16 @@ export const AppView = (props: { }; export const AppViewContainer = () => { - const { projectId, appId } = useParams(); + const { projectId, appId, partitionKey } = useParams(); if (projectId === undefined || appId === undefined) { return
Invalid URL
; } - return ; + return ( + + ); }; diff --git a/telemetry/ui/src/examples/Common.tsx b/telemetry/ui/src/examples/Common.tsx index de5f5874..dc3f4cd1 100644 --- a/telemetry/ui/src/examples/Common.tsx +++ b/telemetry/ui/src/examples/Common.tsx @@ -51,7 +51,8 @@ export const ChatbotAppSelector = (props: { const { projectId, setApp } = props; const { data, refetch } = useQuery( ['apps', projectId], - () => DefaultService.getAppsApiV0ProjectIdAppsGet(projectId as string), + // TODO - use the right partition key + () => DefaultService.getAppsApiV0ProjectIdPartitionKeyAppsGet(projectId as string, '__none__'), { enabled: projectId !== undefined } ); const createAndUpdateMutation = useMutation( diff --git a/telemetry/ui/src/examples/EmailAssistant.tsx b/telemetry/ui/src/examples/EmailAssistant.tsx index e9c4a325..22a4e787 100644 --- a/telemetry/ui/src/examples/EmailAssistant.tsx +++ b/telemetry/ui/src/examples/EmailAssistant.tsx @@ -385,7 +385,7 @@ export const EmailAssistantAppSelector = (props: { const { projectId, setApp } = props; const { data, refetch } = useQuery( ['apps', projectId], - () => DefaultService.getAppsApiV0ProjectIdAppsGet(projectId as string), + () => DefaultService.getAppsApiV0ProjectIdPartitionKeyAppsGet(projectId as string, '__none__'), { enabled: projectId !== undefined } ); const createAndUpdateMutation = useMutation( From 754c10b837023c3a3dcdfb37fe95db5da7dc6de1 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Thu, 1 Aug 2024 21:18:26 -0700 Subject: [PATCH 08/11] WIP for deployment with docker --- burr/log_setup.py | 27 +++++++++++++ burr/tracking/server/run.py | 4 +- burr/tracking/server/s3/backend.py | 4 +- burr/tracking/server/s3/deployment/Dockerfile | 40 +++++++++++++++++++ burr/tracking/server/s3/deployment/nginx.conf | 17 ++++++++ examples/streaming-fastapi/application.py | 7 +--- pyproject.toml | 12 +++++- 7 files changed, 100 insertions(+), 11 deletions(-) create mode 100644 burr/log_setup.py create mode 100644 burr/tracking/server/s3/deployment/Dockerfile create mode 100644 burr/tracking/server/s3/deployment/nginx.conf diff --git a/burr/log_setup.py b/burr/log_setup.py new file mode 100644 index 00000000..85c9a11b --- /dev/null +++ b/burr/log_setup.py @@ -0,0 +1,27 @@ +import logging +import sys + +LOG_LEVELS = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, +} + + +# this is suboptimal but python has no public mapping of log names to levels + + +def setup_logging(log_level: int = logging.INFO): + """Helper function to setup logging to console. + :param log_level: Log level to use when logging + """ + root_logger = logging.getLogger("") # root logger + formatter = logging.Formatter("[%(levelname)s] %(asctime)s %(name)s(%(lineno)s): %(message)s") + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(formatter) + if not len(root_logger.handlers): + # assumes we have already been set up. + root_logger.addHandler(stream_handler) + root_logger.setLevel(log_level) diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index 3020dbf9..b289b099 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -5,10 +5,10 @@ from importlib.resources import files from typing import Sequence -# TODO -- remove this, just for testing -from hamilton.log_setup import setup_logging from starlette import status +# TODO -- remove this, just for testing +from burr.log_setup import setup_logging from burr.tracking.server.backend import BackendBase, IndexingBackendMixin, SnapshottingBackendMixin setup_logging(logging.INFO) diff --git a/burr/tracking/server/s3/backend.py b/burr/tracking/server/s3/backend.py index 41fe847e..e393ede9 100644 --- a/burr/tracking/server/s3/backend.py +++ b/burr/tracking/server/s3/backend.py @@ -167,12 +167,12 @@ async def load_snapshot(self): # if it already exists then return if os.path.exists(path): return + if not os.path.exists(os.path.dirname(path)): + os.makedirs(os.path.dirname(path)) async with self._session.create_client("s3") as client: objects = await client.list_objects_v2( Bucket=self._bucket, Prefix=self._snapshot_prefix, MaxKeys=1 ) - # nothing there - # TODO -- if len(objects["Contents"]) == 0: return # get the latest snapshot -- it's organized by alphabetical order diff --git a/burr/tracking/server/s3/deployment/Dockerfile b/burr/tracking/server/s3/deployment/Dockerfile new file mode 100644 index 00000000..ae601dfd --- /dev/null +++ b/burr/tracking/server/s3/deployment/Dockerfile @@ -0,0 +1,40 @@ +# Use an official Python runtime as a parent image +FROM python:3.11-slim + +# Set environment variables +ENV PYTHONUNBUFFERED=1 + +# Set working directory +WORKDIR /app + +# Copy the current directory contents into the container at /app +COPY . /app + +# Install dependencies and git +RUN apt-get update && apt-get install -y \ + git \ + nginx \ + && apt-get clean + +# Install the dependencies +# TODO -- use the right version +#RUN pip install "git+https://github.com/dagworks-inc/burr.git@tracker-s3#egg=burr[tracking-server-s3]" +RUN pip install "burr[tracking-server-s3]==0.26.0rc4" + +# Copy the nginx config file +COPY nginx.conf /etc/nginx/nginx.conf + +# Expose the port FastAPI will run on and the port NGINX will listen to +EXPOSE 8000 +EXPOSE 80 + +ENV BURR_S3_BUCKET=burr-prod-test +ENV BURR_load_snapshot_on_start=True +ENV BURR_snapshot_interval_milliseconds=3_600_000 +ENV BURR_BACKEND_IMPL=s3 +ENV ENV DEBIAN_FRONTEND=noninteractive +ENV BURR_BACKEND_IMPL=burr.tracking.server.s3.backend.SQLiteS3Backend + + +# Command to run FastAPI server and NGINX +CMD ["sh", "-c", "uvicorn burr.tracking.server.run:app --host 0.0.0.0 --port 8000 & nginx -g 'daemon off;'"] diff --git a/burr/tracking/server/s3/deployment/nginx.conf b/burr/tracking/server/s3/deployment/nginx.conf new file mode 100644 index 00000000..0a447a06 --- /dev/null +++ b/burr/tracking/server/s3/deployment/nginx.conf @@ -0,0 +1,17 @@ +events { + worker_connections 1024; +} + +http { + server { + listen 80; + + location / { + proxy_pass http://127.0.0.1:8000; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + } + } +} diff --git a/examples/streaming-fastapi/application.py b/examples/streaming-fastapi/application.py index 63a84d29..e8c4d48f 100644 --- a/examples/streaming-fastapi/application.py +++ b/examples/streaming-fastapi/application.py @@ -7,7 +7,6 @@ from burr.core import ApplicationBuilder, State, default, when from burr.core.action import action, streaming_action from burr.core.graph import GraphBuilder -from burr.tracking.s3client import S3TrackingClient MODES = [ "answer_question", @@ -175,11 +174,7 @@ def application(app_id: Optional[str] = None): .with_state(chat_history=[]) .with_graph(graph) # .with_tracker(project="demo_chatbot_streaming") - .with_tracker( - tracker=S3TrackingClient( - bucket="burr-prod-test", project="demo_chatbot_streaming", non_blocking=True - ) - ) + .with_tracker(project="demo_chatbot_streaming") .with_identifiers(app_id=app_id) .build() ) diff --git a/pyproject.toml b/pyproject.toml index f55e6419..2addc13e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "burr" -version = "0.25.0" +version = "0.26.0rc4" dependencies = [] # yes, there are none requires-python = ">=3.9" authors = [ @@ -85,6 +85,16 @@ tracking-client-s3 = [ "aiobotocore" ] +tracking-server-s3 = [ + "aerich", + "aiobotocore", + "fastapi-utils", + "fastapi", + "tortoise-orm[accel, asyncmy]", + "burr[tracking-server]", + "typing-inspect" +] + tracking-server = [ "click", "fastapi", From 351732816d48ce764a55d2e1f30c3ac7ff461fd0 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Fri, 2 Aug 2024 22:50:21 -0700 Subject: [PATCH 09/11] PR cleanup --- burr/cli/__main__.py | 1 + burr/cli/demo_data.py | 6 ------ burr/tracking/server/run.py | 10 ---------- 3 files changed, 1 insertion(+), 16 deletions(-) diff --git a/burr/cli/__main__.py b/burr/cli/__main__.py index 2f952278..8607aa96 100644 --- a/burr/cli/__main__.py +++ b/burr/cli/__main__.py @@ -28,6 +28,7 @@ ) +# Quick trick to use loguru for everything so it's all the same color class InterceptHandler(logging.Handler): def emit(self, record): # Get corresponding Loguru level if it exists diff --git a/burr/cli/demo_data.py b/burr/cli/demo_data.py index 25300931..1d15f278 100644 --- a/burr/cli/demo_data.py +++ b/burr/cli/demo_data.py @@ -130,12 +130,6 @@ def generate_counter_data( .with_entrypoint("counter") .build() ) - # app = counter_application.application( - # count_up_to=count, - # app_id=f"count-to-{count}", - # storage_dir=data_dir, - # partition_key=f"user_{i}", - # ) app.run(halt_after=["result"]) diff --git a/burr/tracking/server/run.py b/burr/tracking/server/run.py index b289b099..51d3a940 100644 --- a/burr/tracking/server/run.py +++ b/burr/tracking/server/run.py @@ -23,9 +23,6 @@ from starlette.templating import Jinja2Templates from burr.tracking.server import schema - - # from burr.tracking.server import backend as backend_module - # from burr.tracking.server.s3 import backend as s3_backend from burr.tracking.server.schema import ApplicationLogs, BackendSpec, IndexingJob # dynamic importing due to the dashes (which make reading the examples on github easier) @@ -122,13 +119,6 @@ def get_app_spec(): logger = logging.getLogger(__name__) -# @repeat_every( -# seconds=update_interval if update_interval is not None else float("inf"), -# wait_first=True, -# logger=logger, -# ) - - if app_spec.indexing: update_interval = backend.update_interval_milliseconds() / 1000 if app_spec.indexing else None sync_index = repeat_every( From 9c62257071df43dc2d18e7dbf619bb0944a6efe3 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Sun, 4 Aug 2024 14:25:40 -0700 Subject: [PATCH 10/11] Adds indicator for null primary key in breadcrumb --- .../ui/src/components/nav/breadcrumb.tsx | 50 +++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/telemetry/ui/src/components/nav/breadcrumb.tsx b/telemetry/ui/src/components/nav/breadcrumb.tsx index 7487f212..1243fdf3 100644 --- a/telemetry/ui/src/components/nav/breadcrumb.tsx +++ b/telemetry/ui/src/components/nav/breadcrumb.tsx @@ -33,27 +33,35 @@ export const BreadCrumb = () => { - {pages.map((page) => ( -
  • -
    - - - {page.name} - -
    -
  • - ))} + {pages.map((page, index) => { + // Quick trick to catch null primary keys + const isNullPK = page.name === 'null' && index === 2; + return ( +
  • +
    + + {isNullPK ? ( + no primary key + ) : ( + + {page.name} + + )} +
    +
  • + ); + })} ); From aa2adbc9c8308c995bffaaaf99464ab3f4a0c980 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Sun, 4 Aug 2024 17:25:18 -0700 Subject: [PATCH 11/11] Removes unecessary logging code from CLI --- burr/cli/__main__.py | 22 ++-------------------- examples/streaming-fastapi/application.py | 1 - 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/burr/cli/__main__.py b/burr/cli/__main__.py index 8607aa96..1d9ab69a 100644 --- a/burr/cli/__main__.py +++ b/burr/cli/__main__.py @@ -15,6 +15,7 @@ from burr import system, telemetry from burr.core.persistence import PersistedStateData from burr.integrations.base import require_plugin +from burr.log_setup import setup_logging try: import click @@ -27,27 +28,8 @@ "start", ) - -# Quick trick to use loguru for everything so it's all the same color -class InterceptHandler(logging.Handler): - def emit(self, record): - # Get corresponding Loguru level if it exists - try: - level = logger.level(record.levelname).name - except ValueError: - level = record.levelno - - # Find caller from where originated the log message - frame, depth = logging.currentframe(), 2 - while frame.f_code.co_filename == logging.__file__: - frame = frame.f_back - depth += 1 - - logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) - - # Clear default handlers -logging.basicConfig(handlers=[InterceptHandler()], level=logging.INFO) +setup_logging(logging.INFO) # TODO -- add this as a general callback to the CLI diff --git a/examples/streaming-fastapi/application.py b/examples/streaming-fastapi/application.py index e8c4d48f..8e90df8e 100644 --- a/examples/streaming-fastapi/application.py +++ b/examples/streaming-fastapi/application.py @@ -173,7 +173,6 @@ def application(app_id: Optional[str] = None): .with_entrypoint("prompt") .with_state(chat_history=[]) .with_graph(graph) - # .with_tracker(project="demo_chatbot_streaming") .with_tracker(project="demo_chatbot_streaming") .with_identifiers(app_id=app_id) .build()