Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prod Burr using S3-backed API, initial scaffolding/implementation #288

Merged
merged 11 commits into from
Aug 5, 2024
63 changes: 52 additions & 11 deletions burr/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import json
import logging
import os
import shutil
import subprocess
Expand All @@ -14,6 +15,7 @@
from burr import system, telemetry
from burr.core.persistence import PersistedStateData
from burr.integrations.base import require_plugin
from burr.log_setup import setup_logging

try:
import click
Expand All @@ -26,30 +28,37 @@
"start",
)

# Clear default handlers
setup_logging(logging.INFO)


# TODO -- add this as a general callback to the CLI
def _telemetry_if_enabled(event: str):
if telemetry.is_telemetry_enabled():
telemetry.create_and_send_cli_event(event)


def _command(command: str, capture_output: bool) -> str:
def _command(command: str, capture_output: bool, addl_env: dict = None) -> str:
"""Runs a simple command"""
if addl_env is None:
addl_env = {}
env = os.environ.copy()
env.update(addl_env)
logger.info(f"Running command: {command}")
if isinstance(command, str):
command = command.split(" ")
if capture_output:
try:
return (
subprocess.check_output(command, stderr=subprocess.PIPE, shell=False)
subprocess.check_output(command, stderr=subprocess.PIPE, shell=False, env=env)
.decode()
.strip()
)
except subprocess.CalledProcessError as e:
print(e.stdout.decode())
print(e.stderr.decode())
raise e
subprocess.run(command, shell=False, check=True)
subprocess.run(command, shell=False, check=True, env=env)


def _get_git_root() -> str:
Expand Down Expand Up @@ -102,13 +111,20 @@ def build_ui():
_build_ui()


BACKEND_MODULES = {
"local": "burr.tracking.server.backend.LocalBackend",
"s3": "burr.tracking.server.s3.backend.SQLiteS3Backend",
}


def _run_server(
port: int,
dev_mode: bool,
no_open: bool,
no_copy_demo_data: bool,
initial_page="",
host: str = "127.0.0.1",
backend: str = "local",
):
_telemetry_if_enabled("run_server")
# TODO: Implement server running logic here
Expand Down Expand Up @@ -142,7 +158,10 @@ def _run_server(
daemon=True,
)
thread.start()
_command(cmd, capture_output=False)
env = {
"BURR_BACKEND_IMPL": BACKEND_MODULES[backend],
}
_command(cmd, capture_output=False, addl_env=env)


@cli.command()
Expand All @@ -156,8 +175,16 @@ def _run_server(
help="Host to run the server on -- use 0.0.0.0 if you want "
"to expose it to the network (E.G. in a docker image)",
)
def run_server(port: int, dev_mode: bool, no_open: bool, no_copy_demo_data: bool, host: str):
_run_server(port, dev_mode, no_open, no_copy_demo_data, host=host)
@click.option(
"--backend",
default="local",
help="Backend to use for the server.",
type=click.Choice(["local", "s3"]),
)
def run_server(
port: int, dev_mode: bool, no_open: bool, no_copy_demo_data: bool, host: str, backend: str
):
_run_server(port, dev_mode, no_open, no_copy_demo_data, host=host, backend=backend)


@cli.command()
Expand Down Expand Up @@ -186,18 +213,32 @@ def build_and_publish(prod: bool, no_wipe_dist: bool):


@cli.command(help="generate demo data for the UI")
def generate_demo_data():
@click.option(
"--s3-bucket", help="S3 URI to save to, will use the s3 tracker, not local mode", required=False
)
@click.option(
"--data-dir",
help="Local directory to save to",
required=False,
default="burr/tracking/server/demo_data",
)
@click.option("--unique-app-names", help="Use unique app names", is_flag=True)
def generate_demo_data(s3_bucket, data_dir, unique_app_names: bool):
_telemetry_if_enabled("generate_demo_data")
git_root = _get_git_root()
# We need to add the examples directory to the path so we have all the imports
# The GPT-one relies on a local import
sys.path.extend([git_root, f"{git_root}/examples/multi-modal-chatbot"])
from burr.cli.demo_data import generate_all

with cd(git_root):
logger.info("Removing old demo data")
shutil.rmtree("burr/tracking/server/demo_data", ignore_errors=True)
generate_all("burr/tracking/server/demo_data")
# local mode
if s3_bucket is None:
with cd(git_root):
logger.info("Removing old demo data")
shutil.rmtree(data_dir, ignore_errors=True)
generate_all(data_dir=data_dir, unique_app_names=unique_app_names)
else:
generate_all(s3_bucket=s3_bucket, unique_app_names=unique_app_names)


def _transform_state_to_test_case(state: dict, action_name: str, test_name: str) -> dict:
Expand Down
124 changes: 101 additions & 23 deletions burr/cli/demo_data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
import importlib
import logging
import os
import uuid
from typing import Optional

from application import logger
from burr.core import ApplicationBuilder, Result, default, expr
from burr.core.graph import GraphBuilder
from burr.tracking import LocalTrackingClient
from burr.tracking.s3client import S3TrackingClient

conversational_rag_application = importlib.import_module("examples.conversational-rag.application")
logger = logging.getLogger(__name__)

conversational_rag_application = importlib.import_module(
"examples.conversational-rag.simple_example.application"
)
counter_application = importlib.import_module("examples.hello-world-counter.application")
chatbot_application = importlib.import_module("examples.multi-modal-chatbot.application")
chatbot_application_with_traces = importlib.import_module("examples.tracing-and-spans.application")


def generate_chatbot_data(data_dir: str, use_traces: bool):
def generate_chatbot_data(
data_dir: Optional[str] = None,
s3_bucket: Optional[str] = None,
use_traces: bool = False,
unique_app_names: bool = False,
):
project_id = "demo_chatbot" if not use_traces else "demo_chatbot_with_traces"
run_prefix = str(uuid.uuid4())[0:8] + "-" if unique_app_names else ""
working_conversations = {
"chat-1-giraffe": [
"Please draw a giraffe.", # Answered by the image mode
Expand Down Expand Up @@ -44,39 +61,81 @@ def generate_chatbot_data(data_dir: str, use_traces: bool):
}
broken_conversations = {"chat-6-demonstrate-errors": working_conversations["chat-1-giraffe"]}

def _modify(app_id: str) -> str:
return run_prefix + app_id

def _run_conversation(app_id, prompts):
app = (chatbot_application_with_traces if use_traces else chatbot_application).application(
app_id=app_id,
storage_dir=data_dir,
tracker = (
LocalTrackingClient(project=project_id, storage_dir=data_dir)
if not s3_bucket
else S3TrackingClient(project=project_id, bucket=s3_bucket)
)
graph = (chatbot_application_with_traces if use_traces else chatbot_application).graph
app = (
ApplicationBuilder()
.with_graph(graph)
.with_identifiers(app_id=app_id)
.with_tracker(tracker)
.with_entrypoint("prompt")
.build()
)
for prompt in prompts:
app.run(halt_after=["response"], inputs={"prompt": prompt})

for app_id, prompts in sorted(working_conversations.items()):
_run_conversation(app_id, prompts)
_run_conversation(_modify(app_id), prompts)
old_api_key = os.environ.get("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = "fake"
for app_id, prompts in sorted(broken_conversations.items()):
try:
_run_conversation(app_id, prompts)
_run_conversation(_modify(app_id), prompts)
except Exception as e:
print(f"Got an exception: {e}")
os.environ["OPENAI_API_KEY"] = old_api_key


def generate_counter_data(data_dir: str = "~/.burr"):
def generate_counter_data(
data_dir: str = "~/.burr", s3_bucket: Optional[str] = None, unique_app_names: bool = False
):
counter = counter_application.counter
tracker = (
LocalTrackingClient(project="demo_counter", storage_dir=data_dir)
if not s3_bucket
else S3TrackingClient(project="demo_counter", bucket=s3_bucket)
)

counts = [1, 10, 100, 50, 42]
# This is just cause we don't want to change the code
# TODO -- add ability to grab graph from application or something like that
graph = (
GraphBuilder()
.with_actions(counter=counter, result=Result("counter"))
.with_transitions(
("counter", "counter", expr("counter < count_to")),
("counter", "result", default),
)
.build()
)
for i, count in enumerate(counts):
app = counter_application.application(
count_up_to=count,
app_id=f"count-to-{count}",
storage_dir=data_dir,
partition_key=f"user_{i}",
app_id = f"count-to-{count}"
if unique_app_names:
suffix = str(uuid.uuid4())[0:8]
app_id = f"{app_id}-{suffix}"
app = (
ApplicationBuilder()
.with_graph(graph)
.with_identifiers(app_id=app_id, partition_key=f"user_{i}")
.with_state(count_to=count, counter=0)
.with_tracker(tracker)
.with_entrypoint("counter")
.build()
)
app.run(halt_after=["result"])


def generate_rag_data(data_dir: str = "~/.burr"):
def generate_rag_data(
data_dir: Optional[str] = None, s3_bucket: Optional[str] = None, unique_app_names: bool = False
):
conversations = {
"rag-1-food": [
"What is Elijah's favorite food?",
Expand Down Expand Up @@ -105,24 +164,43 @@ def generate_rag_data(data_dir: str = "~/.burr"):
"Whose favorite food is better, Elijah's or Stefan's?" "exit",
],
}
prefix = str(uuid.uuid4())[0:8] + "-" if unique_app_names else ""
for app_id, prompts in sorted(conversations.items()):
app = conversational_rag_application.application(
app_id=app_id,
storage_dir=data_dir,
graph = conversational_rag_application.graph()
tracker = (
LocalTrackingClient(project="demo_conversational-rag", storage_dir=data_dir)
if not s3_bucket
else S3TrackingClient(project="demo_conversational-rag", bucket=s3_bucket)
)
app_id = f"{prefix}{app_id}" if unique_app_names else app_id
app = (
ApplicationBuilder()
.with_graph(graph)
.with_identifiers(app_id=app_id)
.with_tracker(tracker)
.with_entrypoint("human_converse")
.build()
)
logger.warning(f"Running {app_id}...")
for prompt in prompts:
app.run(halt_after=["ai_converse", "terminal"], inputs={"user_question": prompt})


def generate_all(data_dir: str):
def generate_all(
data_dir: Optional[str] = None, s3_bucket: Optional[str] = None, unique_app_names: bool = False
):
logger.info("Generating chatbot data")
generate_chatbot_data(data_dir, False)
generate_chatbot_data(
data_dir=data_dir, s3_bucket=s3_bucket, use_traces=False, unique_app_names=unique_app_names
)
logger.info("Generating chatbot data with traces")
generate_chatbot_data(data_dir, True)
generate_chatbot_data(
data_dir=data_dir, s3_bucket=s3_bucket, use_traces=True, unique_app_names=unique_app_names
)
logger.info("Generating counter data")
generate_counter_data(data_dir)
generate_counter_data(data_dir=data_dir, s3_bucket=s3_bucket, unique_app_names=unique_app_names)
logger.info("Generating RAG data")
generate_rag_data(data_dir)
generate_rag_data(data_dir=data_dir, s3_bucket=s3_bucket, unique_app_names=unique_app_names)


#
5 changes: 4 additions & 1 deletion burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,10 @@ def __init__(
self._parent_pointer = fork_parent_pointer
self.dependency_factory = {
"__tracer": functools.partial(
visibility.tracing.TracerFactory, lifecycle_adapters=self._adapter_set
visibility.tracing.TracerFactory,
lifecycle_adapters=self._adapter_set,
app_id=self._uid,
partition_key=self._partition_key,
),
"__context": self._context_factory,
}
Expand Down
8 changes: 8 additions & 0 deletions burr/lifecycle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def pre_start_span(
action_sequence_id: int,
span: "ActionSpan",
span_dependencies: list[str],
app_id: str,
partition_key: Optional[str],
**future_kwargs: Any,
):
pass
Expand All @@ -182,6 +184,8 @@ async def pre_start_span(
action_sequence_id: int,
span: "ActionSpan",
span_dependencies: list[str],
app_id: str,
partition_key: Optional[str],
**future_kwargs: Any,
):
pass
Expand All @@ -200,6 +204,8 @@ def post_end_span(
action_sequence_id: int,
span: "ActionSpan",
span_dependencies: list[str],
app_id: str,
partition_key: Optional[str],
**future_kwargs: Any,
):
pass
Expand All @@ -215,6 +221,8 @@ async def post_end_span(
action_sequence_id: int,
span: "ActionSpan",
span_dependencies: list[str],
app_id: str,
partition_key: Optional[str],
**future_kwargs: Any,
):
pass
Expand Down
Loading
Loading