diff --git a/src/airunner/aihandler/casual_lm_transfformer_base_handler.py b/src/airunner/aihandler/casual_lm_transfformer_base_handler.py index 40232610f..cbccdd245 100644 --- a/src/airunner/aihandler/casual_lm_transfformer_base_handler.py +++ b/src/airunner/aihandler/casual_lm_transfformer_base_handler.py @@ -5,7 +5,9 @@ from transformers import pipeline as hf_pipeline from airunner.aihandler.local_agent import LocalAgent -from airunner.aihandler.llm_tools import QuitApplicationTool, StartVisionCaptureTool, StopVisionCaptureTool +from airunner.aihandler.llm_tools import QuitApplicationTool, StartVisionCaptureTool, StopVisionCaptureTool, \ + StartAudioCaptureTool, StopAudioCaptureTool, StartSpeakersTool, StopSpeakersTool, ProcessVisionTool, \ + ProcessAudioTool from airunner.aihandler.tokenizer_handler import TokenizerHandler from airunner.enums import SignalCode, LLMAction, LLMChatRole, LLMToolName @@ -41,6 +43,9 @@ def __init__(self, *args, **kwargs): self.guardrails_prompt: str = "" self.system_instructions: str = "" self.agent = None + self.tools: dict = self.load_tools() + self.restrict_tools_to_additional: bool = True + self.return_agent_code: bool = False #self.register(SignalCode.LLM_RESPOND_TO_USER, self.on_llm_respond_to_user_signal) @property @@ -55,6 +60,20 @@ def botname(self): return self._botname return "Assistant" + @staticmethod + def load_tools() -> dict: + return { + LLMToolName.QUIT_APPLICATION.value: QuitApplicationTool(), + LLMToolName.VISION_START_CAPTURE.value: StartVisionCaptureTool(), + LLMToolName.VISION_STOP_CAPTURE.value: StopVisionCaptureTool(), + LLMToolName.STT_START_CAPTURE.value: StartAudioCaptureTool(), + LLMToolName.STT_STOP_CAPTURE.value: StopAudioCaptureTool(), + LLMToolName.TTS_ENABLE.value: StartSpeakersTool(), + LLMToolName.TTS_DISABLE.value: StopSpeakersTool(), + LLMToolName.DESCRIBE_IMAGE.value: ProcessVisionTool, + LLMToolName.LLM_PROCESS_STT_AUDIO.value: ProcessAudioTool(), + } + def on_clear_history_signal(self): self.history = [] @@ -97,27 +116,18 @@ def load_agent(self): # description="Agent that can return help results about the application." # ) # ) - tools = { - LLMToolName.QUIT_APPLICATION.value: QuitApplicationTool(), - LLMToolName.VISION_START_CAPTURE.value: StartVisionCaptureTool(), - LLMToolName.VISION_STOP_CAPTURE.value: StopVisionCaptureTool(), - } self.agent = LocalAgent( model=self.model, tokenizer=self.tokenizer, - additional_tools=tools + additional_tools=self.tools, + restrict_tools_to_additional=self.restrict_tools_to_additional ) - self.agent._toolbox = tools def chat(self, prompt: AnyStr) -> AnyStr: self.logger.info("Chat Stream") res = self.agent.chat( task=self.prompt, - return_code=False - #task="add 5 and 5" - #message=prompt, - #chat_history=self.prepare_messages(LLMAction.CHAT), - #task=LLMToolName.ADD.value + return_code=self.return_agent_code ) # self.stream_text( # self.streamer, diff --git a/src/airunner/aihandler/llm_tools.py b/src/airunner/aihandler/llm_tools.py index d581faec1..a690a846c 100644 --- a/src/airunner/aihandler/llm_tools.py +++ b/src/airunner/aihandler/llm_tools.py @@ -1,3 +1,21 @@ +""" +This module, `llm_tools.py`, contains tools for an LLM agent. +These tools are used to control the application, analyze images, +audio and more. + +The tools are implemented as classes, which are generated +using a factory function `create_application_control_tool_class`. +This function takes a description, a name, and a signal code, and returns a +class that inherits from `BaseTool` and `MediatorMixin`. + +Each tool class has a `__call__` method that emits a signal when the +tool is used. The application listens for these signals and +responds accordingly. + +Classes: + See below for a list of classes and their descriptions. +""" + from transformers import Tool from airunner.mediator_mixin import MediatorMixin @@ -5,34 +23,127 @@ class BaseTool(Tool, MediatorMixin): + """ + Base class for all tools. Adds the `MediatorMixin` to the `Tool` class. + This allows for signals to be emitted when the tool is used. + """ def __init__(self, *args, **kwargs): MediatorMixin.__init__(self) super().__init__(*args, **kwargs) -class ApplicationControlTool(BaseTool): - inputs = ["text"] - outputs = ["text"] - signal_code = None +def create_application_control_tool_class(description, name, signal_code): + """ + Factory function to create a class for an application control tool. + + Args: + description (str): The description of the tool. + name (str): The name of the tool. + signal_code (SignalCode): The signal code that the tool emits when used. + + Returns: + type: A class that represents the tool. + """ + class ApplicationControlTool(BaseTool): + inputs = ["text"] + outputs = ["text"] + signal_code = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + self.emit(self.signal_code) + return "emitting signal" + + ApplicationControlTool.__doc__ = description + ApplicationControlTool.__name__ = name + ApplicationControlTool.signal_code = signal_code + + return ApplicationControlTool + + +QuitApplicationTool = create_application_control_tool_class( + ( + "This tool quits the application. It takes no input and returns a " + "string." + ), + LLMToolName.QUIT_APPLICATION.value, + SignalCode.QUIT_APPLICATION +) + + +StartVisionCaptureTool = create_application_control_tool_class( + ( + "This tool turns the camera on - it starts the input feed. It takes no " + "input and returns a string." + ), + LLMToolName.VISION_START_CAPTURE.value, + SignalCode.VISION_START_CAPTURE +) + - def __call__(self, *args, **kwargs): - self.emit(self.signal_code) - return "emitting signal" +StopVisionCaptureTool = create_application_control_tool_class( + ( + "This tool turns the camera off - it stops the input feed. It takes no " + "input and returns a string." + ), + LLMToolName.VISION_STOP_CAPTURE.value, + SignalCode.VISION_STOP_CAPTURE +) +StartAudioCaptureTool = create_application_control_tool_class( + ( + "This tool turns the microphone on. It takes no input and returns a " + "string." + ), + LLMToolName.STT_START_CAPTURE.value, + SignalCode.STT_START_CAPTURE_SIGNAL +) -class QuitApplicationTool(ApplicationControlTool): - description = "This tool quits the application. It takes no input and returns a string." - name = LLMToolName.QUIT_APPLICATION.value - signal_code = SignalCode.QUIT_APPLICATION +StopAudioCaptureTool = create_application_control_tool_class( + ( + "This tool turns the microphone off. It takes no input and returns a " + "string." + ), + LLMToolName.STT_STOP_CAPTURE.value, + SignalCode.STT_STOP_CAPTURE_SIGNAL +) +StartSpeakersTool = create_application_control_tool_class( + ( + "This tool turns the speakers on. It takes no input and returns a " + "string." + ), + LLMToolName.TTS_ENABLE.value, + SignalCode.TTS_ENABLE_SIGNAL +) -class StartVisionCaptureTool(ApplicationControlTool): - description = "This tool turns the camera on - it starts the input feed. It takes no input and returns a string." - name = LLMToolName.VISION_START_CAPTURE.value - signal_code = SignalCode.VISION_START_CAPTURE +StopSpeakersTool = create_application_control_tool_class( + ( + "This tool turns the speakers off. It takes no input and returns a " + "string." + ), + LLMToolName.TTS_DISABLE.value, + SignalCode.TTS_DISABLE_SIGNAL +) +ProcessVisionTool = create_application_control_tool_class( + ( + "This tool processes the images which are captured by the camera. " + "These are images that the assistant may use in the context of a " + "conversation with the user. It takes no input and returns a string." + ), + LLMToolName.VISION_PROCESS_IMAGES.value, + SignalCode.VISION_PROCESS_IMAGES +) -class StopVisionCaptureTool(ApplicationControlTool): - description = "This tool turns the camera off - it stops the input feed. It takes no input and returns a string." - name = LLMToolName.VISION_STOP_CAPTURE.value - signal_code = SignalCode.VISION_STOP_CAPTURE +ProcessAudioTool = create_application_control_tool_class( + ( + "This tool processes the audio which is captured by the microphone. " + "This is audio that the assistant may use in the context of a " + "conversation with the user. It takes no input and returns a string." + ), + LLMToolName.LLM_PROCESS_STT_AUDIO.value, + SignalCode.LLM_PROCESS_STT_AUDIO_SIGNAL +) diff --git a/src/airunner/aihandler/local_agent.py b/src/airunner/aihandler/local_agent.py index 951be8566..eaab71a5d 100644 --- a/src/airunner/aihandler/local_agent.py +++ b/src/airunner/aihandler/local_agent.py @@ -5,6 +5,12 @@ class LocalAgent(LocalAgentBase): + def __init__(self, *args, **kwargs): + self.restrict_tools_to_additional = kwargs.pop("restrict_tools_to_additional", False) + super().__init__(*args, **kwargs) + if self.restrict_tools_to_additional: + self._toolbox = kwargs.get("additional_tools") + def format_prompt(self, task, chat_mode=False): task = super().format_prompt(task, chat_mode=chat_mode) task = task.replace("", "")