From cb4ffc97ed06b78694d73c37aa731076ace2accb Mon Sep 17 00:00:00 2001 From: zilto Date: Fri, 6 Sep 2024 11:47:23 -0400 Subject: [PATCH 1/4] added playground --- burr/integrations/playground/__init__.py | 0 burr/integrations/playground/app.py | 341 +++++++++++++++++++++++ pyproject.toml | 10 + 3 files changed, 351 insertions(+) create mode 100644 burr/integrations/playground/__init__.py create mode 100644 burr/integrations/playground/app.py diff --git a/burr/integrations/playground/__init__.py b/burr/integrations/playground/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/burr/integrations/playground/app.py b/burr/integrations/playground/app.py new file mode 100644 index 00000000..a7c9f503 --- /dev/null +++ b/burr/integrations/playground/app.py @@ -0,0 +1,341 @@ +import asyncio + +import litellm +import streamlit as st + +from burr.core import Application, ApplicationBuilder, State, action +from burr.tracking import LocalTrackingClient +from burr.tracking.server.backend import LocalBackend +from burr.visibility import TracerFactory + + +@st.cache_data +def instrument(provider: str): + msg = None + if provider == "openai": + try: + from opentelemetry.instrumentation.openai import ( # openai is a dependency of litellm + OpenAIInstrumentor, + ) + + OpenAIInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "cohere": + try: + from opentelemetry.instrumentation.cohere import CohereInstrumentor + + CohereInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "anthropic": + try: + from opentelemetry.instrumentation.anthropic import AnthropicInstrumentor + + AnthropicInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "bedrock": + try: + from opentelemetry.instrumentation.bedrock import BedrockInstrumentor + + BedrockInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "ollama": + try: + from opentelemetry.instrumentation.ollama import OllamaInstrumentor + + OllamaInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "gemini": + try: + from opentelemetry.instrumentation.google_generativeai import ( + GoogleGenerativeAiInstrumentor, + ) + + GoogleGenerativeAiInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "replicate": + try: + from opentelemetry.instrumentation.replicate import ReplicateInstrumentor + + ReplicateInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "together_ai": + try: + from opentelemetry.instrumentation.together import TogetherAiInstrumentor + + TogetherAiInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "replicate": + try: + from opentelemetry.instrumentation.replicate import ReplicateInstrumentor + + ReplicateInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "huggingface": + try: + from opentelemetry.instrumentation.transformers import TransformersInstrumentor + + TransformersInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "vertex_ai": + try: + from opentelemetry.instrumentation.vertexai import VertexAIInstrumentor + + VertexAIInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + elif provider == "watsonx": + try: + from opentelemetry.instrumentation.watsonx import WatsonxInstrumentor + + WatsonxInstrumentor().instrument() + except ImportError: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + else: + msg = f"Couldn't instrument {provider}. Try installing `opentelemetry-instrumenation-{provider}" + + if msg: + print(msg) + return msg + + +@action(reads=["history"], writes=["history"]) +def generate_answer( + state: State, model: str, messages: list[dict], __tracer: TracerFactory, msg_to_log=None +): + if msg_to_log: + __tracer.log_attribute("message", msg_to_log) + + response = litellm.completion(model=model, messages=messages) + llm_answer = response.choices[0].message.content + + history = state["history"] + if history.get(model) is None: + history[model] = [] + + history[model] += [llm_answer] + return state.update(history=history) + + +def build_burr_app(source_project: str) -> Application: + tracker = LocalTrackingClient(project="burr-playground") + return ( + ApplicationBuilder() + .with_actions(generate_answer) + .with_transitions(("generate_answer", "generate_answer")) + .with_identifiers(app_id=source_project) + .initialize_from( + initializer=tracker, + resume_at_next_action=False, + default_state={"history": {}}, + default_entrypoint="generate_answer", + ) + .with_tracker("local", project="burr-playground", use_otel_tracing=True) + .build() + ) + + +@st.cache_resource +def get_burr_backend(): + return LocalBackend() + + +def normalize_spans(spans: list) -> dict: + nested_dict = {} + for span in spans: + key = span.key + value = span.value + + keys = key.split(".") + d = nested_dict + for k in keys[:-1]: + if k not in d: + d[k] = {} + d = d[k] + d[keys[-1]] = value + return nested_dict + + +def selector_with_params_component(items, query_param: str, item_key: str): + selection = 0 + if item_id := st.query_params.get(query_param): + for item_idx, i in enumerate(items): + if getattr(i, item_key) == item_id: + selection = item_idx + break + + selected_item = st.selectbox( + query_param.capitalize().split("_")[0], + options=items, + format_func=lambda i: getattr(i, item_key), + index=selection, + ) + st.query_params[query_param] = getattr(selected_item, item_key) + return selected_item + + +def history_component(normalized_spans): + # historical data + with st.expander("History", expanded=True): + prompts = normalized_spans["gen_ai"]["prompt"].values() + answer = normalized_spans["gen_ai"]["completion"].values() + + for message in list(prompts) + list(answer): + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + +def launcher_component(idx, default_provider=0): + """Menu to select LLM provider and selected model""" + selected_provider = st.selectbox( + "Provider", + options=litellm.models_by_provider.keys(), + index=default_provider, + key=f"provider_{idx}", + ) + selected_model = st.selectbox( + "Model", + options=litellm.models_by_provider[selected_provider], + index=None, + key=f"model_{idx}", + ) + st.session_state[f"selected_provider_{idx}"] = selected_provider + st.session_state[f"selected_model_{idx}"] = selected_model + + +def get_llm_spans(step): + chat_span_ids = set() + for span in step.spans: + if ".chat" in span.begin_entry.span_name: + chat_span_ids.add(span.begin_entry.span_id) + + return [attr for attr in step.attributes if attr.span_id in chat_span_ids] + + +def frontend(): + st.title("🌯 Burr prompt playground") + backend = get_burr_backend() + + # default value; is overriden at the end of `with st.sidebar:` + normalized_spans = None + with st.sidebar: + st.header("Burr playground") + + # project selection + projects = asyncio.run(backend.list_projects({})) + selected_project = selector_with_params_component(projects, "project", "name") + + # app selection + apps, _ = asyncio.run( + backend.list_apps({}, project_id=selected_project.id, partition_key=None) + ) + selected_app = selector_with_params_component(apps, "app_id", "app_id") + + # logs selection + logs = asyncio.run( + backend.get_application_logs( + {}, project_id=selected_project.id, app_id=selected_app.app_id, partition_key=None + ) + ) + steps_with_llms = [ + step + for step in logs.steps + if any(span for span in step.spans if span.begin_entry.span_name == "openai.chat") + ] + if len(steps_with_llms) == 0: + st.warning("Select a `Project > Application > Step` that includes LLM requests") + return + + step_selection = 0 + if step_id := st.query_params.get("step"): + for step_idx, step in enumerate(steps_with_llms): + if step.step_start_log.sequence_id == step_id: + step_selection = step_idx + break + + selected_step = st.selectbox( + "Step", + options=steps_with_llms, + index=step_selection, + format_func=lambda step: f"{step.step_start_log.sequence_id}: {step.step_start_log.action}", + ) + st.query_params["step"] = selected_step.step_start_log.sequence_id + + relevant_spans = get_llm_spans(selected_step) + normalized_spans = normalize_spans(relevant_spans) + + # main window + st.header( + selected_project.name + + " : " + + selected_app.app_id[:10] + + " : " + + str(selected_step.step_start_log.sequence_id) + + "-" + + selected_step.step_start_log.action + ) + + history_component(normalized_spans) + + messages = list(normalized_spans["gen_ai"]["prompt"].values()) + + left, right = st.columns([0.85, 0.15]) + with left: + new_prompt = st.text_area("Prompt", value=messages[-1]["content"], height=300) + + launcher_0, launcher_1, launcher_2 = st.columns(3) + with launcher_0: + launcher_component(0, default_provider=0) + placeholder_0 = st.empty() + + with launcher_1: + launcher_component(1, default_provider=3) + placeholder_1 = st.empty() + + with launcher_2: + launcher_component(2, default_provider=1) + placeholder_2 = st.empty() + + with right: + if st.button("Launch"): + for i in range(2): + model = st.session_state.get(f"selected_model_{i}") + if model is None: + continue + + instrumentation_msg = instrument(st.session_state[f"selected_provider_{i}"]) + burr_app = build_burr_app(source_project=selected_project.name) + _, _, state = burr_app.step( + inputs={ + "model": model, + "messages": messages[:-1] + [{"role": "user", "content": new_prompt}], + "msg_to_log": instrumentation_msg, + } + ) + + locals()[f"placeholder_{i}"].container().write(state["history"][model][-1]) + + +if __name__ == "__main__": + frontend() diff --git a/pyproject.toml b/pyproject.toml index 3c28cdea..ed69289e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,6 +158,16 @@ opentelemetry = [ "opentelemetry-api", "opentelemetry-sdk", ] + +playground = [ + "asyncio", + "litellm", + "streamlit", + "burr", + "burr[opentelemetry]", + "burr[tracking]", +] + [tool.setuptools] include-package-data = true From e08509ef0acb04fc5bff2feb66bba5c770257a30 Mon Sep 17 00:00:00 2001 From: zilto Date: Fri, 6 Sep 2024 17:01:00 -0400 Subject: [PATCH 2/4] added generic streamlit components; removed streamlit integration dep on hamilton --- burr/integrations/playground/app.py | 61 ++-------- burr/integrations/streamlit.py | 168 +++++++++++++++++++++++++++- 2 files changed, 176 insertions(+), 53 deletions(-) diff --git a/burr/integrations/playground/app.py b/burr/integrations/playground/app.py index a7c9f503..c48a1353 100644 --- a/burr/integrations/playground/app.py +++ b/burr/integrations/playground/app.py @@ -1,9 +1,13 @@ -import asyncio - import litellm import streamlit as st from burr.core import Application, ApplicationBuilder, State, action +from burr.integrations.streamlit import ( + application_selectbox, + get_steps, + project_selectbox, + step_selectbox, +) from burr.tracking import LocalTrackingClient from burr.tracking.server.backend import LocalBackend from burr.visibility import TracerFactory @@ -177,24 +181,6 @@ def normalize_spans(spans: list) -> dict: return nested_dict -def selector_with_params_component(items, query_param: str, item_key: str): - selection = 0 - if item_id := st.query_params.get(query_param): - for item_idx, i in enumerate(items): - if getattr(i, item_key) == item_id: - selection = item_idx - break - - selected_item = st.selectbox( - query_param.capitalize().split("_")[0], - options=items, - format_func=lambda i: getattr(i, item_key), - index=selection, - ) - st.query_params[query_param] = getattr(selected_item, item_key) - return selected_item - - def history_component(normalized_spans): # historical data with st.expander("History", expanded=True): @@ -243,45 +229,20 @@ def frontend(): st.header("Burr playground") # project selection - projects = asyncio.run(backend.list_projects({})) - selected_project = selector_with_params_component(projects, "project", "name") + selected_project = project_selectbox(backend=backend) + selected_app = application_selectbox(project=selected_project, backend=backend) + steps = get_steps(project=selected_project, application=selected_app, backend=backend) - # app selection - apps, _ = asyncio.run( - backend.list_apps({}, project_id=selected_project.id, partition_key=None) - ) - selected_app = selector_with_params_component(apps, "app_id", "app_id") - - # logs selection - logs = asyncio.run( - backend.get_application_logs( - {}, project_id=selected_project.id, app_id=selected_app.app_id, partition_key=None - ) - ) steps_with_llms = [ step - for step in logs.steps + for step in steps if any(span for span in step.spans if span.begin_entry.span_name == "openai.chat") ] if len(steps_with_llms) == 0: st.warning("Select a `Project > Application > Step` that includes LLM requests") return - step_selection = 0 - if step_id := st.query_params.get("step"): - for step_idx, step in enumerate(steps_with_llms): - if step.step_start_log.sequence_id == step_id: - step_selection = step_idx - break - - selected_step = st.selectbox( - "Step", - options=steps_with_llms, - index=step_selection, - format_func=lambda step: f"{step.step_start_log.sequence_id}: {step.step_start_log.action}", - ) - st.query_params["step"] = selected_step.step_start_log.sequence_id - + selected_step = step_selectbox(steps=steps_with_llms) relevant_spans = get_llm_spans(selected_step) normalized_spans = normalize_spans(relevant_spans) diff --git a/burr/integrations/streamlit.py b/burr/integrations/streamlit.py index d734aae3..957fb313 100644 --- a/burr/integrations/streamlit.py +++ b/burr/integrations/streamlit.py @@ -1,13 +1,22 @@ +import asyncio import colorsys import dataclasses import inspect import json -from typing import List, Optional +from typing import List, Optional, Sequence from burr.core import Application from burr.core.action import FunctionBasedAction from burr.integrations.base import require_plugin -from burr.integrations.hamilton import Hamilton, StateSource +from burr.tracking.server.backend import LocalBackend +from burr.tracking.server.schema import ApplicationLogs, ApplicationSummary, Project, Span, Step + +try: + from burr.integrations.hamilton import Hamilton, StateSource + + HAMILTON_AVAILABLE = True +except ImportError: + HAMILTON_AVAILABLE = False try: import graphviz @@ -178,7 +187,11 @@ def render_action(state: AppState): return st.header(f"`{current_node}`") action_object = actions[current_node] - is_hamilton = isinstance(action_object, Hamilton) + if HAMILTON_AVAILABLE: + is_hamilton = isinstance(action_object, Hamilton) + else: + is_hamilton = False + is_function_api = isinstance(action_object, FunctionBasedAction) def format_read(var): @@ -286,3 +299,152 @@ def stringify(i): with data_view: render_state_results(app_state) update_state(app_state) + + +def get_projects(backend: LocalBackend) -> Sequence[Project]: + """Get projects from Burr backend""" + return asyncio.run(backend.list_projects({})) + + +def get_applications(project: Project, backend: LocalBackend) -> Sequence[ApplicationSummary]: + """Get project's applications from Burr backend""" + return asyncio.run(backend.list_apps({}, project_id=project.id, partition_key=None)) + + +def get_logs( + project: Project, application: ApplicationSummary, backend: LocalBackend +) -> ApplicationLogs: + """Get application's logs from Burr backend""" + return asyncio.run( + backend.get_application_logs( + {}, project_id=project.id, app_id=application.app_id, partition_key=None + ) + ) + + +def get_steps( + project: Project, application: ApplicationSummary, backend: LocalBackend +) -> Sequence[Step]: + """Get application's steps contained in logs from Burr backend""" + logs = get_logs(project=project, application=application, backend=backend) + return logs.steps + + +def get_spans(step: Step) -> Sequence[Span]: + """Get step's spans""" + return step.spans + + +def project_selectbox(projects: Sequence[Project] = None, backend: LocalBackend = None) -> Project: + """Create a streamlit selectbox for projects that automatically updates the URL query params""" + if projects and backend: + raise ValueError("Pass either `projects` OR `backend`, but not both.") + + if backend: + projects = asyncio.run(backend.list_projects({})) + + selection = 0 + if project_name := st.query_params.get("project"): + for idx, project in enumerate(projects): + if project.name == project_name: + selection = idx + break + + selected_project = st.selectbox( + "Project", + options=projects, + format_func=lambda project: project.name, + index=selection, + ) + st.query_params["project"] = selected_project.name + return selected_project + + +def application_selectbox( + applications: Sequence[ApplicationSummary] = None, + project: Project = None, + backend: LocalBackend = None, +) -> ApplicationSummary: + """Create a streamlit selectbox for applications that automatically updates the URL query params""" + if applications and project: + raise ValueError("Pass either `applications` OR `project`, but not both.") + + if project: + if backend is None: + raise ValueError("If passing `project`, you must also pass `backend`") + + applications = get_applications(project=project, backend=backend) + + selection = 0 + if app_id := st.query_params.get("app_id"): + for idx, app in enumerate(applications): + if app.app_id == app_id: + selection = idx + break + + selected_app = st.selectbox( + "Application", + options=applications, + format_func=lambda app: app.app_id, + index=selection, + ) + st.query_params["app_id"] = selected_app.app_id + return selected_app + + +def step_selectbox( + steps: Sequence[Step] = None, + application: ApplicationSummary = None, + project: Project = None, + backend: LocalBackend = None, +) -> Step: + """Create a streamlit selectbox for steps that automatically updates the URL query params""" + if steps and application: + raise ValueError("Pass either `steps` OR `application`, but not both.") + + if application: + if not backend or not project: + raise ValueError("If passing `application`, you must also pass `project` and `backend`") + + steps = get_steps(project=project, application=application, backend=backend) + + selection = 0 + if step_id := st.query_params.get("step"): + for idx, step in enumerate(steps): + if step.step_start_log.sequence_id == step_id: + selection = idx + break + + selected_step = st.selectbox( + "Step", + options=steps, + format_func=lambda step: f"{step.step_start_log.sequence_id}::{step.step_start_log.action}", + index=selection, + ) + st.query_params["step"] = str(selected_step.step_start_log.sequence_id) + return selected_step + + +def span_selectbox(spans: Sequence[Span] = None, step: Step = None) -> Span: + """Create a streamlit selectbox for spans that automatically updates the URL query params""" + if spans and step: + raise ValueError("Pass either `spans` OR `step`, but not both.") + + if step: + spans = get_spans(step) + + selection = 0 + if span_id := st.query_params.get("span"): + for idx, span in enumerate(spans): + if span.begin_entry.span_id == span_id: + selection = idx + break + + selected_span = st.selectbox( + "Span", + options=spans, + format_func=lambda span: f"{span.begin_entry.sequence_id}::{span.begin_entry.span_name}", + index=selection, + ) + st.query_params["span"] = str(selected_span.begin_entry.span_id) + return selected_span From bc790ef952e3413d34f58441fc7c141de79d1e28 Mon Sep 17 00:00:00 2001 From: zilto Date: Fri, 6 Sep 2024 17:04:25 -0400 Subject: [PATCH 3/4] fixed improper filtering --- burr/integrations/playground/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burr/integrations/playground/app.py b/burr/integrations/playground/app.py index c48a1353..ba7f2de3 100644 --- a/burr/integrations/playground/app.py +++ b/burr/integrations/playground/app.py @@ -236,7 +236,7 @@ def frontend(): steps_with_llms = [ step for step in steps - if any(span for span in step.spans if span.begin_entry.span_name == "openai.chat") + if any(span for span in step.spans if len(get_llm_spans(step)) > 0) ] if len(steps_with_llms) == 0: st.warning("Select a `Project > Application > Step` that includes LLM requests") From ec2f69ccf081040bdb0b8403d988fbbe22cc0c56 Mon Sep 17 00:00:00 2001 From: zilto Date: Fri, 6 Sep 2024 17:16:47 -0400 Subject: [PATCH 4/4] fix bug for Anthropic that doesn't return a 'role' --- burr/integrations/playground/app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/burr/integrations/playground/app.py b/burr/integrations/playground/app.py index ba7f2de3..69669243 100644 --- a/burr/integrations/playground/app.py +++ b/burr/integrations/playground/app.py @@ -188,7 +188,8 @@ def history_component(normalized_spans): answer = normalized_spans["gen_ai"]["completion"].values() for message in list(prompts) + list(answer): - with st.chat_message(message["role"]): + role = message.get("role", "assistant") + with st.chat_message(role): st.markdown(message["content"])