Skip to content

Commit

Permalink
Support for Llama 3.1 (#25)
Browse files Browse the repository at this point in the history
* feat(ROS1): fix roslaunch_list and rosnode_kill tools.

* chore: bump langchain package versions.

* feat(LLM): add official support for ChatOllama model.

* fix: rosservice tool causing invalid paramter validation with updated langchain libs.
  • Loading branch information
RobRoyce authored Oct 10, 2024
1 parent bf36b5d commit c6ca7a8
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 49 deletions.
16 changes: 5 additions & 11 deletions demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,10 @@ DEVELOPMENT=${DEVELOPMENT:-false}

# Enable X11 forwarding based on OS
case "$(uname)" in
Linux)
echo "Enabling X11 forwarding for Linux..."
export DISPLAY=:0
xhost +local:docker
;;
Darwin)
echo "Enabling X11 forwarding for macOS..."
ip=$(ifconfig en0 | awk '$1=="inet" {print $2}')
export DISPLAY=$ip:0
xhost + $ip
Linux*|Darwin*)
echo "Enabling X11 forwarding..."
export DISPLAY=host.docker.internal:0
xhost +
;;
MINGW*|CYGWIN*|MSYS*)
echo "Enabling X11 forwarding for Windows..."
Expand Down Expand Up @@ -73,4 +67,4 @@ docker run -it --rm --name $CONTAINER_NAME \
# Disable X11 forwarding
xhost -

exit 0
exit 0
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ requires-python = ">=3.9, <4"
dependencies = [
"PyYAML==6.0.1",
"python-dotenv>=1.0.1",
"langchain==0.2.14",
"langchain-community==0.2.12",
"langchain-core==0.2.34",
"langchain-openai==0.1.22",
"langchain-ollama",
"langchain==0.3.2",
"langchain-community==0.3.1",
"langchain-core==0.3.9",
"langchain-openai==0.2.2",
"langchain-ollama==0.2.0",
"pydantic",
"pyinputplus",
"azure-identity",
"cffi",
"rich",
"pillow>=10.4.0",
"numpy>=1.21.2",
"numpy>=1.26.4",
]

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions src/rosa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# limitations under the License.

from .prompts import RobotSystemPrompts
from .rosa import ROSA
from .rosa import ROSA, ChatModel

__all__ = ["ROSA", "RobotSystemPrompts"]
__all__ = ["ROSA", "RobotSystemPrompts", "ChatModel"]
7 changes: 5 additions & 2 deletions src/rosa/rosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@
from langchain_community.callbacks import get_openai_callback
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from .prompts import RobotSystemPrompts, system_prompts
from .tools import ROSATools

ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatOllama]


class ROSA:
"""ROSA (Robot Operating System Agent) is a class that encapsulates the logic for interacting with ROS systems
using natural language.
Args:
ros_version (Literal[1, 2]): The version of ROS that the agent will interact with.
llm (Union[AzureChatOpenAI, ChatOpenAI]): The language model to use for generating responses.
llm (Union[AzureChatOpenAI, ChatOpenAI, ChatOllama]): The language model to use for generating responses.
tools (Optional[list]): A list of additional LangChain tool functions to use with the agent.
tool_packages (Optional[list]): A list of Python packages containing LangChain tool functions to use.
prompts (Optional[RobotSystemPrompts]): Custom prompts to use with the agent.
Expand Down Expand Up @@ -63,7 +66,7 @@ class ROSA:
def __init__(
self,
ros_version: Literal[1, 2],
llm: Union[AzureChatOpenAI, ChatOpenAI],
llm: ChatModel,
tools: Optional[list] = None,
tool_packages: Optional[list] = None,
prompts: Optional[RobotSystemPrompts] = None,
Expand Down
71 changes: 45 additions & 26 deletions src/rosa/tools/ros1.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,14 @@ def rosservice_info(services: List[str]) -> dict:


@tool
def rosservice_call(service: str, args: List[str]) -> dict:
def rosservice_call(service: str, args: Optional[List[any]] = None) -> dict:
"""Calls a specific ROS service with the provided arguments.
:param service: The name of the ROS service to call.
:param args: A list of arguments to pass to the service.
"""
if not args:
args = []
try:
response = rosservice.call_service(service, args)
return response
Expand Down Expand Up @@ -738,43 +740,60 @@ def roslaunch(package: str, launch_file: str) -> str:


@tool
def roslaunch_list(package: str) -> dict:
"""Returns a list of available ROS launch files in a package.
def roslaunch_list(packages: List[str]) -> dict:
"""Returns a list of available ROS launch files in the specified packages.
:param package: The name of the ROS package to list launch files for.
:param packages: A list of ROS package names to list launch files for.
"""
try:
rospack = rospkg.RosPack()
directory = rospack.get_path(package)
launch = os.path.join(directory, "launch")

launch_files = []
results = {}
errors = []

# Get all files in the launch directory
if os.path.exists(launch):
launch_files = [
f for f in os.listdir(launch) if os.path.isfile(os.path.join(launch, f))
]
rospack = rospkg.RosPack()
for package in packages:
try:
directory = rospack.get_path(package)
launch = os.path.join(directory, "launch")

launch_files = []

# Get all files in the launch directory
if os.path.exists(launch):
launch_files = [
f
for f in os.listdir(launch)
if os.path.isfile(os.path.join(launch, f))
]

results[package] = {
"directory": directory,
"total": len(launch_files),
"launch_files": launch_files,
}
except Exception as e:
errors.append(
f"Failed to get ROS launch files for package '{package}': {e}"
)

if not results:
return {
"package": package,
"directory": directory,
"total": len(launch_files),
"launch_files": launch_files,
"error": "Failed to get ROS launch files for all specified packages.",
"details": errors,
}

except Exception as e:
return {"error": f"Failed to get ROS launch files in package '{package}': {e}"}
return {"results": results, "errors": errors}


@tool
def rosnode_kill(node: str) -> str:
def rosnode_kill(node_names: List[str]) -> dict:
"""Kills a specific ROS node.
:param node: The name of the ROS node to kill.
:param node_names: A list of node names to kill.
"""
if not node_names or len(node_names) == 0:
return {"error": "Please provide the name(s) of the ROS node to kill."}

try:
os.system(f"rosnode kill {node}")
return f"Killed ROS node '{node}'."
successes, failures = rosnode.kill_nodes(node_names)
return dict(successesfully_killed=successes, failed_to_kill=failures)
except Exception as e:
return f"Failed to kill ROS node '{node}': {e}"
return {"error": f"Failed to kill ROS node(s): {e}"}
14 changes: 12 additions & 2 deletions src/turtle_agent/scripts/turtle_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import pyinputplus as pyip
import rospy
from langchain.agents import tool, Tool
from rich.console import Group # Add this import
# from langchain_ollama import ChatOllama
from rich.console import Console
from rich.console import Group
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
Expand All @@ -35,9 +36,10 @@
from prompts import get_prompts


# Typical method for defining tools in ROSA
@tool
def cool_turtle_tool():
"""A cool turtle tool."""
"""A cool turtle tool that doesn't really do anything."""
return "This is a cool turtle tool! It doesn't do anything, but it's cool."


Expand All @@ -47,6 +49,14 @@ def __init__(self, streaming: bool = False, verbose: bool = True):
self.__blacklist = ["master", "docker"]
self.__prompts = get_prompts()
self.__llm = get_llm(streaming=streaming)

# self.__llm = ChatOllama(
# base_url="host.docker.internal:11434",
# model="llama3.1",
# temperature=0,
# num_ctx=8192,
# )

self.__streaming = streaming

# Another method for adding tools
Expand Down

0 comments on commit c6ca7a8

Please sign in to comment.