From c6ca7a85c6f98896ab527affd5925f83a1118d61 Mon Sep 17 00:00:00 2001 From: Rob Royce Date: Thu, 10 Oct 2024 11:54:56 -0700 Subject: [PATCH] Support for Llama 3.1 (#25) * feat(ROS1): fix roslaunch_list and rosnode_kill tools. * chore: bump langchain package versions. * feat(LLM): add official support for ChatOllama model. * fix: rosservice tool causing invalid paramter validation with updated langchain libs. --- demo.sh | 16 ++---- pyproject.toml | 12 ++-- src/rosa/__init__.py | 4 +- src/rosa/rosa.py | 7 ++- src/rosa/tools/ros1.py | 71 +++++++++++++++--------- src/turtle_agent/scripts/turtle_agent.py | 14 ++++- 6 files changed, 75 insertions(+), 49 deletions(-) diff --git a/demo.sh b/demo.sh index 2868ad9..cf9c530 100755 --- a/demo.sh +++ b/demo.sh @@ -27,16 +27,10 @@ DEVELOPMENT=${DEVELOPMENT:-false} # Enable X11 forwarding based on OS case "$(uname)" in - Linux) - echo "Enabling X11 forwarding for Linux..." - export DISPLAY=:0 - xhost +local:docker - ;; - Darwin) - echo "Enabling X11 forwarding for macOS..." - ip=$(ifconfig en0 | awk '$1=="inet" {print $2}') - export DISPLAY=$ip:0 - xhost + $ip + Linux*|Darwin*) + echo "Enabling X11 forwarding..." + export DISPLAY=host.docker.internal:0 + xhost + ;; MINGW*|CYGWIN*|MSYS*) echo "Enabling X11 forwarding for Windows..." @@ -73,4 +67,4 @@ docker run -it --rm --name $CONTAINER_NAME \ # Disable X11 forwarding xhost - -exit 0 \ No newline at end of file +exit 0 diff --git a/pyproject.toml b/pyproject.toml index cfb20a5..6d4a604 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,18 +25,18 @@ requires-python = ">=3.9, <4" dependencies = [ "PyYAML==6.0.1", "python-dotenv>=1.0.1", - "langchain==0.2.14", - "langchain-community==0.2.12", - "langchain-core==0.2.34", - "langchain-openai==0.1.22", - "langchain-ollama", + "langchain==0.3.2", + "langchain-community==0.3.1", + "langchain-core==0.3.9", + "langchain-openai==0.2.2", + "langchain-ollama==0.2.0", "pydantic", "pyinputplus", "azure-identity", "cffi", "rich", "pillow>=10.4.0", - "numpy>=1.21.2", + "numpy>=1.26.4", ] [project.urls] diff --git a/src/rosa/__init__.py b/src/rosa/__init__.py index 4c326da..a490e36 100644 --- a/src/rosa/__init__.py +++ b/src/rosa/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. from .prompts import RobotSystemPrompts -from .rosa import ROSA +from .rosa import ROSA, ChatModel -__all__ = ["ROSA", "RobotSystemPrompts"] +__all__ = ["ROSA", "RobotSystemPrompts", "ChatModel"] diff --git a/src/rosa/rosa.py b/src/rosa/rosa.py index 54d585d..703b48c 100644 --- a/src/rosa/rosa.py +++ b/src/rosa/rosa.py @@ -23,11 +23,14 @@ from langchain_community.callbacks import get_openai_callback from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate +from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI from .prompts import RobotSystemPrompts, system_prompts from .tools import ROSATools +ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatOllama] + class ROSA: """ROSA (Robot Operating System Agent) is a class that encapsulates the logic for interacting with ROS systems @@ -35,7 +38,7 @@ class ROSA: Args: ros_version (Literal[1, 2]): The version of ROS that the agent will interact with. - llm (Union[AzureChatOpenAI, ChatOpenAI]): The language model to use for generating responses. + llm (Union[AzureChatOpenAI, ChatOpenAI, ChatOllama]): The language model to use for generating responses. tools (Optional[list]): A list of additional LangChain tool functions to use with the agent. tool_packages (Optional[list]): A list of Python packages containing LangChain tool functions to use. prompts (Optional[RobotSystemPrompts]): Custom prompts to use with the agent. @@ -63,7 +66,7 @@ class ROSA: def __init__( self, ros_version: Literal[1, 2], - llm: Union[AzureChatOpenAI, ChatOpenAI], + llm: ChatModel, tools: Optional[list] = None, tool_packages: Optional[list] = None, prompts: Optional[RobotSystemPrompts] = None, diff --git a/src/rosa/tools/ros1.py b/src/rosa/tools/ros1.py index e6e5e6e..0b972f9 100644 --- a/src/rosa/tools/ros1.py +++ b/src/rosa/tools/ros1.py @@ -470,12 +470,14 @@ def rosservice_info(services: List[str]) -> dict: @tool -def rosservice_call(service: str, args: List[str]) -> dict: +def rosservice_call(service: str, args: Optional[List[any]] = None) -> dict: """Calls a specific ROS service with the provided arguments. :param service: The name of the ROS service to call. :param args: A list of arguments to pass to the service. """ + if not args: + args = [] try: response = rosservice.call_service(service, args) return response @@ -738,43 +740,60 @@ def roslaunch(package: str, launch_file: str) -> str: @tool -def roslaunch_list(package: str) -> dict: - """Returns a list of available ROS launch files in a package. +def roslaunch_list(packages: List[str]) -> dict: + """Returns a list of available ROS launch files in the specified packages. - :param package: The name of the ROS package to list launch files for. + :param packages: A list of ROS package names to list launch files for. """ - try: - rospack = rospkg.RosPack() - directory = rospack.get_path(package) - launch = os.path.join(directory, "launch") - - launch_files = [] + results = {} + errors = [] - # Get all files in the launch directory - if os.path.exists(launch): - launch_files = [ - f for f in os.listdir(launch) if os.path.isfile(os.path.join(launch, f)) - ] + rospack = rospkg.RosPack() + for package in packages: + try: + directory = rospack.get_path(package) + launch = os.path.join(directory, "launch") + + launch_files = [] + + # Get all files in the launch directory + if os.path.exists(launch): + launch_files = [ + f + for f in os.listdir(launch) + if os.path.isfile(os.path.join(launch, f)) + ] + + results[package] = { + "directory": directory, + "total": len(launch_files), + "launch_files": launch_files, + } + except Exception as e: + errors.append( + f"Failed to get ROS launch files for package '{package}': {e}" + ) + if not results: return { - "package": package, - "directory": directory, - "total": len(launch_files), - "launch_files": launch_files, + "error": "Failed to get ROS launch files for all specified packages.", + "details": errors, } - except Exception as e: - return {"error": f"Failed to get ROS launch files in package '{package}': {e}"} + return {"results": results, "errors": errors} @tool -def rosnode_kill(node: str) -> str: +def rosnode_kill(node_names: List[str]) -> dict: """Kills a specific ROS node. - :param node: The name of the ROS node to kill. + :param node_names: A list of node names to kill. """ + if not node_names or len(node_names) == 0: + return {"error": "Please provide the name(s) of the ROS node to kill."} + try: - os.system(f"rosnode kill {node}") - return f"Killed ROS node '{node}'." + successes, failures = rosnode.kill_nodes(node_names) + return dict(successesfully_killed=successes, failed_to_kill=failures) except Exception as e: - return f"Failed to kill ROS node '{node}': {e}" + return {"error": f"Failed to kill ROS node(s): {e}"} diff --git a/src/turtle_agent/scripts/turtle_agent.py b/src/turtle_agent/scripts/turtle_agent.py index 11f56e9..ca19978 100755 --- a/src/turtle_agent/scripts/turtle_agent.py +++ b/src/turtle_agent/scripts/turtle_agent.py @@ -21,8 +21,9 @@ import pyinputplus as pyip import rospy from langchain.agents import tool, Tool -from rich.console import Group # Add this import +# from langchain_ollama import ChatOllama from rich.console import Console +from rich.console import Group from rich.live import Live from rich.markdown import Markdown from rich.panel import Panel @@ -35,9 +36,10 @@ from prompts import get_prompts +# Typical method for defining tools in ROSA @tool def cool_turtle_tool(): - """A cool turtle tool.""" + """A cool turtle tool that doesn't really do anything.""" return "This is a cool turtle tool! It doesn't do anything, but it's cool." @@ -47,6 +49,14 @@ def __init__(self, streaming: bool = False, verbose: bool = True): self.__blacklist = ["master", "docker"] self.__prompts = get_prompts() self.__llm = get_llm(streaming=streaming) + + # self.__llm = ChatOllama( + # base_url="host.docker.internal:11434", + # model="llama3.1", + # temperature=0, + # num_ctx=8192, + # ) + self.__streaming = streaming # Another method for adding tools