Skip to content

Commit

Permalink
Proof of concept tracing with hamilton and langchain
Browse files Browse the repository at this point in the history
This shows how to quickly implement something that
will create "spans" in the burr UI for the given framework.
  • Loading branch information
skrawcz committed Mar 16, 2024
1 parent 9cdd387 commit f157f50
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 25 deletions.
11 changes: 6 additions & 5 deletions examples/multi-agent-collaboration/func_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ def base_system_prompt(tool_names: list[str], system_message: str) -> str:
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress.\n\n"
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop.\n\n"
f" You have access to the following tools: {tool_names}.\n{system_message}"
"If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with 'FINAL ANSWER' so the team knows to stop.\n\n"
f"You have access to the following tools: {tool_names}.\n{system_message}\n\n"
"Remember to prefix your response with 'FINAL ANSWER' if you or another assistant "
"thinks the task is complete; assume the user can visualize the result."
)


Expand Down Expand Up @@ -138,9 +140,8 @@ def llm_function_response(
) -> object:
"""Creates the function response from the LLM model for the given prompt & functions.
:param base_system_prompt:
:param message_history:
:param tool_function_specs:
:param user_query:
:param llm_client:
:return:
"""
Expand Down
83 changes: 70 additions & 13 deletions examples/multi-agent-collaboration/hamilton_application.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,66 @@
import json
from typing import Any, Dict, Optional

import func_agent
from hamilton import driver
from hamilton import driver, lifecycle
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.utilities import PythonREPL

from burr import core
from burr.core import Action, ApplicationBuilder, State, action, default
from burr.lifecycle import PostRunStepHook
from burr.tracking import client as burr_tclient
from burr.visibility import ActionSpanTracer, TracerFactory


class SimpleTracer(lifecycle.NodeExecutionHook):
"""Simple Hamilton Tracer"""

def __init__(self, tracer: TracerFactory):
self._tracer: TracerFactory = tracer
self.active_spans = {}

def run_before_node_execution(
self,
*,
node_name: str,
node_tags: Dict[str, Any],
node_kwargs: Dict[str, Any],
node_return_type: type,
task_id: Optional[str],
run_id: str,
node_input_types: Dict[str, Any],
**future_kwargs: Any,
):
context_manager: ActionSpanTracer = self._tracer(node_name)
context_manager.__enter__()
self.active_spans[node_name] = context_manager

def run_after_node_execution(
self,
*,
node_name: str,
node_tags: Dict[str, Any],
node_kwargs: Dict[str, Any],
node_return_type: type,
result: Any,
error: Optional[Exception],
success: bool,
task_id: Optional[str],
run_id: str,
**future_kwargs: Any,
):
context_manager = self.active_spans.pop(node_name)
context_manager.__exit__(None, None, None)


def initialize_tool_dag(agent_name: str, tracer: TracerFactory) -> driver.Driver:
tracer = SimpleTracer(tracer)
# Initialize some things needed for tools.
tool_dag = driver.Builder().with_modules(func_agent).with_adapters(tracer).build()
return tool_dag


# Initialize some things needed for tools.
tool_dag = driver.Builder().with_modules(func_agent).build()
repl = PythonREPL()


Expand All @@ -30,13 +79,14 @@ def python_repl(code: str) -> dict:


@action(reads=["query", "messages"], writes=["messages", "next_hop"])
def chart_generator(state: State) -> tuple[dict, State]:
def chart_generator(state: State, __tracer: TracerFactory) -> tuple[dict, State]:
query = state["query"]
tool_dag = initialize_tool_dag("chart_generator", __tracer)
result = tool_dag.execute(
["parsed_tool_calls", "llm_function_message"],
inputs={
"tools": [python_repl],
"system_message": "Any charts you display will be visible by the user.",
"system_message": "Any charts you display will be visible by the user. When done say 'FINAL ANSWER'.",
"user_query": query,
"messages": state["messages"],
},
Expand All @@ -54,13 +104,14 @@ def chart_generator(state: State) -> tuple[dict, State]:


@action(reads=["query", "messages"], writes=["messages", "next_hop"])
def researcher(state: State) -> tuple[dict, State]:
def researcher(state: State, __tracer: TracerFactory) -> tuple[dict, State]:
query = state["query"]
tool_dag = initialize_tool_dag("researcher", __tracer)
result = tool_dag.execute(
["parsed_tool_calls", "llm_function_message"],
inputs={
"tools": [tavily_tool],
"system_message": "You should provide accurate data for the chart generator to use.",
"system_message": "You should provide accurate data for the chart generator to use. When done say 'FINAL ANSWER'.",
"user_query": query,
"messages": state["messages"],
},
Expand Down Expand Up @@ -134,21 +185,23 @@ def default_state_and_entry_point() -> tuple[dict, str]:
"messages": [],
"query": "Fetch the UK's GDP over the past 5 years,"
" then draw a line graph of it."
" Once you code it up, finish.",
" Once the python code has been written and the graph drawn, the task is complete.",
"sender": "",
"parsed_tool_calls": [],
}, "researcher"


def main(app_instance_id: str = None):
def main(app_instance_id: str = None, sequence_number: int = -1):
project_name = "demo:hamilton-multi-agent-v1"
if app_instance_id:
state, entry_point = burr_tclient.LocalTrackingClient.load_state(
project_name, app_instance_id
project_name, app_instance_id, sequence_no=sequence_number
)
else:
state, entry_point = default_state_and_entry_point()

# look up app_id for particular user
# if None -- then proceed with defaults
# else load from state, and set entry point
app = (
ApplicationBuilder()
.with_state(**state)
Expand Down Expand Up @@ -190,8 +243,12 @@ def main(app_instance_id: str = None):
if __name__ == "__main__":
# Add an app_id to restart from last sequence in that state
# e.g. fine the ID in the UI and then put it in here "app_f0e4a918-b49c-4ee1-9d2b-30c15104c51c"
_app_id = None # "app_fb52ca3b-9198-4e5a-9ee0-9c07df1d7edd"
main(_app_id)
# _app_id = None # "app_4ed5b3b3-0f38-4b37-aed7-559d506174c7"
_app_id = "app_4ed5b3b3-0f38-4b37-aed7-559d506174c7"
# _sequence_no = None # 23
_sequence_no = 23
main(None, None)
# main(_app_id, _sequence_no)

# some test code
# tavily_tool = TavilySearchResults(max_results=5)
Expand Down
59 changes: 52 additions & 7 deletions examples/multi-agent-collaboration/lcel_application.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
from typing import Annotated
from typing import Annotated, Any, Optional
from uuid import UUID

from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import FunctionMessage, HumanMessage
from langchain_core.outputs import LLMResult
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_function
Expand All @@ -14,6 +17,7 @@
from burr.core import Action, State, action, default, expr
from burr.lifecycle import PostRunStepHook
from burr.tracking import client as burr_tclient
from burr.visibility import ActionSpanTracer, TracerFactory

tavily_tool = TavilySearchResults(max_results=5)

Expand All @@ -22,6 +26,44 @@
repl = PythonREPL()


class LangChainTracer(BaseCallbackHandler):
def __init__(self, tracer: TracerFactory):
self._tracer: TracerFactory = tracer
self.active_spans = {}

def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> Any:
"""Run when LLM starts running."""
# print("LLM STARTED")
# print(prompts)
# print(serialized)
# print(kwargs)
model_name = kwargs["invocation_params"]["model_name"]
run_id = kwargs["run_id"]
name = (model_name + "_" + str(run_id))[:30]
context_manager: ActionSpanTracer = self._tracer(name)
context_manager.__enter__()
self.active_spans[name] = context_manager

def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM ends running."""
# print("LLM ENDED")
# print(response)
# print(run_id)
# print(parent_run_id)
# print(kwargs)
model_name = response.llm_output["model_name"]
name = (model_name + "_" + str(run_id))[:30]
context_manager = self.active_spans.pop(name)
context_manager.__exit__(None, None, None)


def create_agent(llm, tools, system_message: str):
"""Helper function to create an agent with a system message and tools."""
functions = [convert_to_openai_function(t) for t in tools]
Expand Down Expand Up @@ -58,8 +100,9 @@ def python_repl(code: Annotated[str, "The python code to execute to generate you


# Helper function to create a node for a given agent
def _agent_node(messages: list, sender: str, agent, name: str) -> dict:
result = agent.invoke({"messages": messages, "sender": sender})
def _agent_node(messages: list, sender: str, agent, name: str, tracer: TracerFactory) -> dict:
tracer = LangChainTracer(tracer)
result = agent.invoke({"messages": messages, "sender": sender}, config={"callbacks": [tracer]})
# We convert the agent output into a format that is suitable to append to the global state
if isinstance(result, FunctionMessage):
pass
Expand Down Expand Up @@ -88,15 +131,17 @@ def _agent_node(messages: list, sender: str, agent, name: str) -> dict:


@action(reads=["messages", "sender"], writes=["messages", "sender"])
def research_node(state: State) -> tuple[dict, State]:
def research_node(state: State, __tracer: TracerFactory) -> tuple[dict, State]:
# Research agent and node
result = _agent_node(state["messages"], state["sender"], research_agent, "Researcher")
result = _agent_node(state["messages"], state["sender"], research_agent, "Researcher", __tracer)
return result, state.append(messages=result["messages"]).update(sender="Researcher")


@action(reads=["messages", "sender"], writes=["messages", "sender"])
def chart_node(state: State) -> tuple[dict, State]:
result = _agent_node(state["messages"], state["sender"], chart_agent, "Chart Generator")
def chart_node(state: State, __tracer: TracerFactory) -> tuple[dict, State]:
result = _agent_node(
state["messages"], state["sender"], chart_agent, "Chart Generator", __tracer
)
return result, state.append(messages=result["messages"]).update(sender="Chart Generator")


Expand Down

0 comments on commit f157f50

Please sign in to comment.