diff --git a/burr/integrations/playground/__init__.py b/burr/integrations/playground/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/burr/integrations/playground/app.py b/burr/integrations/playground/app.py new file mode 100644 index 00000000..69669243 --- /dev/null +++ b/burr/integrations/playground/app.py @@ -0,0 +1,303 @@ +import litellm +import streamlit as st + +from burr.core import Application, ApplicationBuilder, State, action +from burr.integrations.streamlit import ( + application_selectbox, + get_steps, + project_selectbox, + step_selectbox, +) +from burr.tracking import LocalTrackingClient +from burr.tracking.server.backend import LocalBackend +from burr.visibility import TracerFactory + + +@st.cache_data +def instrument(provider: str): + msg = None + if provider == "openai": + try: + from opentelemetry.instrumentation.openai import ( # openai is a dependency of litellm + OpenAIInstrumentor, + ) + + OpenAIInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "cohere": + try: + from opentelemetry.instrumentation.cohere import CohereInstrumentor + + CohereInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "anthropic": + try: + from opentelemetry.instrumentation.anthropic import AnthropicInstrumentor + + AnthropicInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "bedrock": + try: + from opentelemetry.instrumentation.bedrock import BedrockInstrumentor + + BedrockInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "ollama": + try: + from opentelemetry.instrumentation.ollama import OllamaInstrumentor + + OllamaInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "gemini": + try: + from opentelemetry.instrumentation.google_generativeai import ( + GoogleGenerativeAiInstrumentor, + ) + + GoogleGenerativeAiInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "replicate": + try: + from opentelemetry.instrumentation.replicate import ReplicateInstrumentor + + ReplicateInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "together_ai": + try: + from opentelemetry.instrumentation.together import TogetherAiInstrumentor + + TogetherAiInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "replicate": + try: + from opentelemetry.instrumentation.replicate import ReplicateInstrumentor + + ReplicateInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "huggingface": + try: + from opentelemetry.instrumentation.transformers import TransformersInstrumentor + + TransformersInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "vertex_ai": + try: + from opentelemetry.instrumentation.vertexai import VertexAIInstrumentor + + VertexAIInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "watsonx": + try: + from opentelemetry.instrumentation.watsonx import WatsonxInstrumentor + + WatsonxInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + else: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + if msg: + print(msg) + return msg + + +@action(reads=["history"], writes=["history"]) +def generate_answer( + state: State, model: str, messages: list[dict], __tracer: TracerFactory, msg_to_log=None +): + if msg_to_log: + __tracer.log_attribute("message", msg_to_log) + + response = litellm.completion(model=model, messages=messages) + llm_answer = response.choices[0].message.content + + history = state["history"] + if history.get(model) is None: + history[model] = [] + + history[model] += [llm_answer] + return state.update(history=history) + + +def build_burr_app(source_project: str) -> Application: + tracker = LocalTrackingClient(project="burr-playground") + return ( + ApplicationBuilder() + .with_actions(generate_answer) + .with_transitions(("generate_answer", "generate_answer")) + .with_identifiers(app_id=source_project) + .initialize_from( + initializer=tracker, + resume_at_next_action=False, + default_state={"history": {}}, + default_entrypoint="generate_answer", + ) + .with_tracker("local", project="burr-playground", use_otel_tracing=True) + .build() + ) + + +@st.cache_resource +def get_burr_backend(): + return LocalBackend() + + +def normalize_spans(spans: list) -> dict: + nested_dict = {} + for span in spans: + key = span.key + value = span.value + + keys = key.split(".") + d = nested_dict + for k in keys[:-1]: + if k not in d: + d[k] = {} + d = d[k] + d[keys[-1]] = value + return nested_dict + + +def history_component(normalized_spans): + # historical data + with st.expander("History", expanded=True): + prompts = normalized_spans["gen_ai"]["prompt"].values() + answer = normalized_spans["gen_ai"]["completion"].values() + + for message in list(prompts) + list(answer): + role = message.get("role", "assistant") + with st.chat_message(role): + st.markdown(message["content"]) + + +def launcher_component(idx, default_provider=0): + """Menu to select LLM provider and selected model""" + selected_provider = st.selectbox( + "Provider", + options=litellm.models_by_provider.keys(), + index=default_provider, + key=f"provider_{idx}", + ) + selected_model = st.selectbox( + "Model", + options=litellm.models_by_provider[selected_provider], + index=None, + key=f"model_{idx}", + ) + st.session_state[f"selected_provider_{idx}"] = selected_provider + st.session_state[f"selected_model_{idx}"] = selected_model + + +def get_llm_spans(step): + chat_span_ids = set() + for span in step.spans: + if ".chat" in span.begin_entry.span_name: + chat_span_ids.add(span.begin_entry.span_id) + + return [attr for attr in step.attributes if attr.span_id in chat_span_ids] + + +def frontend(): + st.title("🌯 Burr prompt playground") + backend = get_burr_backend() + + # default value; is overriden at the end of `with st.sidebar:` + normalized_spans = None + with st.sidebar: + st.header("Burr playground") + + # project selection + selected_project = project_selectbox(backend=backend) + selected_app = application_selectbox(project=selected_project, backend=backend) + steps = get_steps(project=selected_project, application=selected_app, backend=backend) + + steps_with_llms = [ + step + for step in steps + if any(span for span in step.spans if len(get_llm_spans(step)) > 0) + ] + if len(steps_with_llms) == 0: + st.warning("Select a `Project > Application > Step` that includes LLM requests") + return + + selected_step = step_selectbox(steps=steps_with_llms) + relevant_spans = get_llm_spans(selected_step) + normalized_spans = normalize_spans(relevant_spans) + + # main window + st.header( + selected_project.name + + " : " + + selected_app.app_id[:10] + + " : " + + str(selected_step.step_start_log.sequence_id) + + "-" + + selected_step.step_start_log.action + ) + + history_component(normalized_spans) + + messages = list(normalized_spans["gen_ai"]["prompt"].values()) + + left, right = st.columns([0.85, 0.15]) + with left: + new_prompt = st.text_area("Prompt", value=messages[-1]["content"], height=300) + + launcher_0, launcher_1, launcher_2 = st.columns(3) + with launcher_0: + launcher_component(0, default_provider=0) + placeholder_0 = st.empty() + + with launcher_1: + launcher_component(1, default_provider=3) + placeholder_1 = st.empty() + + with launcher_2: + launcher_component(2, default_provider=1) + placeholder_2 = st.empty() + + with right: + if st.button("Launch"): + for i in range(2): + model = st.session_state.get(f"selected_model_{i}") + if model is None: + continue + + instrumentation_msg = instrument(st.session_state[f"selected_provider_{i}"]) + burr_app = build_burr_app(source_project=selected_project.name) + _, _, state = burr_app.step( + inputs={ + "model": model, + "messages": messages[:-1] + [{"role": "user", "content": new_prompt}], + "msg_to_log": instrumentation_msg, + } + ) + + locals()[f"placeholder_{i}"].container().write(state["history"][model][-1]) + + +if __name__ == "__main__": + frontend() diff --git a/burr/integrations/streamlit.py b/burr/integrations/streamlit.py index d734aae3..957fb313 100644 --- a/burr/integrations/streamlit.py +++ b/burr/integrations/streamlit.py @@ -1,13 +1,22 @@ +import asyncio import colorsys import dataclasses import inspect import json -from typing import List, Optional +from typing import List, Optional, Sequence from burr.core import Application from burr.core.action import FunctionBasedAction from burr.integrations.base import require_plugin -from burr.integrations.hamilton import Hamilton, StateSource +from burr.tracking.server.backend import LocalBackend +from burr.tracking.server.schema import ApplicationLogs, ApplicationSummary, Project, Span, Step + +try: + from burr.integrations.hamilton import Hamilton, StateSource + + HAMILTON_AVAILABLE = True +except ImportError: + HAMILTON_AVAILABLE = False try: import graphviz @@ -178,7 +187,11 @@ def render_action(state: AppState): return st.header(f"`{current_node}`") action_object = actions[current_node] - is_hamilton = isinstance(action_object, Hamilton) + if HAMILTON_AVAILABLE: + is_hamilton = isinstance(action_object, Hamilton) + else: + is_hamilton = False + is_function_api = isinstance(action_object, FunctionBasedAction) def format_read(var): @@ -286,3 +299,152 @@ def stringify(i): with data_view: render_state_results(app_state) update_state(app_state) + + +def get_projects(backend: LocalBackend) -> Sequence[Project]: + """Get projects from Burr backend""" + return asyncio.run(backend.list_projects({})) + + +def get_applications(project: Project, backend: LocalBackend) -> Sequence[ApplicationSummary]: + """Get project's applications from Burr backend""" + return asyncio.run(backend.list_apps({}, project_id=project.id, partition_key=None)) + + +def get_logs( + project: Project, application: ApplicationSummary, backend: LocalBackend +) -> ApplicationLogs: + """Get application's logs from Burr backend""" + return asyncio.run( + backend.get_application_logs( + {}, project_id=project.id, app_id=application.app_id, partition_key=None + ) + ) + + +def get_steps( + project: Project, application: ApplicationSummary, backend: LocalBackend +) -> Sequence[Step]: + """Get application's steps contained in logs from Burr backend""" + logs = get_logs(project=project, application=application, backend=backend) + return logs.steps + + +def get_spans(step: Step) -> Sequence[Span]: + """Get step's spans""" + return step.spans + + +def project_selectbox(projects: Sequence[Project] = None, backend: LocalBackend = None) -> Project: + """Create a streamlit selectbox for projects that automatically updates the URL query params""" + if projects and backend: + raise ValueError("Pass either `projects` OR `backend`, but not both.") + + if backend: + projects = asyncio.run(backend.list_projects({})) + + selection = 0 + if project_name := st.query_params.get("project"): + for idx, project in enumerate(projects): + if project.name == project_name: + selection = idx + break + + selected_project = st.selectbox( + "Project", + options=projects, + format_func=lambda project: project.name, + index=selection, + ) + st.query_params["project"] = selected_project.name + return selected_project + + +def application_selectbox( + applications: Sequence[ApplicationSummary] = None, + project: Project = None, + backend: LocalBackend = None, +) -> ApplicationSummary: + """Create a streamlit selectbox for applications that automatically updates the URL query params""" + if applications and project: + raise ValueError("Pass either `applications` OR `project`, but not both.") + + if project: + if backend is None: + raise ValueError("If passing `project`, you must also pass `backend`") + + applications = get_applications(project=project, backend=backend) + + selection = 0 + if app_id := st.query_params.get("app_id"): + for idx, app in enumerate(applications): + if app.app_id == app_id: + selection = idx + break + + selected_app = st.selectbox( + "Application", + options=applications, + format_func=lambda app: app.app_id, + index=selection, + ) + st.query_params["app_id"] = selected_app.app_id + return selected_app + + +def step_selectbox( + steps: Sequence[Step] = None, + application: ApplicationSummary = None, + project: Project = None, + backend: LocalBackend = None, +) -> Step: + """Create a streamlit selectbox for steps that automatically updates the URL query params""" + if steps and application: + raise ValueError("Pass either `steps` OR `application`, but not both.") + + if application: + if not backend or not project: + raise ValueError("If passing `application`, you must also pass `project` and `backend`") + + steps = get_steps(project=project, application=application, backend=backend) + + selection = 0 + if step_id := st.query_params.get("step"): + for idx, step in enumerate(steps): + if step.step_start_log.sequence_id == step_id: + selection = idx + break + + selected_step = st.selectbox( + "Step", + options=steps, + format_func=lambda step: f"{step.step_start_log.sequence_id}::{step.step_start_log.action}", + index=selection, + ) + st.query_params["step"] = str(selected_step.step_start_log.sequence_id) + return selected_step + + +def span_selectbox(spans: Sequence[Span] = None, step: Step = None) -> Span: + """Create a streamlit selectbox for spans that automatically updates the URL query params""" + if spans and step: + raise ValueError("Pass either `spans` OR `step`, but not both.") + + if step: + spans = get_spans(step) + + selection = 0 + if span_id := st.query_params.get("span"): + for idx, span in enumerate(spans): + if span.begin_entry.span_id == span_id: + selection = idx + break + + selected_span = st.selectbox( + "Span", + options=spans, + format_func=lambda span: f"{span.begin_entry.sequence_id}::{span.begin_entry.span_name}", + index=selection, + ) + st.query_params["span"] = str(selected_span.begin_entry.span_id) + return selected_span diff --git a/pyproject.toml b/pyproject.toml index 3c28cdea..ed69289e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,6 +158,16 @@ opentelemetry = [ "opentelemetry-api", "opentelemetry-sdk", ] + +playground = [ + "asyncio", + "litellm", + "streamlit", + "burr", + "burr[opentelemetry]", + "burr[tracking]", +] + [tool.setuptools] include-package-data = true