Skip to content

Commit

Permalink
LLM tool and agent updates
Browse files Browse the repository at this point in the history
- cleaner code
- adds factory function for tools
- adds more available tools
  • Loading branch information
w4ffl35 committed Feb 1, 2024
1 parent 785cab6 commit f942afb
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 32 deletions.
36 changes: 23 additions & 13 deletions src/airunner/aihandler/casual_lm_transfformer_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from transformers import pipeline as hf_pipeline

from airunner.aihandler.local_agent import LocalAgent
from airunner.aihandler.llm_tools import QuitApplicationTool, StartVisionCaptureTool, StopVisionCaptureTool
from airunner.aihandler.llm_tools import QuitApplicationTool, StartVisionCaptureTool, StopVisionCaptureTool, \
StartAudioCaptureTool, StopAudioCaptureTool, StartSpeakersTool, StopSpeakersTool, ProcessVisionTool, \
ProcessAudioTool
from airunner.aihandler.tokenizer_handler import TokenizerHandler
from airunner.enums import SignalCode, LLMAction, LLMChatRole, LLMToolName

Expand Down Expand Up @@ -41,6 +43,9 @@ def __init__(self, *args, **kwargs):
self.guardrails_prompt: str = ""
self.system_instructions: str = ""
self.agent = None
self.tools: dict = self.load_tools()
self.restrict_tools_to_additional: bool = True
self.return_agent_code: bool = False
#self.register(SignalCode.LLM_RESPOND_TO_USER, self.on_llm_respond_to_user_signal)

@property
Expand All @@ -55,6 +60,20 @@ def botname(self):
return self._botname
return "Assistant"

@staticmethod
def load_tools() -> dict:
return {
LLMToolName.QUIT_APPLICATION.value: QuitApplicationTool(),
LLMToolName.VISION_START_CAPTURE.value: StartVisionCaptureTool(),
LLMToolName.VISION_STOP_CAPTURE.value: StopVisionCaptureTool(),
LLMToolName.STT_START_CAPTURE.value: StartAudioCaptureTool(),
LLMToolName.STT_STOP_CAPTURE.value: StopAudioCaptureTool(),
LLMToolName.TTS_ENABLE.value: StartSpeakersTool(),
LLMToolName.TTS_DISABLE.value: StopSpeakersTool(),
LLMToolName.DESCRIBE_IMAGE.value: ProcessVisionTool,
LLMToolName.LLM_PROCESS_STT_AUDIO.value: ProcessAudioTool(),
}

def on_clear_history_signal(self):
self.history = []

Expand Down Expand Up @@ -97,27 +116,18 @@ def load_agent(self):
# description="Agent that can return help results about the application."
# )
# )
tools = {
LLMToolName.QUIT_APPLICATION.value: QuitApplicationTool(),
LLMToolName.VISION_START_CAPTURE.value: StartVisionCaptureTool(),
LLMToolName.VISION_STOP_CAPTURE.value: StopVisionCaptureTool(),
}
self.agent = LocalAgent(
model=self.model,
tokenizer=self.tokenizer,
additional_tools=tools
additional_tools=self.tools,
restrict_tools_to_additional=self.restrict_tools_to_additional
)
self.agent._toolbox = tools

def chat(self, prompt: AnyStr) -> AnyStr:
self.logger.info("Chat Stream")
res = self.agent.chat(
task=self.prompt,
return_code=False
#task="add 5 and 5"
#message=prompt,
#chat_history=self.prepare_messages(LLMAction.CHAT),
#task=LLMToolName.ADD.value
return_code=self.return_agent_code
)
# self.stream_text(
# self.streamer,
Expand Down
149 changes: 130 additions & 19 deletions src/airunner/aihandler/llm_tools.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,149 @@
"""
This module, `llm_tools.py`, contains tools for an LLM agent.
These tools are used to control the application, analyze images,
audio and more.
The tools are implemented as classes, which are generated
using a factory function `create_application_control_tool_class`.
This function takes a description, a name, and a signal code, and returns a
class that inherits from `BaseTool` and `MediatorMixin`.
Each tool class has a `__call__` method that emits a signal when the
tool is used. The application listens for these signals and
responds accordingly.
Classes:
See below for a list of classes and their descriptions.
"""

from transformers import Tool

from airunner.mediator_mixin import MediatorMixin
from airunner.enums import SignalCode, LLMToolName


class BaseTool(Tool, MediatorMixin):
"""
Base class for all tools. Adds the `MediatorMixin` to the `Tool` class.
This allows for signals to be emitted when the tool is used.
"""
def __init__(self, *args, **kwargs):
MediatorMixin.__init__(self)
super().__init__(*args, **kwargs)


class ApplicationControlTool(BaseTool):
inputs = ["text"]
outputs = ["text"]
signal_code = None
def create_application_control_tool_class(description, name, signal_code):
"""
Factory function to create a class for an application control tool.
Args:
description (str): The description of the tool.
name (str): The name of the tool.
signal_code (SignalCode): The signal code that the tool emits when used.
Returns:
type: A class that represents the tool.
"""
class ApplicationControlTool(BaseTool):
inputs = ["text"]
outputs = ["text"]
signal_code = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __call__(self, *args, **kwargs):
self.emit(self.signal_code)
return "emitting signal"

ApplicationControlTool.__doc__ = description
ApplicationControlTool.__name__ = name
ApplicationControlTool.signal_code = signal_code

return ApplicationControlTool


QuitApplicationTool = create_application_control_tool_class(
(
"This tool quits the application. It takes no input and returns a "
"string."
),
LLMToolName.QUIT_APPLICATION.value,
SignalCode.QUIT_APPLICATION
)


StartVisionCaptureTool = create_application_control_tool_class(
(
"This tool turns the camera on - it starts the input feed. It takes no "
"input and returns a string."
),
LLMToolName.VISION_START_CAPTURE.value,
SignalCode.VISION_START_CAPTURE
)


def __call__(self, *args, **kwargs):
self.emit(self.signal_code)
return "emitting signal"
StopVisionCaptureTool = create_application_control_tool_class(
(
"This tool turns the camera off - it stops the input feed. It takes no "
"input and returns a string."
),
LLMToolName.VISION_STOP_CAPTURE.value,
SignalCode.VISION_STOP_CAPTURE
)

StartAudioCaptureTool = create_application_control_tool_class(
(
"This tool turns the microphone on. It takes no input and returns a "
"string."
),
LLMToolName.STT_START_CAPTURE.value,
SignalCode.STT_START_CAPTURE_SIGNAL
)

class QuitApplicationTool(ApplicationControlTool):
description = "This tool quits the application. It takes no input and returns a string."
name = LLMToolName.QUIT_APPLICATION.value
signal_code = SignalCode.QUIT_APPLICATION
StopAudioCaptureTool = create_application_control_tool_class(
(
"This tool turns the microphone off. It takes no input and returns a "
"string."
),
LLMToolName.STT_STOP_CAPTURE.value,
SignalCode.STT_STOP_CAPTURE_SIGNAL
)

StartSpeakersTool = create_application_control_tool_class(
(
"This tool turns the speakers on. It takes no input and returns a "
"string."
),
LLMToolName.TTS_ENABLE.value,
SignalCode.TTS_ENABLE_SIGNAL
)

class StartVisionCaptureTool(ApplicationControlTool):
description = "This tool turns the camera on - it starts the input feed. It takes no input and returns a string."
name = LLMToolName.VISION_START_CAPTURE.value
signal_code = SignalCode.VISION_START_CAPTURE
StopSpeakersTool = create_application_control_tool_class(
(
"This tool turns the speakers off. It takes no input and returns a "
"string."
),
LLMToolName.TTS_DISABLE.value,
SignalCode.TTS_DISABLE_SIGNAL
)

ProcessVisionTool = create_application_control_tool_class(
(
"This tool processes the images which are captured by the camera. "
"These are images that the assistant may use in the context of a "
"conversation with the user. It takes no input and returns a string."
),
LLMToolName.VISION_PROCESS_IMAGES.value,
SignalCode.VISION_PROCESS_IMAGES
)

class StopVisionCaptureTool(ApplicationControlTool):
description = "This tool turns the camera off - it stops the input feed. It takes no input and returns a string."
name = LLMToolName.VISION_STOP_CAPTURE.value
signal_code = SignalCode.VISION_STOP_CAPTURE
ProcessAudioTool = create_application_control_tool_class(
(
"This tool processes the audio which is captured by the microphone. "
"This is audio that the assistant may use in the context of a "
"conversation with the user. It takes no input and returns a string."
),
LLMToolName.LLM_PROCESS_STT_AUDIO.value,
SignalCode.LLM_PROCESS_STT_AUDIO_SIGNAL
)
6 changes: 6 additions & 0 deletions src/airunner/aihandler/local_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@


class LocalAgent(LocalAgentBase):
def __init__(self, *args, **kwargs):
self.restrict_tools_to_additional = kwargs.pop("restrict_tools_to_additional", False)
super().__init__(*args, **kwargs)
if self.restrict_tools_to_additional:
self._toolbox = kwargs.get("additional_tools")

def format_prompt(self, task, chat_mode=False):
task = super().format_prompt(task, chat_mode=chat_mode)
task = task.replace("</s>", "")
Expand Down

0 comments on commit f942afb

Please sign in to comment.