From 1134473794ea9c483ab81b508fe2588a4eb92b4b Mon Sep 17 00:00:00 2001 From: Rob Royce Date: Thu, 10 Oct 2024 10:08:20 -0700 Subject: [PATCH] feat(LLM): add official support for ChatOllama model. --- demo.sh | 3 ++- src/rosa/__init__.py | 4 ++-- src/rosa/rosa.py | 7 +++++-- src/turtle_agent/scripts/turtle_agent.py | 11 ++++++++++- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/demo.sh b/demo.sh index d13a3a1..cf9c530 100755 --- a/demo.sh +++ b/demo.sh @@ -29,6 +29,7 @@ DEVELOPMENT=${DEVELOPMENT:-false} case "$(uname)" in Linux*|Darwin*) echo "Enabling X11 forwarding..." + export DISPLAY=host.docker.internal:0 xhost + ;; MINGW*|CYGWIN*|MSYS*) @@ -66,4 +67,4 @@ docker run -it --rm --name $CONTAINER_NAME \ # Disable X11 forwarding xhost - -exit 0 \ No newline at end of file +exit 0 diff --git a/src/rosa/__init__.py b/src/rosa/__init__.py index 4c326da..a490e36 100644 --- a/src/rosa/__init__.py +++ b/src/rosa/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. from .prompts import RobotSystemPrompts -from .rosa import ROSA +from .rosa import ROSA, ChatModel -__all__ = ["ROSA", "RobotSystemPrompts"] +__all__ = ["ROSA", "RobotSystemPrompts", "ChatModel"] diff --git a/src/rosa/rosa.py b/src/rosa/rosa.py index 54d585d..703b48c 100644 --- a/src/rosa/rosa.py +++ b/src/rosa/rosa.py @@ -23,11 +23,14 @@ from langchain_community.callbacks import get_openai_callback from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate +from langchain_ollama import ChatOllama from langchain_openai import AzureChatOpenAI, ChatOpenAI from .prompts import RobotSystemPrompts, system_prompts from .tools import ROSATools +ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatOllama] + class ROSA: """ROSA (Robot Operating System Agent) is a class that encapsulates the logic for interacting with ROS systems @@ -35,7 +38,7 @@ class ROSA: Args: ros_version (Literal[1, 2]): The version of ROS that the agent will interact with. - llm (Union[AzureChatOpenAI, ChatOpenAI]): The language model to use for generating responses. + llm (Union[AzureChatOpenAI, ChatOpenAI, ChatOllama]): The language model to use for generating responses. tools (Optional[list]): A list of additional LangChain tool functions to use with the agent. tool_packages (Optional[list]): A list of Python packages containing LangChain tool functions to use. prompts (Optional[RobotSystemPrompts]): Custom prompts to use with the agent. @@ -63,7 +66,7 @@ class ROSA: def __init__( self, ros_version: Literal[1, 2], - llm: Union[AzureChatOpenAI, ChatOpenAI], + llm: ChatModel, tools: Optional[list] = None, tool_packages: Optional[list] = None, prompts: Optional[RobotSystemPrompts] = None, diff --git a/src/turtle_agent/scripts/turtle_agent.py b/src/turtle_agent/scripts/turtle_agent.py index 2b4742a..ca19978 100755 --- a/src/turtle_agent/scripts/turtle_agent.py +++ b/src/turtle_agent/scripts/turtle_agent.py @@ -21,8 +21,9 @@ import pyinputplus as pyip import rospy from langchain.agents import tool, Tool -from rich.console import Group # Add this import +# from langchain_ollama import ChatOllama from rich.console import Console +from rich.console import Group from rich.live import Live from rich.markdown import Markdown from rich.panel import Panel @@ -48,6 +49,14 @@ def __init__(self, streaming: bool = False, verbose: bool = True): self.__blacklist = ["master", "docker"] self.__prompts = get_prompts() self.__llm = get_llm(streaming=streaming) + + # self.__llm = ChatOllama( + # base_url="host.docker.internal:11434", + # model="llama3.1", + # temperature=0, + # num_ctx=8192, + # ) + self.__streaming = streaming # Another method for adding tools