From 03a0e6f61a805580f60d47f3417eb393bf827412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Mon, 3 Feb 2025 14:59:00 +0100 Subject: [PATCH 01/30] feat: add base impl of tts agent and start moving tts models into new format --- src/rai_core/rai/agents/tts_agent.py | 50 +++++++++++++++ src/rai_core/rai/agents/voice_agent.py | 2 +- src/rai_core/rai/communication/__init__.py | 2 + .../communication/sound_device/__init__.py | 3 +- src/rai_tts/rai_tts/models/__init__.py | 18 ++++++ src/rai_tts/rai_tts/models/base.py | 33 ++++++++++ src/rai_tts/rai_tts/models/open_tts.py | 62 +++++++++++++++++++ 7 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 src/rai_core/rai/agents/tts_agent.py create mode 100644 src/rai_tts/rai_tts/models/__init__.py create mode 100644 src/rai_tts/rai_tts/models/base.py create mode 100644 src/rai_tts/rai_tts/models/open_tts.py diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py new file mode 100644 index 000000000..8a7498062 --- /dev/null +++ b/src/rai_core/rai/agents/tts_agent.py @@ -0,0 +1,50 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import TYPE_CHECKING, Optional + +from rai.agents.base import BaseAgent +from rai.communication import ROS2ARIConnector, SoundDeviceConnector + +if TYPE_CHECKING: + from rai.communication.sound_device.api import SoundDeviceConfig + + +class TextToSpeechAgent(BaseAgent): + def __init__( + self, + speaker_config: SoundDeviceConfig, + ros2_name: str, + logger: Optional[logging.Logger] = None, + ): + if logger is None: + self.logger = logging.getLogger(__name__) + else: + self.logger = logger + + speaker = SoundDeviceConnector( + targets=[("speaker", speaker_config)], sources=[] + ) + ros2_connector = ROS2ARIConnector(ros2_name) + super().__init__(connectors={"ros2": ros2_connector, "speaker": speaker}) + self.running = False + + def __call__(self): + self.run() + + def run(self): + self.running = True + self.logger.info("TextToSpeechAgent started") diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index 339db49da..a49fde57b 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -26,10 +26,10 @@ from rai.communication import ( ROS2ARIConnector, ROS2ARIMessage, + SoundDeviceConfig, SoundDeviceConnector, SoundDeviceMessage, ) -from rai.communication.sound_device.api import SoundDeviceConfig from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel diff --git a/src/rai_core/rai/communication/__init__.py b/src/rai_core/rai/communication/__init__.py index f18324d79..b1a039a3a 100644 --- a/src/rai_core/rai/communication/__init__.py +++ b/src/rai_core/rai/communication/__init__.py @@ -17,6 +17,7 @@ from .hri_connector import HRIConnector, HRIMessage, HRIPayload from .ros2.connectors import ROS2ARIConnector, ROS2ARIMessage from .sound_device.connector import ( + SoundDeviceConfig, SoundDeviceConnector, SoundDeviceError, SoundDeviceMessage, @@ -35,4 +36,5 @@ "SoundDeviceConnector", "SoundDeviceError", "SoundDeviceMessage", + "SoundDeviceConfig", ] diff --git a/src/rai_core/rai/communication/sound_device/__init__.py b/src/rai_core/rai/communication/sound_device/__init__.py index 503c274d9..e93e76bf3 100644 --- a/src/rai_core/rai/communication/sound_device/__init__.py +++ b/src/rai_core/rai/communication/sound_device/__init__.py @@ -13,11 +13,12 @@ # limitations under the License. from .api import SoundDeviceAPI, SoundDeviceConfig, SoundDeviceError -from .connector import SoundDeviceConnector +from .connector import SoundDeviceConnector, SoundDeviceMessage __all__ = [ "SoundDeviceAPI", "SoundDeviceConfig", "SoundDeviceConnector", "SoundDeviceError", + "SoundDeviceMessage", ] diff --git a/src/rai_tts/rai_tts/models/__init__.py b/src/rai_tts/rai_tts/models/__init__.py new file mode 100644 index 000000000..bcc4a3822 --- /dev/null +++ b/src/rai_tts/rai_tts/models/__init__.py @@ -0,0 +1,18 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import SpeechResult, TTSModel, TTSModelError +from .open_tts import OpenTTS + +__all__ = ["TTSModel", "SpeechResult", "TTSModelError", "OpenTTS"] diff --git a/src/rai_tts/rai_tts/models/base.py b/src/rai_tts/rai_tts/models/base.py new file mode 100644 index 000000000..0f498edc6 --- /dev/null +++ b/src/rai_tts/rai_tts/models/base.py @@ -0,0 +1,33 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import NamedTuple + +from numpy.typing import NDArray + + +class SpeechResult(NamedTuple): + speech: NDArray + sample_rate: int + + +class TTSModelError(Exception): + pass + + +class TTSModel(ABC): + @abstractmethod + def get_speech(self, text: str) -> SpeechResult: + pass diff --git a/src/rai_tts/rai_tts/models/open_tts.py b/src/rai_tts/rai_tts/models/open_tts.py new file mode 100644 index 000000000..ab56b0037 --- /dev/null +++ b/src/rai_tts/rai_tts/models/open_tts.py @@ -0,0 +1,62 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BytesIO + +import numpy as np +import requests +from scipy.io.wavfile import read + +from rai_tts.models import SpeechResult, TTSModel, TTSModelError + + +class OpenTTS(TTSModel): + def __init__( + self, + url: str = "http://localhost:5500/api/tts", + voice: str = "larynx:blizzard_lessac-glow_tts", + ): + self.url = url + self.voice = voice + + def get_speech(self, text: str) -> SpeechResult: + params = { + "voice": self.voice, + "text": text, + } + try: + response = requests.get("http://localhost:5500/api/tts", params=params) + except requests.exceptions.RequestException as e: + raise TTSModelError( + f"Error occurred while fetching audio: {e}, check if OpenTTS server is running correctly." + ) from e + + content_type = response.headers.get("Content-Type", "") + + if "audio" not in content_type: + raise ValueError("Response does not contain audio data") + + # Load audio into memory + audio_bytes = BytesIO(response.content) + sample_rate, data = read(audio_bytes) + if data.dtype == np.int32: + data = (data / 2**16).astype(np.int16) # Scale down from int32 + elif data.dtype == np.uint8: + data = (data - 128).astype(np.int16) * 256 # Convert uint8 to int16 + elif data.dtype == np.float32: + data = ( + (data * 32768).clip(-32768, 32767).astype(np.int16) + ) # Convert float32 to int16 + + return SpeechResult(data, sample_rate) From ef39181159834e6a0acd430118908d08871da1c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 4 Feb 2025 13:16:55 +0100 Subject: [PATCH 02/30] feat: change connector api to support AudioSegment --- .../rai/communication/sound_device/api.py | 24 ++++++++++++++----- .../communication/sound_device/connector.py | 23 ++++-------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/rai_core/rai/communication/sound_device/api.py b/src/rai_core/rai/communication/sound_device/api.py index 98d554e4b..31439f9ed 100644 --- a/src/rai_core/rai/communication/sound_device/api.py +++ b/src/rai_core/rai/communication/sound_device/api.py @@ -16,7 +16,9 @@ from dataclasses import dataclass from typing import Any, Callable, Optional +import numpy as np from numpy._typing import NDArray +from pydub import AudioSegment from scipy.signal import resample try: @@ -77,7 +79,7 @@ def __init__(self, config: SoundDeviceConfig): self.in_stream = None self.out_stream = None - def write(self, data: NDArray, blocking: bool = False, loop: bool = False): + def write(self, data: AudioSegment, blocking: bool = False, loop: bool = False): """ Write data to the sound device. @@ -101,15 +103,16 @@ def write(self, data: NDArray, blocking: bool = False, loop: bool = False): if not self.write_flag: raise SoundDeviceError(f"{self.device_name} does not support writing!") assert sd is not None + audio = np.array(data.get_array_of_samples()) sd.play( - data, - samplerate=self.sample_rate, + audio, + samplerate=data.frame_rate, blocking=blocking, loop=loop, device=self.device_number, ) - def read(self, time: float, blocking: bool = False) -> NDArray: + def read(self, time: float, blocking: bool = False) -> AudioSegment: """ Read data from the sound device. @@ -142,7 +145,7 @@ def read(self, time: float, blocking: bool = False) -> NDArray: raise SoundDeviceError(f"{self.device_name} does not support reading!") assert sd is not None frames = int(time * self.sample_rate) - return sd.rec( + recording = sd.rec( frames=frames, samplerate=self.sample_rate, channels=self.config.channels, @@ -151,6 +154,13 @@ def read(self, time: float, blocking: bool = False) -> NDArray: dtype=self.config.dtype, ) + return AudioSegment( + data=recording.flatten(), + sample_width=recording.dtype.itemsize, + frame_rate=self.sample_rate, + channels=self.config.channels, + ) + def stop(self): """ Stop the sound device from playing or recording. @@ -179,6 +189,7 @@ def open_write_stream( self, feed_data: Callable[[NDArray, int, Any, Any], None], on_done: Callable = lambda _: None, + sample_rate: Optional[int] = None, ): if not self.write_flag or not self.stream_flag: raise SoundDeviceError( @@ -201,8 +212,9 @@ def callback(indata: NDArray, frames: int, time: Any, status: CallbackFlags): try: assert sd is not None + sample_rate = self.sample_rate if sample_rate is None else sample_rate self.out_stream = sd.OutputStream( - samplerate=self.sample_rate, + samplerate=sample_rate, channels=self.config.channels, device=self.device_number, dtype=self.config.dtype, diff --git a/src/rai_core/rai/communication/sound_device/connector.py b/src/rai_core/rai/communication/sound_device/connector.py index 09daf4b98..65f4165bf 100644 --- a/src/rai_core/rai/communication/sound_device/connector.py +++ b/src/rai_core/rai/communication/sound_device/connector.py @@ -13,13 +13,8 @@ # limitations under the License. -import base64 -import io from typing import Callable, Literal, Optional, Tuple -import numpy as np -from scipy.io import wavfile - try: import sounddevice as sd except ImportError as e: @@ -100,10 +95,7 @@ def send_message(self, message: SoundDeviceMessage, target: str, **kwargs) -> No ) else: if message.audios is not None: - wav_bytes = base64.b64decode(message.audios[0]) - wav_buffer = io.BytesIO(wav_bytes) - _, audio_data = wavfile.read(wav_buffer) - self.devices[target].write(audio_data) + self.devices[target].write(message.audios[0]) else: raise SoundDeviceError("Failed to provice audios in message to play") @@ -127,22 +119,15 @@ def service_call( raise SoundDeviceError("For stopping use send_message with stop=True.") elif message.read: recording = self.devices[target].read(duration, blocking=True) + payload = HRIPayload( text="", - audios=[ - base64.b64encode(recording).decode("utf-8") - ], # TODO: refactor once utility functions for encoding/decoding are available + audios=[recording], ) ret = SoundDeviceMessage(payload) else: if message.audios is not None: - wav_bytes = base64.b64decode( - message.audios[0] - ) # TODO: refactor once utility functions for encoding/decoding are available - wav_buffer = io.BytesIO(wav_bytes) - _, audio_data = wavfile.read(wav_buffer) - audio_data = np.array(audio_data) - self.devices[target].write(audio_data, blocking=True) + self.devices[target].write(message.audios[0], blocking=True) else: raise SoundDeviceError("Failed to provice audios in message to play") ret = SoundDeviceMessage() From 983d148d11d403caf413b4f72e1ed2ac0c39ae53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Mon, 10 Feb 2025 17:33:48 +0100 Subject: [PATCH 03/30] feat: working TTS, with pausing --- src/rai_core/rai/agents/__init__.py | 2 + src/rai_core/rai/agents/tts_agent.py | 147 +++++++++++++++++- src/rai_core/rai/communication/__init__.py | 13 +- .../rai/communication/ros2/__init__.py | 12 +- src/rai_core/rai/communication/ros2/api.py | 24 ++- .../rai/communication/ros2/connectors.py | 30 ++-- .../rai/communication/ros2/messages.py | 27 ++++ .../rai/communication/sound_device/api.py | 5 +- .../communication/sound_device/connector.py | 7 +- src/rai_tts/rai_tts/models/__init__.py | 4 +- src/rai_tts/rai_tts/models/base.py | 15 +- src/rai_tts/rai_tts/models/open_tts.py | 15 +- 12 files changed, 262 insertions(+), 39 deletions(-) create mode 100644 src/rai_core/rai/communication/ros2/messages.py diff --git a/src/rai_core/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py index e28822100..fcbc5517f 100644 --- a/src/rai_core/rai/agents/__init__.py +++ b/src/rai_core/rai/agents/__init__.py @@ -15,9 +15,11 @@ from rai.agents.conversational_agent import create_conversational_agent from rai.agents.state_based import create_state_based_agent from rai.agents.tool_runner import ToolRunner +from rai.agents.tts_agent import TextToSpeechAgent from rai.agents.voice_agent import VoiceRecognitionAgent __all__ = [ + "TextToSpeechAgent", "ToolRunner", "VoiceRecognitionAgent", "create_conversational_agent", diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index 8a7498062..d41b1a8ef 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -14,20 +14,54 @@ import logging +from dataclasses import dataclass +from threading import Event, Thread from typing import TYPE_CHECKING, Optional +from numpy._typing import NDArray +from pydub import AudioSegment + from rai.agents.base import BaseAgent -from rai.communication import ROS2ARIConnector, SoundDeviceConnector +from rai.communication import ( + ROS2HRIConnector, + SoundDeviceConfig, + SoundDeviceConnector, + TopicConfig, +) +from rai.communication.ros2.connectors import ROS2HRIMessage +from rai.communication.sound_device.connector import SoundDeviceMessage +from rai_tts.models.base import TTSModel if TYPE_CHECKING: from rai.communication.sound_device.api import SoundDeviceConfig +from queue import Empty, Queue + +import numpy as np + +# This file contains every concurrent programming antipattern known to man +# The words callback hell are insufficient to describe the cacophony of function calls +# wreathing havoc along the 9 circles of threads +# Ye who enter here abandon all hope +# +# It works tho + + +@dataclass +class PlayData: + playing: bool = False + current_segment: Optional[AudioSegment] = None + data: Optional[NDArray] = None + channels: int = 1 + current_frame: int = 0 + class TextToSpeechAgent(BaseAgent): def __init__( self, speaker_config: SoundDeviceConfig, ros2_name: str, + tts: TTSModel, logger: Optional[logging.Logger] = None, ): if logger is None: @@ -38,13 +72,122 @@ def __init__( speaker = SoundDeviceConnector( targets=[("speaker", speaker_config)], sources=[] ) - ros2_connector = ROS2ARIConnector(ros2_name) + self.node_base_name = ros2_name + self.model = tts + ros2_connector = self._setup_ros2_connector() super().__init__(connectors={"ros2": ros2_connector, "speaker": speaker}) + + self.text_queue = Queue() + self.audio_queue = Queue() + + self.tog_play_event = Event() + self.stop_event = Event() + self.current_audio = None + + self.terminate_agent = Event() + self.transcription_thread = None self.running = False + self.playback_data = PlayData() + def __call__(self): self.run() def run(self): self.running = True self.logger.info("TextToSpeechAgent started") + self.transcription_thread = Thread(target=self._transcription_thread) + self.transcription_thread.start() + sample_rate, channels = self.model.get_tts_params() + + msg = SoundDeviceMessage(read=False) + assert isinstance(self.connectors["speaker"], SoundDeviceConnector) + self.connectors["speaker"].start_action( + msg, + "speaker", + on_feedback=self._speaker_callback, + on_done=lambda: None, + sample_rate=sample_rate, + channels=channels, + ) + + def _speaker_callback(self, outdata, frames, time, status_dict): + set_flags = [flag for flag, status in status_dict.items() if status] + + if set_flags: + self.logger.warning("Flags set:", ", ".join(set_flags)) + if self.playback_data.playing: + if self.playback_data.current_segment is None: + try: + self.playback_data.current_segment = self.audio_queue.get( + block=False + ) + self.playback_data.data = np.array( + self.playback_data.current_segment.get_array_of_samples() # type: ignore + ).reshape(-1, self.playback_data.channels) + except Empty: + pass + if self.playback_data.data is not None: + current_frame = self.playback_data.current_frame + chunksize = min(len(self.playback_data.data) - current_frame, frames) + outdata[:chunksize] = self.playback_data.data[ + current_frame : current_frame + chunksize + ] + if chunksize < frames: + outdata[chunksize:] = 0 + self.playback_data.current_frame = 0 + self.playback_data.current_segment = None + self.playback_data.data = None + else: + self.playback_data.current_frame += chunksize + + if not self.playback_data.playing: + outdata[:] = np.zeros(outdata.size).reshape(outdata.shape) + + def stop(self): + self.terminate_agent.set() + if self.transcription_thread is not None: + self.transcription_thread.join() + + def _transcription_thread(self): + while not self.terminate_agent.wait(timeout=0.01): + try: + data = self.text_queue.get(block=False) + except Empty: + continue + audio = self.model.get_speech(data) + self.audio_queue.put(audio) + self.playback_data.playing = True + + def _setup_ros2_connector(self): + to_human = TopicConfig( + name="/to_human", + msg_type="std_msgs/msg/String", + auto_qos_matching=True, + is_subscriber=True, + subscriber_callback=self._on_to_human_message, + source_author="ai", + ) + voice_commands = TopicConfig( + name="/voice_commands", + msg_type="std_msgs/msg/String", + auto_qos_matching=True, + is_subscriber=True, + subscriber_callback=self._on_command_message, + source_author="human", + ) + return ROS2HRIConnector( + node_name=self.node_base_name, + sources=[("/to_human", to_human), ("/voice_commands", voice_commands)], + ) + + def _on_to_human_message(self, message: ROS2HRIMessage): + self.text_queue.put(message.text) + + def _on_command_message(self, message: ROS2HRIMessage): + if message.text == "tog_play": + self.playback_data.playing = not self.playback_data.playing + elif message.text == "stop": + self.playback_data.playing = False + while not self.audio_queue.empty(): + _ = self.audio_queue.get() diff --git a/src/rai_core/rai/communication/__init__.py b/src/rai_core/rai/communication/__init__.py index b1a039a3a..4b76ee303 100644 --- a/src/rai_core/rai/communication/__init__.py +++ b/src/rai_core/rai/communication/__init__.py @@ -15,7 +15,13 @@ from .ari_connector import ARIConnector, ARIMessage from .base_connector import BaseConnector, BaseMessage from .hri_connector import HRIConnector, HRIMessage, HRIPayload -from .ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +from .ros2.api import TopicConfig +from .ros2.connectors import ( + ROS2ARIConnector, + ROS2ARIMessage, + ROS2HRIConnector, + ROS2HRIMessage, +) from .sound_device.connector import ( SoundDeviceConfig, SoundDeviceConnector, @@ -33,8 +39,11 @@ "HRIPayload", "ROS2ARIConnector", "ROS2ARIMessage", + "ROS2HRIConnector", + "ROS2HRIMessage", + "SoundDeviceConfig", "SoundDeviceConnector", "SoundDeviceError", "SoundDeviceMessage", - "SoundDeviceConfig", + "TopicConfig", ] diff --git a/src/rai_core/rai/communication/ros2/__init__.py b/src/rai_core/rai/communication/ros2/__init__.py index 3334d9673..587ac0296 100644 --- a/src/rai_core/rai/communication/ros2/__init__.py +++ b/src/rai_core/rai/communication/ros2/__init__.py @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI +from .api import ConfigurableROS2TopicAPI, ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI +from .messages import ROS2ARIMessage, ROS2HRIMessage -__all__ = ["ROS2ActionAPI", "ROS2ServiceAPI", "ROS2TopicAPI"] +__all__ = [ + "ConfigurableROS2TopicAPI", + "ROS2ARIMessage", + "ROS2ActionAPI", + "ROS2HRIMessage", + "ROS2ServiceAPI", + "ROS2TopicAPI", +] diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index 1f32f45a0..b9f2b3db2 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -25,6 +25,7 @@ Callable, Dict, List, + Literal, Optional, Tuple, Type, @@ -55,6 +56,8 @@ from rclpy.task import Future from rclpy.topic_endpoint_info import TopicEndpointInfo +from rai.communication.hri_connector import HRIPayload +from rai.communication.ros2.messages import ROS2HRIMessage from rai.tools.ros.utils import import_message_from_str, wait_for_message @@ -323,7 +326,7 @@ class TopicConfig: auto_qos_matching: bool = True qos_profile: Optional[QoSProfile] = None is_subscriber: bool = False - subscriber_callback: Optional[Callable[[Any], None]] = None + subscriber_callback: Optional[Callable[[ROS2HRIMessage], None]] = None def __post_init__(self): if not self.auto_qos_matching and self.qos_profile is None: @@ -375,11 +378,22 @@ def configure_subscriber( f"Failed to reconfigure existing subscriber to {topic}" ) - assert config.subscriber_callback is not None + msg_type = import_message_from_str(config.msg_type) + + def callback_wrapper(message): + text = message.data + print(text) + assert config.subscriber_callback is not None + config.subscriber_callback( + ROS2HRIMessage( + HRIPayload(text=text), message_author=config.source_author + ) + ) + self._subscribtions[topic] = self._node.create_subscription( - msg_type=import_message_from_str(config.msg_type), + msg_type=msg_type, topic=topic, - callback=config.subscriber_callback, + callback=callback_wrapper, qos_profile=qos_profile, ) @@ -398,7 +412,7 @@ def publish_configured(self, topic: str, msg_content: dict[str, Any]) -> None: except Exception as e: raise ValueError(f"{topic} has not been configured for publishing") from e msg_type = publisher.msg_type - msg = build_ros2_msg(msg_type, msg_content) # type: ignore + msg = build_ros2_msg(msg_type, {"data": msg_content.text}) # type: ignore publisher.publish(msg) diff --git a/src/rai_core/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py index 452dae22e..93dda60c9 100644 --- a/src/rai_core/rai/communication/ros2/connectors.py +++ b/src/rai_core/rai/communication/ros2/connectors.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast import numpy as np +import rai_interfaces.msg import rclpy import rclpy.executors import rclpy.node @@ -27,6 +28,8 @@ from cv_bridge import CvBridge from PIL import Image from pydub import AudioSegment +from rai_interfaces.msg import HRIMessage as ROS2HRIMessage_ +from rai_interfaces.msg._audio_message import AudioMessage as ROS2HRIMessage__Audio from rclpy.duration import Duration from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node @@ -34,7 +37,6 @@ from sensor_msgs.msg import Image as ROS2Image from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped -import rai_interfaces.msg from rai.communication import ( ARIConnector, ARIMessage, @@ -45,13 +47,10 @@ from rai.communication.ros2.api import ( ConfigurableROS2TopicAPI, ROS2ActionAPI, + ROS2ARIMessage, + ROS2HRIMessage, ROS2ServiceAPI, ROS2TopicAPI, - TopicConfig, -) -from rai_interfaces.msg import HRIMessage as ROS2HRIMessage_ -from rai_interfaces.msg._audio_message import ( - AudioMessage as ROS2HRIMessage__Audio, ) @@ -279,15 +278,19 @@ def __init__( ] _targets = [ - target - if isinstance(target, tuple) - else (target, TopicConfig(is_subscriber=False)) + ( + target + if isinstance(target, tuple) + else (target, TopicConfig(is_subscriber=False)) + ) for target in targets ] _sources = [ - source - if isinstance(source, tuple) - else (source, TopicConfig(is_subscriber=True)) + ( + source + if isinstance(source, tuple) + else (source, TopicConfig(is_subscriber=True)) + ) for source in sources ] @@ -306,6 +309,9 @@ def __init__( self._thread = threading.Thread(target=self._executor.spin) self._thread.start() + # def run(self): + # self._executor.spin() + def _configure_publishers(self, targets: List[Tuple[str, TopicConfig]]): for target in targets: self._topic_api.configure_publisher(target[0], target[1]) diff --git a/src/rai_core/rai/communication/ros2/messages.py b/src/rai_core/rai/communication/ros2/messages.py new file mode 100644 index 000000000..3c5a3a8c2 --- /dev/null +++ b/src/rai_core/rai/communication/ros2/messages.py @@ -0,0 +1,27 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Literal, Optional + +from rai.communication import ARIMessage, HRIMessage, HRIPayload + + +class ROS2ARIMessage(ARIMessage): + def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None): + super().__init__(payload, metadata) + + +class ROS2HRIMessage(HRIMessage): + def __init__(self, payload: HRIPayload, message_author: Literal["ai", "human"]): + super().__init__(payload, message_author) diff --git a/src/rai_core/rai/communication/sound_device/api.py b/src/rai_core/rai/communication/sound_device/api.py index 31439f9ed..d1c83a7de 100644 --- a/src/rai_core/rai/communication/sound_device/api.py +++ b/src/rai_core/rai/communication/sound_device/api.py @@ -190,6 +190,7 @@ def open_write_stream( feed_data: Callable[[NDArray, int, Any, Any], None], on_done: Callable = lambda _: None, sample_rate: Optional[int] = None, + channels: Optional[int] = None, ): if not self.write_flag or not self.stream_flag: raise SoundDeviceError( @@ -213,9 +214,11 @@ def callback(indata: NDArray, frames: int, time: Any, status: CallbackFlags): try: assert sd is not None sample_rate = self.sample_rate if sample_rate is None else sample_rate + print(sample_rate) + channels = self.config.channels if channels is None else channels self.out_stream = sd.OutputStream( samplerate=sample_rate, - channels=self.config.channels, + channels=channels, device=self.device_number, dtype=self.config.dtype, callback=callback, diff --git a/src/rai_core/rai/communication/sound_device/connector.py b/src/rai_core/rai/communication/sound_device/connector.py index 65f4165bf..48f29b28b 100644 --- a/src/rai_core/rai/communication/sound_device/connector.py +++ b/src/rai_core/rai/communication/sound_device/connector.py @@ -149,7 +149,12 @@ def start_action( self.devices[target].open_read_stream(on_feedback, on_done) self.action_handles[handle] = (target, True) else: - self.devices[target].open_write_stream(on_feedback, on_done) + sample_rate = kwargs.get("sample_rate", None) + channels = kwargs.get("channels", None) + + self.devices[target].open_write_stream( + on_feedback, on_done, sample_rate, channels + ) self.action_handles[handle] = (target, False) return handle diff --git a/src/rai_tts/rai_tts/models/__init__.py b/src/rai_tts/rai_tts/models/__init__.py index bcc4a3822..b1187962b 100644 --- a/src/rai_tts/rai_tts/models/__init__.py +++ b/src/rai_tts/rai_tts/models/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import SpeechResult, TTSModel, TTSModelError +from .base import TTSModel, TTSModelError from .open_tts import OpenTTS -__all__ = ["TTSModel", "SpeechResult", "TTSModelError", "OpenTTS"] +__all__ = ["OpenTTS", "TTSModel", "TTSModelError"] diff --git a/src/rai_tts/rai_tts/models/base.py b/src/rai_tts/rai_tts/models/base.py index 0f498edc6..6014cc50c 100644 --- a/src/rai_tts/rai_tts/models/base.py +++ b/src/rai_tts/rai_tts/models/base.py @@ -13,14 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import NamedTuple +from typing import Tuple -from numpy.typing import NDArray - - -class SpeechResult(NamedTuple): - speech: NDArray - sample_rate: int +from pydub import AudioSegment class TTSModelError(Exception): @@ -29,5 +24,9 @@ class TTSModelError(Exception): class TTSModel(ABC): @abstractmethod - def get_speech(self, text: str) -> SpeechResult: + def get_speech(self, text: str) -> AudioSegment: + pass + + @abstractmethod + def get_tts_params(self) -> Tuple[int, int]: pass diff --git a/src/rai_tts/rai_tts/models/open_tts.py b/src/rai_tts/rai_tts/models/open_tts.py index ab56b0037..641ab0f7e 100644 --- a/src/rai_tts/rai_tts/models/open_tts.py +++ b/src/rai_tts/rai_tts/models/open_tts.py @@ -13,12 +13,14 @@ # limitations under the License. from io import BytesIO +from typing import Tuple import numpy as np import requests +from pydub import AudioSegment from scipy.io.wavfile import read -from rai_tts.models import SpeechResult, TTSModel, TTSModelError +from rai_tts.models import TTSModel, TTSModelError class OpenTTS(TTSModel): @@ -30,13 +32,13 @@ def __init__( self.url = url self.voice = voice - def get_speech(self, text: str) -> SpeechResult: + def get_speech(self, text: str) -> AudioSegment: params = { "voice": self.voice, "text": text, } try: - response = requests.get("http://localhost:5500/api/tts", params=params) + response = requests.get(self.url, params=params) except requests.exceptions.RequestException as e: raise TTSModelError( f"Error occurred while fetching audio: {e}, check if OpenTTS server is running correctly." @@ -59,4 +61,9 @@ def get_speech(self, text: str) -> SpeechResult: (data * 32768).clip(-32768, 32767).astype(np.int16) ) # Convert float32 to int16 - return SpeechResult(data, sample_rate) + return AudioSegment(data, frame_rate=sample_rate, sample_width=2, channels=1) + + def get_tts_params(self) -> Tuple[int, int]: + data = self.get_speech("A") + print(data.frame_rate) + return data.frame_rate, 1 From 87adc6518ae24dddcefcf4dc0f1d635d3c67facf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 11 Feb 2025 14:35:44 +0100 Subject: [PATCH 04/30] feat: working S2S --- src/rai_core/rai/agents/tts_agent.py | 10 +++++++++- src/rai_core/rai/agents/voice_agent.py | 20 ++++++++++++-------- src/rai_core/rai/communication/ros2/api.py | 1 - src/rai_tts/rai_tts/models/open_tts.py | 1 - 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index d41b1a8ef..d543302be 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -115,7 +115,7 @@ def _speaker_callback(self, outdata, frames, time, status_dict): set_flags = [flag for flag, status in status_dict.items() if status] if set_flags: - self.logger.warning("Flags set:", ", ".join(set_flags)) + self.logger.warning("Flags set:" + ", ".join(set_flags)) if self.playback_data.playing: if self.playback_data.current_segment is None: try: @@ -182,12 +182,20 @@ def _setup_ros2_connector(self): ) def _on_to_human_message(self, message: ROS2HRIMessage): + self.logger.info(f"Receieved message from human: {message.text}") self.text_queue.put(message.text) def _on_command_message(self, message: ROS2HRIMessage): if message.text == "tog_play": self.playback_data.playing = not self.playback_data.playing + elif message.text == "play": + self.playback_data.playing = True + elif message.text == "pause": + self.playback_data.playing = False elif message.text == "stop": self.playback_data.playing = False while not self.audio_queue.empty(): _ = self.audio_queue.get() + self.playback_data.data = None + self.playback_data.current_frame = 0 + self.playback_data.current_segment = None diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index a49fde57b..c72540fe4 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -159,7 +159,7 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): if voice_detected: self.logger.debug("Voice detected... resetting grace period") self.grace_period_start = sample_time - + self._send_ros2_message("pause", "/voice_commands") if ( self.recording_started and sample_time - self.grace_period_start > self.grace_period @@ -174,12 +174,16 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.sample_buffer = [] self.transcription_threads[self.active_thread]["thread"].start() self.active_thread = "" + self._send_ros2_message("stop", "/voice_commands") + elif sample_time - self.grace_period_start > self.grace_period: + self._send_ros2_message("play", "/voice_commands") def should_record( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> bool: for model in self.should_record_pipeline: detected, output = model(audio_data, input_parameters) + self.logger.info(f"detected {detected}, output {output}") if detected: return True return False @@ -191,13 +195,13 @@ def transcription_thread(self, identifier: str): self.transcription_lock ): # this is only necessary for the local model... TODO: fix this somehow transcription = self.transcription_model.transcribe(audio_data) - assert isinstance(self.connectors["ros2"], ROS2ARIConnector) + self._send_ros2_message(transcription, "/from_human") + self.transcription_threads[identifier]["transcription"] = transcription + self.transcription_threads[identifier]["event"].set() + + def _send_ros2_message(self, data: str, topic: str): self.connectors["ros2"].send_message( - ROS2ARIMessage( - {"data": transcription}, {"msg_type": "std_msgs/msg/String"} - ), - "/from_human", + ROS2ARIMessage({"data": data}, {"msg_type": "std_msgs/msg/String"}), + topic, msg_type="std_msgs/msg/String", ) - self.transcription_threads[identifier]["transcription"] = transcription - self.transcription_threads[identifier]["event"].set() diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index b9f2b3db2..a0e2cec86 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -382,7 +382,6 @@ def configure_subscriber( def callback_wrapper(message): text = message.data - print(text) assert config.subscriber_callback is not None config.subscriber_callback( ROS2HRIMessage( diff --git a/src/rai_tts/rai_tts/models/open_tts.py b/src/rai_tts/rai_tts/models/open_tts.py index 641ab0f7e..e1e80d33a 100644 --- a/src/rai_tts/rai_tts/models/open_tts.py +++ b/src/rai_tts/rai_tts/models/open_tts.py @@ -65,5 +65,4 @@ def get_speech(self, text: str) -> AudioSegment: def get_tts_params(self) -> Tuple[int, int]: data = self.get_speech("A") - print(data.frame_rate) return data.frame_rate, 1 From f2eef35e79a73b15d03455fc1ff8a83e3728290b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 11 Feb 2025 14:47:44 +0100 Subject: [PATCH 05/30] feat: add agent runner --- src/rai_core/rai/agents/base.py | 6 ++- src/rai_core/rai/agents/runner.py | 64 +++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 src/rai_core/rai/agents/runner.py diff --git a/src/rai_core/rai/agents/base.py b/src/rai_core/rai/agents/base.py index c2dd4fe50..838b8e044 100644 --- a/src/rai_core/rai/agents/base.py +++ b/src/rai_core/rai/agents/base.py @@ -28,5 +28,9 @@ def __init__( self.connectors: dict[str, BaseConnector] = connectors @abstractmethod - def run(self, *args, **kwargs): + def run(self): + pass + + @abstractmethod + def stop(self): pass diff --git a/src/rai_core/rai/agents/runner.py b/src/rai_core/rai/agents/runner.py new file mode 100644 index 000000000..16a69848e --- /dev/null +++ b/src/rai_core/rai/agents/runner.py @@ -0,0 +1,64 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import threading + +from rai.agents.base import BaseAgent + + +class AgentRunner: + """ + Manages and runs a collection of agents. + + Parameters + ---------- + agents : list of BaseAgent + A list of agent instances that implement `run` and `stop` methods. + + Attributes + ---------- + agents : list of BaseAgent + The list of agents managed by this runner. + """ + + def __init__(self, agents: list[BaseAgent]): + """ + Initializes the AgentRunner with a list of agents. + + Parameters + ---------- + agents : list of BaseAgent + A list of agent instances that will be managed and executed. + """ + self.agents = agents + + def run_indefinitely(self): + """ + Starts all agents and keeps them running indefinitely. + + This method runs each agent's `run` method and waits indefinitely for a stop signal. + If a `KeyboardInterrupt` is received, it logs the interruption and stops all agents. + """ + for agent in self.agents: + agent.run() + + stop_signal = threading.Event() + + try: + stop_signal.wait() + except KeyboardInterrupt: + logging.info("KeyboardInterrupt received! Shutting down...") + + for agent in self.agents: + agent.stop() From 1591a596ce2a69395c5e634e29ce5fec2a0d05c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 11 Feb 2025 17:44:29 +0100 Subject: [PATCH 06/30] chore: add runner to __init__ --- src/rai_core/rai/agents/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rai_core/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py index fcbc5517f..54037b49d 100644 --- a/src/rai_core/rai/agents/__init__.py +++ b/src/rai_core/rai/agents/__init__.py @@ -13,12 +13,14 @@ # limitations under the License. from rai.agents.conversational_agent import create_conversational_agent +from rai.agents.runner import AgentRunner from rai.agents.state_based import create_state_based_agent from rai.agents.tool_runner import ToolRunner from rai.agents.tts_agent import TextToSpeechAgent from rai.agents.voice_agent import VoiceRecognitionAgent __all__ = [ + "AgentRunner", "TextToSpeechAgent", "ToolRunner", "VoiceRecognitionAgent", From b018bb5b5fc445c368d82b9e44fcedc3b2213110 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 11 Feb 2025 18:00:31 +0100 Subject: [PATCH 07/30] fix: working demo after rebase --- src/rai_core/rai/agents/tts_agent.py | 2 - src/rai_core/rai/agents/voice_agent.py | 2 +- src/rai_core/rai/communication/ros2/api.py | 1 + .../rai/communication/ros2/connectors.py | 85 +------------------ .../rai/communication/ros2/messages.py | 66 +++++++++++++- 5 files changed, 70 insertions(+), 86 deletions(-) diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index d543302be..7e172de80 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -161,7 +161,6 @@ def _transcription_thread(self): def _setup_ros2_connector(self): to_human = TopicConfig( - name="/to_human", msg_type="std_msgs/msg/String", auto_qos_matching=True, is_subscriber=True, @@ -169,7 +168,6 @@ def _setup_ros2_connector(self): source_author="ai", ) voice_commands = TopicConfig( - name="/voice_commands", msg_type="std_msgs/msg/String", auto_qos_matching=True, is_subscriber=True, diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index c72540fe4..2725c6dad 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -183,7 +183,7 @@ def should_record( ) -> bool: for model in self.should_record_pipeline: detected, output = model(audio_data, input_parameters) - self.logger.info(f"detected {detected}, output {output}") + self.logger.debug(f"detected {detected}, output {output}") if detected: return True return False diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index a0e2cec86..4cb9d922e 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -327,6 +327,7 @@ class TopicConfig: qos_profile: Optional[QoSProfile] = None is_subscriber: bool = False subscriber_callback: Optional[Callable[[ROS2HRIMessage], None]] = None + source_author: Literal["human", "ai"] = "ai" def __post_init__(self): if not self.auto_qos_matching and self.qos_profile is None: diff --git a/src/rai_core/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py index 93dda60c9..170781326 100644 --- a/src/rai_core/rai/communication/ros2/connectors.py +++ b/src/rai_core/rai/communication/ros2/connectors.py @@ -15,48 +15,27 @@ import threading import time import uuid -from collections import OrderedDict -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Callable, List, Literal, Optional, Tuple, Union -import numpy as np -import rai_interfaces.msg import rclpy import rclpy.executors import rclpy.node import rclpy.time -import rosidl_runtime_py.convert -from cv_bridge import CvBridge -from PIL import Image -from pydub import AudioSegment -from rai_interfaces.msg import HRIMessage as ROS2HRIMessage_ -from rai_interfaces.msg._audio_message import AudioMessage as ROS2HRIMessage__Audio from rclpy.duration import Duration from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node from rclpy.qos import QoSProfile -from sensor_msgs.msg import Image as ROS2Image from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped -from rai.communication import ( - ARIConnector, - ARIMessage, - HRIConnector, - HRIMessage, - HRIPayload, -) +from rai.communication import ARIConnector, HRIConnector from rai.communication.ros2.api import ( ConfigurableROS2TopicAPI, ROS2ActionAPI, - ROS2ARIMessage, - ROS2HRIMessage, ROS2ServiceAPI, ROS2TopicAPI, + TopicConfig, ) - - -class ROS2ARIMessage(ARIMessage): - def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None): - super().__init__(payload, metadata) +from rai.communication.ros2.messages import ROS2ARIMessage, ROS2HRIMessage class ROS2ARIConnector(ARIConnector[ROS2ARIMessage]): @@ -207,62 +186,6 @@ def shutdown(self): self._node.destroy_node() -class ROS2HRIMessage(HRIMessage): - def __init__(self, payload: HRIPayload, message_author: Literal["ai", "human"]): - super().__init__(payload, message_author) - - @classmethod - def from_ros2( - cls, msg: rai_interfaces.msg.HRIMessage, message_author: Literal["ai", "human"] - ): - cv_bridge = CvBridge() - images = [ - cv_bridge.imgmsg_to_cv2(img_msg, "rgb8") - for img_msg in cast(List[ROS2Image], msg.images) - ] - pil_images = [Image.fromarray(img) for img in images] - audio_segments = [ - AudioSegment( - data=audio_msg.audio, - frame_rate=audio_msg.sample_rate, - sample_width=2, # bytes, int16 - channels=audio_msg.channels, - ) - for audio_msg in msg.audios - ] - return ROS2HRIMessage( - payload=HRIPayload(text=msg.text, images=pil_images, audios=audio_segments), - message_author=message_author, - ) - - def to_ros2_dict(self) -> OrderedDict[str, Any]: - cv_bridge = CvBridge() - assert isinstance(self.payload, HRIPayload) - img_msgs = [ - cv_bridge.cv2_to_imgmsg(np.array(img), "rgb8") - for img in self.payload.images - ] - audio_msgs = [ - ROS2HRIMessage__Audio( - audio=audio.raw_data, - sample_rate=audio.frame_rate, - channels=audio.channels, - ) - for audio in self.payload.audios - ] - - return cast( - OrderedDict[str, Any], - rosidl_runtime_py.convert.message_to_ordereddict( - ROS2HRIMessage_( - text=self.payload.text, - images=img_msgs, - audios=audio_msgs, - ) - ), - ) - - class ROS2HRIConnector(HRIConnector[ROS2HRIMessage]): def __init__( self, diff --git a/src/rai_core/rai/communication/ros2/messages.py b/src/rai_core/rai/communication/ros2/messages.py index 3c5a3a8c2..5e013f500 100644 --- a/src/rai_core/rai/communication/ros2/messages.py +++ b/src/rai_core/rai/communication/ros2/messages.py @@ -12,9 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Literal, Optional +from collections import OrderedDict +from typing import Any, Dict, List, Literal, Optional, cast +import numpy as np +import rosidl_runtime_py.convert +from cv_bridge import CvBridge +from PIL import Image +from pydub import AudioSegment +from sensor_msgs.msg import Image as ROS2Image + +import rai_interfaces.msg from rai.communication import ARIMessage, HRIMessage, HRIPayload +from rai_interfaces.msg import HRIMessage as ROS2HRIMessage_ +from rai_interfaces.msg._audio_message import AudioMessage as ROS2HRIMessage__Audio class ROS2ARIMessage(ARIMessage): @@ -24,4 +35,55 @@ def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None): class ROS2HRIMessage(HRIMessage): def __init__(self, payload: HRIPayload, message_author: Literal["ai", "human"]): - super().__init__(payload, message_author) + super().__init__(payload, {}, message_author) + + @classmethod + def from_ros2( + cls, msg: rai_interfaces.msg.HRIMessage, message_author: Literal["ai", "human"] + ): + cv_bridge = CvBridge() + images = [ + cv_bridge.imgmsg_to_cv2(img_msg, "rgb8") + for img_msg in cast(List[ROS2Image], msg.images) + ] + pil_images = [Image.fromarray(img) for img in images] + audio_segments = [ + AudioSegment( + data=audio_msg.audio, + frame_rate=audio_msg.sample_rate, + sample_width=2, # bytes, int16 + channels=audio_msg.channels, + ) + for audio_msg in msg.audios + ] + return ROS2HRIMessage( + payload=HRIPayload(text=msg.text, images=pil_images, audios=audio_segments), + message_author=message_author, + ) + + def to_ros2_dict(self) -> OrderedDict[str, Any]: + cv_bridge = CvBridge() + assert isinstance(self.payload, HRIPayload) + img_msgs = [ + cv_bridge.cv2_to_imgmsg(np.array(img), "rgb8") + for img in self.payload.images + ] + audio_msgs = [ + ROS2HRIMessage__Audio( + audio=audio.raw_data, + sample_rate=audio.frame_rate, + channels=audio.channels, + ) + for audio in self.payload.audios + ] + + return cast( + OrderedDict[str, Any], + rosidl_runtime_py.convert.message_to_ordereddict( + ROS2HRIMessage_( + text=self.payload.text, + images=img_msgs, + audios=audio_msgs, + ) + ), + ) From 8644b5b6eeb157a0359c21946c02d43070576bb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 12 Feb 2025 11:54:26 +0100 Subject: [PATCH 08/30] feat: add runners to create configurable, multi-agent deployments --- examples/s2s-demo.py | 19 +++++ src/rai_core/rai/agents/__init__.py | 2 - src/rai_core/rai/runners/__init__.py | 18 +++++ .../rai/{agents/runner.py => runners/base.py} | 4 +- src/rai_core/rai/runners/s2s.py | 81 +++++++++++++++++++ 5 files changed, 120 insertions(+), 4 deletions(-) create mode 100644 examples/s2s-demo.py create mode 100644 src/rai_core/rai/runners/__init__.py rename src/rai_core/rai/{agents/runner.py => runners/base.py} (95%) create mode 100644 src/rai_core/rai/runners/s2s.py diff --git a/examples/s2s-demo.py b/examples/s2s-demo.py new file mode 100644 index 000000000..33618c621 --- /dev/null +++ b/examples/s2s-demo.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import rclpy +from rai.runners import Speech2SpeechRunner + +rclpy.init() +runner = Speech2SpeechRunner() +runner.run_indefinitely() diff --git a/src/rai_core/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py index 54037b49d..fcbc5517f 100644 --- a/src/rai_core/rai/agents/__init__.py +++ b/src/rai_core/rai/agents/__init__.py @@ -13,14 +13,12 @@ # limitations under the License. from rai.agents.conversational_agent import create_conversational_agent -from rai.agents.runner import AgentRunner from rai.agents.state_based import create_state_based_agent from rai.agents.tool_runner import ToolRunner from rai.agents.tts_agent import TextToSpeechAgent from rai.agents.voice_agent import VoiceRecognitionAgent __all__ = [ - "AgentRunner", "TextToSpeechAgent", "ToolRunner", "VoiceRecognitionAgent", diff --git a/src/rai_core/rai/runners/__init__.py b/src/rai_core/rai/runners/__init__.py new file mode 100644 index 000000000..de1f7fae3 --- /dev/null +++ b/src/rai_core/rai/runners/__init__.py @@ -0,0 +1,18 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseRunner +from .s2s import Speech2SpeechRunner + +__all__ = ["BaseRunner", "Speech2SpeechRunner"] diff --git a/src/rai_core/rai/agents/runner.py b/src/rai_core/rai/runners/base.py similarity index 95% rename from src/rai_core/rai/agents/runner.py rename to src/rai_core/rai/runners/base.py index 16a69848e..d0dfb4f20 100644 --- a/src/rai_core/rai/agents/runner.py +++ b/src/rai_core/rai/runners/base.py @@ -17,7 +17,7 @@ from rai.agents.base import BaseAgent -class AgentRunner: +class BaseRunner: """ Manages and runs a collection of agents. @@ -34,7 +34,7 @@ class AgentRunner: def __init__(self, agents: list[BaseAgent]): """ - Initializes the AgentRunner with a list of agents. + Initializes the BaseRunner with a list of agents. Parameters ---------- diff --git a/src/rai_core/rai/runners/s2s.py b/src/rai_core/rai/runners/s2s.py new file mode 100644 index 000000000..af4678ca5 --- /dev/null +++ b/src/rai_core/rai/runners/s2s.py @@ -0,0 +1,81 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +from rai.agents import TextToSpeechAgent, VoiceRecognitionAgent +from rai.agents.base import BaseAgent +from rai.communication.sound_device import SoundDeviceConfig +from rai.runners import BaseRunner +from rai_asr.models import LocalWhisper, OpenWakeWord, SileroVAD +from rai_tts.models import OpenTTS + + +@dataclass +class TTSConfig: + device_name: str = "default" + url: str = "http://localhost:5500/api/tts" + voice: str = "larynx:blizzard_lessac-glow_tts" + + +@dataclass +class ASRConfig: + device_name: str = "default" + vad_threshold: float = 0.5 + oww_threshold: float = 0.1 + whisper_model: str = "tiny" + oww_model: str = "hey jarvis" + + +class Speech2SpeechRunner(BaseRunner): + def __init__( + self, + agents: Optional[list[BaseAgent]] = None, + tts_cfg: TTSConfig = TTSConfig(), + asr_config: ASRConfig = ASRConfig(), + ): + if agents is None: + agents = [] + super().__init__(agents) + tts = self._setup_tts_agent(tts_cfg) + asr = self._setup_asr_agent(asr_config) + self.agents.extend([tts, asr]) + + def _setup_tts_agent(self, cfg: TTSConfig): + speaker_config = SoundDeviceConfig( + stream=True, + is_output=True, + device_name=cfg.device_name, + ) + tts = OpenTTS(cfg.url, cfg.voice) + return TextToSpeechAgent(speaker_config, "text_to_speech", tts) + + def _setup_asr_agent(self, cfg: ASRConfig): + vad = SileroVAD(threshold=cfg.vad_threshold) + oww = OpenWakeWord(cfg.oww_model, cfg.oww_threshold) + whisper = LocalWhisper( + cfg.whisper_model, vad.sampling_rate + ) # models should have compatible sampling rate + microphone_config = SoundDeviceConfig( + stream=True, + device_name=cfg.device_name, + consumer_sampling_rate=vad.sampling_rate, + is_input=True, + ) + asr_agent = VoiceRecognitionAgent( + microphone_config, "automatic_speech_recognition", whisper, vad + ) + asr_agent.add_detection_model(oww, pipeline="record") + return asr_agent From 64b55b89f74df1eaa8c52a41ba95ac9fd33d20d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 12 Feb 2025 11:56:26 +0100 Subject: [PATCH 09/30] diocs: add docstrings for affected classes --- src/rai_asr/rai_asr/models/base.py | 102 ++++++++++++++++++ src/rai_asr/rai_asr/models/local_whisper.py | 81 ++++++++++++++ src/rai_asr/rai_asr/models/open_wake_word.py | 68 +++++++++++- src/rai_asr/rai_asr/models/silero_vad.py | 58 +++++++++- src/rai_core/rai/agents/tts_agent.py | 21 ++++ src/rai_core/rai/agents/voice_agent.py | 54 ++++++++-- .../rai/communication/sound_device/api.py | 34 ++++++ src/rai_core/rai/runners/s2s.py | 13 +++ src/rai_tts/rai_tts/models/open_tts.py | 52 ++++++++- 9 files changed, 472 insertions(+), 11 deletions(-) diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py index 1cf62a14c..a1c9e4c5b 100644 --- a/src/rai_asr/rai_asr/models/base.py +++ b/src/rai_asr/rai_asr/models/base.py @@ -21,20 +21,107 @@ class BaseVoiceDetectionModel(ABC): + """ + Abstract base class for voice detection models. + + This class provides a standard interface for voice detection models, where + subclasses must implement the `detect` method to process audio data and determine + whether a specific event (e.g., a wake word or other voice activity) has occurred. + """ + def __call__( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: + """ + Invokes the model to detect a voice event. + + This method calls the `detect` function, allowing the model instance to be used + as a callable. + + Parameters + ---------- + audio_data : NDArray + A NumPy array containing audio input data. + input_parameters : dict of str to Any + Additional parameters for detection. + + Returns + ------- + Tuple[bool, dict] + A tuple where the first value is a boolean indicating whether a voice event + was detected (`True` if detected, `False` otherwise). The second value is + a dictionary containing additional detection information. + """ return self.detect(audio_data, input_parameters) @abstractmethod def detect( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: + """ + Abstract method for detecting a voice event. + + Subclasses must implement this method to analyze audio data and determine + whether a specific voice-related event has occurred. + + Parameters + ---------- + audio_data : NDArray + A NumPy array containing audio input data. + input_parameters : dict of str to Any + Additional parameters for detection. + + Returns + ------- + Tuple[bool, dict] + A tuple where the first value is a boolean indicating whether a voice event + was detected (`True` if detected, `False` otherwise). The second value is + a dictionary containing additional detection information. + """ pass class BaseTranscriptionModel(ABC): + """ + Abstract base class for speech transcription models. + + This class provides a standardized interface for speech-to-text models, where + subclasses must implement the `transcribe` method to convert audio input into text. + + Parameters + ---------- + model_name : str + The name of the transcription model. + sample_rate : int + The sample rate of the input audio, in Hz. + language : str, optional + The language of the transcription output. Default is "en" (English). + + Attributes + ---------- + model_name : str + The name of the transcription model. + sample_rate : int + The sample rate of the input audio, in Hz. + language : str + The language of the transcription output. + latest_transcription : str + Stores the latest transcribed text. + """ + def __init__(self, model_name: str, sample_rate: int, language: str = "en"): + """ + Initializes the BaseTranscriptionModel with the model name, sample rate, and language. + + Parameters + ---------- + model_name : str + The name of the transcription model. + sample_rate : int + The sample rate of the input audio, in Hz. + language : str, optional + The language of the transcription output. Default is "en" (English). + """ self.model_name = model_name self.sample_rate = sample_rate self.language = language @@ -43,4 +130,19 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): @abstractmethod def transcribe(self, data: NDArray[np.int16]) -> str: + """ + Abstract method for transcribing speech from audio data. + + Subclasses must implement this method to convert the provided audio input into text. + + Parameters + ---------- + data : NDArray[np.int16] + A NumPy array containing the raw audio waveform data. + + Returns + ------- + str + The transcribed text from the audio input. + """ pass diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index f8292e339..afd16433d 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -25,6 +25,31 @@ class LocalWhisper(BaseTranscriptionModel): + """ + A transcription model using OpenAI's Whisper, running locally. + + This class loads a Whisper model and performs speech-to-text transcription + on audio data. It supports GPU acceleration if available. + + Parameters + ---------- + model_name : str + The name of the Whisper model to load. + sample_rate : int + The sample rate of the input audio, in Hz. + language : str, optional + The language of the transcription output. Default is "en" (English). + **kwargs : dict, optional + Additional keyword arguments for loading the Whisper model. + + Attributes + ---------- + whisper : whisper.Whisper + The loaded Whisper model for transcription. + logger : logging.Logger + Logger instance for logging transcription results. + """ + def __init__( self, model_name: str, sample_rate: int, language: str = "en", **kwargs ): @@ -37,6 +62,22 @@ def __init__( self.logger = logging.getLogger(__name__) def transcribe(self, data: NDArray[np.int16]) -> str: + """ + Transcribes speech from the given audio data using Whisper. + + This method normalizes the input audio, processes it using the Whisper model, + and returns the transcribed text. + + Parameters + ---------- + data : NDArray[np.int16] + A NumPy array containing the raw audio waveform data. + + Returns + ------- + str + The transcribed text from the audio input. + """ normalized_data = data.astype(np.float32) / 32768.0 result = whisper.transcribe( self.whisper, normalized_data @@ -49,6 +90,30 @@ def transcribe(self, data: NDArray[np.int16]) -> str: class FasterWhisper(BaseTranscriptionModel): + """ + A transcription model using Faster Whisper for efficient speech-to-text conversion. + + This class loads a Faster Whisper model, optimized for speed and efficiency. + + Parameters + ---------- + model_name : str + The name of the Faster Whisper model to load. + sample_rate : int + The sample rate of the input audio, in Hz. + language : str, optional + The language of the transcription output. Default is "en" (English). + **kwargs : dict, optional + Additional keyword arguments for loading the Faster Whisper model. + + Attributes + ---------- + model : WhisperModel + The loaded Faster Whisper model instance. + logger : logging.Logger + Logger instance for logging transcription results. + """ + def __init__( self, model_name: str, sample_rate: int, language: str = "en", **kwargs ): @@ -57,6 +122,22 @@ def __init__( self.logger = logging.getLogger(__name__) def transcribe(self, data: NDArray[np.int16]) -> str: + """ + Transcribes speech from the given audio data using Faster Whisper. + + This method normalizes the input audio, processes it using the Faster Whisper model, + and returns the transcribed text. + + Parameters + ---------- + data : NDArray[np.int16] + A NumPy array containing the raw audio waveform data. + + Returns + ------- + str + The transcribed text from the audio input. + """ normalized_data = data.astype(np.float32) / 32768.0 segments, _ = self.model.transcribe(normalized_data) transcription = " ".join(segment.text for segment in segments) diff --git a/src/rai_asr/rai_asr/models/open_wake_word.py b/src/rai_asr/rai_asr/models/open_wake_word.py index 1fb4211e0..c084eaff7 100644 --- a/src/rai_asr/rai_asr/models/open_wake_word.py +++ b/src/rai_asr/rai_asr/models/open_wake_word.py @@ -22,7 +22,41 @@ class OpenWakeWord(BaseVoiceDetectionModel): - def __init__(self, wake_word_model_path: str, threshold: float = 0.5): + """ + A wake word detection model using the Open Wake Word framework. + + This class loads a specified wake word model and detects whether a wake word is present + in the provided audio input. + + Parameters + ---------- + wake_word_model_path : str + Path to the wake word model file or name of a standard one. + threshold : float, optional + The confidence threshold for wake word detection. If a prediction surpasses this + value, the model will trigger a wake word detection. Default is 0.1. + + Attributes + ---------- + model_name : str + The name of the model, set to `"open_wake_word"`. + model : OWWModel + The Open Wake Word model instance used for inference. + threshold : float + The confidence threshold for determining wake word detection. + """ + + def __init__(self, wake_word_model_path: str, threshold: float = 0.1): + """ + Initializes the OpenWakeWord detection model. + + Parameters + ---------- + wake_word_model_path : str + Path to the wake word model file. + threshold : float, optional + Confidence threshold for wake word detection. Default is 0.1. + """ super(OpenWakeWord, self).__init__() self.model_name = "open_wake_word" download_models() @@ -37,10 +71,40 @@ def __init__(self, wake_word_model_path: str, threshold: float = 0.5): def detect( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: + """ + Detects whether a wake word is present in the given audio data. + + This method runs inference on the provided audio data and determines whether + the detected confidence surpasses the threshold. If so, it resets the model + and returns `True`, indicating a wake word detection. + + Parameters + ---------- + audio_data : NDArray + A NumPy array representing the input audio data. + input_parameters : dict of str to Any + Additional input parameters to be included in the output. + + Returns + ------- + Tuple[bool, dict] + A tuple where the first value is a boolean indicating whether the wake word + was detected (`True` if detected, `False` otherwise). The second value is + a dictionary containing predictions and confidence values for them. + + Raises + ------ + Exception + If the predictions returned by the model are not in the expected dictionary format. + """ predictions = self.model.predict(audio_data) ret = input_parameters.copy() ret.update({self.model_name: {"predictions": predictions}}) - for key, value in predictions.items(): + if not isinstance(predictions, dict): + raise Exception( + f"Unexpected format from model predict {type(predictions)}:{predictions}" + ) + for _, value in predictions.items(): # type ignore if value > self.threshold: self.model.reset() return True, ret diff --git a/src/rai_asr/rai_asr/models/silero_vad.py b/src/rai_asr/rai_asr/models/silero_vad.py index fdecb8b5b..dd325f836 100644 --- a/src/rai_asr/rai_asr/models/silero_vad.py +++ b/src/rai_asr/rai_asr/models/silero_vad.py @@ -22,6 +22,41 @@ class SileroVAD(BaseVoiceDetectionModel): + """ + Voice Activity Detection (VAD) model using SileroVAD. + + This class loads the SileroVAD model from Torch Hub and detects speech presence in an audio signal. + It supports two sampling rates: 8000 Hz and 16000 Hz. + + Parameters + ---------- + sampling_rate : Literal[8000, 16000], optional + The sampling rate of the input audio. Must be either 8000 or 16000. Default is 16000. + threshold : float, optional + Confidence threshold for voice detection. If the VAD confidence exceeds this threshold, + the method returns `True` (indicating voice presence). Default is 0.5. + + Attributes + ---------- + model_name : str + Name of the VAD model, set to `"silero_vad"`. + model : torch.nn.Module + The loaded SileroVAD model. + sampling_rate : int + The sampling rate of the input audio (either 8000 or 16000). + window_size : int + The size of the processing window, determined by the sampling rate. + - 512 samples for 16000 Hz + - 256 samples for 8000 Hz + threshold : float + Confidence threshold for determining voice activity. + + Raises + ------ + ValueError + If an unsupported sampling rate is provided. + """ + def __init__(self, sampling_rate: Literal[8000, 16000] = 16000, threshold=0.5): super(SileroVAD, self).__init__() self.model_name = "silero_vad" @@ -42,7 +77,7 @@ def __init__(self, sampling_rate: Literal[8000, 16000] = 16000, threshold=0.5): ) # TODO: consider if this should be a ValueError or something else self.threshold = threshold - def int2float(self, sound: NDArray[np.int16]): + def _int2float(self, sound: NDArray[np.int16]): converted_sound = sound.astype("float32") converted_sound *= 1 / 32768 converted_sound = converted_sound.squeeze() @@ -51,8 +86,27 @@ def int2float(self, sound: NDArray[np.int16]): def detect( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: + """ + Detects voice activity in the given audio data. + + This method processes a window of the most recent audio samples, computes a confidence score + using the SileroVAD model, and determines if the confidence exceeds the specified threshold. + + Parameters + ---------- + audio_data : NDArray + A NumPy array containing audio input data. + input_parameters : dict of str to Any + Additional parameters for detection. + + Returns + ------- + Tuple[bool, dict] + - A boolean indicating whether voice activity was detected (`True` if detected, `False` otherwise). + - A dictionary containing the computed VAD confidence score. + """ vad_confidence = self.model( - torch.tensor(self.int2float(audio_data[-self.window_size :])), + torch.tensor(self._int2float(audio_data[-self.window_size :])), self.sampling_rate, ).item() ret = input_parameters.copy() diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index 7e172de80..b11dcdc09 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -57,6 +57,21 @@ class PlayData: class TextToSpeechAgent(BaseAgent): + """ + Agent responsible for converting text to speech and handling audio playback. + + Parameters + ---------- + speaker_config : SoundDeviceConfig + Configuration for the sound device used for playback. + ros2_name : str + Name of the ROS2 node. + tts : TTSModel + Text-to-speech model used for generating audio. + logger : Optional[logging.Logger], optional + Logger instance for logging messages, by default None. + """ + def __init__( self, speaker_config: SoundDeviceConfig, @@ -94,6 +109,9 @@ def __call__(self): self.run() def run(self): + """ + Start the text-to-speech agent, initializing playback and launching the transcription thread. + """ self.running = True self.logger.info("TextToSpeechAgent started") self.transcription_thread = Thread(target=self._transcription_thread) @@ -145,6 +163,9 @@ def _speaker_callback(self, outdata, frames, time, status_dict): outdata[:] = np.zeros(outdata.size).reshape(outdata.shape) def stop(self): + """ + Clean exit the text-to-speech agent, terminating playback and joining the transcription thread. + """ self.terminate_agent.set() if self.transcription_thread is not None: self.transcription_thread.join() diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index 2725c6dad..abb802054 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -41,6 +41,25 @@ class ThreadData(TypedDict): class VoiceRecognitionAgent(BaseAgent): + """ + Agent responsible for voice recognition, transcription, and processing voice activity. + + Parameters + ---------- + microphone_config : SoundDeviceConfig + Configuration for the microphone device used for audio input. + ros2_name : str + Name of the ROS2 node. + transcription_model : BaseTranscriptionModel + Model used for transcribing audio input to text. + vad : BaseVoiceDetectionModel + Voice activity detection model used to determine when speech is present. + grace_period : float, optional + Time in seconds to wait before stopping recording after speech ends, by default 1.0. + logger : Optional[logging.Logger], optional + Logger instance for logging messages, by default None. + """ + def __init__( self, microphone_config: SoundDeviceConfig, @@ -85,6 +104,23 @@ def __call__(self): def add_detection_model( self, model: BaseVoiceDetectionModel, pipeline: str = "record" ): + """ + Add a voice detection model to the specified processing pipeline. + + Parameters + ---------- + model : BaseVoiceDetectionModel + The voice detection model to be added. + pipeline : str, optional + The pipeline where the model should be added, either 'record' or 'stop'. + Default is 'record'. + + Raises + ------ + ValueError + If the specified pipeline is not 'record' or 'stop'. + """ + if pipeline == "record": self.should_record_pipeline.append(model) elif pipeline == "stop": @@ -93,17 +129,23 @@ def add_detection_model( raise ValueError("Pipeline should be either 'record' or 'stop'") def run(self): + """ + Start the voice recognition agent, initializing the microphone and handling incoming audio samples. + """ self.running = True assert isinstance(self.connectors["microphone"], SoundDeviceConnector) msg = SoundDeviceMessage(read=True) self.listener_handle = self.connectors["microphone"].start_action( action_data=msg, target="microphone", - on_feedback=self.on_new_sample, + on_feedback=self._on_new_sample, on_done=lambda: None, ) def stop(self): + """ + Clean exit the voice recognition agent, ensuring all transcription threads finish before termination. + """ self.logger.info("Stopping voice agent") self.running = False self.connectors["microphone"].terminate_action(self.listener_handle) @@ -120,7 +162,7 @@ def stop(self): ) self.logger.info("Voice agent stopped") - def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): + def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): sample_time = time.time() with self.sample_buffer_lock: self.sample_buffer.append(indata) @@ -137,14 +179,14 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): should_record = False # TODO: second condition is temporary if voice_detected and not self.recording_started: - should_record = self.should_record(indata, output_parameters) + should_record = self._should_record(indata, output_parameters) if should_record: self.logger.info("starting recording...") self.recording_started = True thread_id = str(uuid4())[0:8] transcription_thread = Thread( - target=self.transcription_thread, + target=self._transcription_thread, args=[thread_id], ) transcription_finished = Event() @@ -178,7 +220,7 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): elif sample_time - self.grace_period_start > self.grace_period: self._send_ros2_message("play", "/voice_commands") - def should_record( + def _should_record( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> bool: for model in self.should_record_pipeline: @@ -188,7 +230,7 @@ def should_record( return True return False - def transcription_thread(self, identifier: str): + def _transcription_thread(self, identifier: str): self.logger.info(f"transcription thread {identifier} started") audio_data = np.concatenate(self.transcription_buffers[identifier]) with ( diff --git a/src/rai_core/rai/communication/sound_device/api.py b/src/rai_core/rai/communication/sound_device/api.py index d1c83a7de..009714b6a 100644 --- a/src/rai_core/rai/communication/sound_device/api.py +++ b/src/rai_core/rai/communication/sound_device/api.py @@ -35,6 +35,40 @@ def __init__(self, msg: str): @dataclass class SoundDeviceConfig: + """ + Configuration settings for a sound device. + + This dataclass holds configuration parameters for audio input/output devices. + It ensures that at least one identifier (`device_number` or `device_name`) is + provided when initialized. + + Parameters + ---------- + stream : bool, optional + Whether the device should operate in streaming mode. Default is False. + block_size : int, optional + The block size for audio processing. Default is 1024. + dtype : str, optional + The data type of the audio stream (e.g., "int16", "float32"). Default is "int16". + channels : int, optional + The number of audio channels (e.g., 1 for mono, 2 for stereo). Default is 1. + consumer_sampling_rate : Optional[int], optional + The desired sampling rate for the audio consumer. Default is None. + device_number : Optional[int], optional + The device number for the sound device. If None, `device_name` must be set. Default is None. + device_name : Optional[str], optional + The name of the sound device. If None, `device_number` must be set. Default is None. + is_input : bool, optional + Indicates whether the device is used for input (recording). Default is False. + is_output : bool, optional + Indicates whether the device is used for output (playback). Default is False. + + Raises + ------ + ValueError + If neither `device_number` nor `device_name` is provided. + """ + stream: bool = False block_size: int = 1024 dtype: str = "int16" diff --git a/src/rai_core/rai/runners/s2s.py b/src/rai_core/rai/runners/s2s.py index af4678ca5..c38d1934c 100644 --- a/src/rai_core/rai/runners/s2s.py +++ b/src/rai_core/rai/runners/s2s.py @@ -40,6 +40,19 @@ class ASRConfig: class Speech2SpeechRunner(BaseRunner): + """ + Manages a speech-to-speech pipeline by integrating text-to-speech (TTS) and automatic speech recognition (ASR) agents. + + Parameters + ---------- + agents : Optional[list[BaseAgent]], optional + A list of existing agents to be included in the runner, by default None. + tts_cfg : TTSConfig, optional + Configuration for the text-to-speech agent, by default TTSConfig(). + asr_config : ASRConfig, optional + Configuration for the automatic speech recognition agent, by default ASRConfig(). + """ + def __init__( self, agents: Optional[list[BaseAgent]] = None, diff --git a/src/rai_tts/rai_tts/models/open_tts.py b/src/rai_tts/rai_tts/models/open_tts.py index e1e80d33a..5eb71fd5b 100644 --- a/src/rai_tts/rai_tts/models/open_tts.py +++ b/src/rai_tts/rai_tts/models/open_tts.py @@ -24,6 +24,17 @@ class OpenTTS(TTSModel): + """ + A text-to-speech (TTS) model interface for OpenTTS. + + Parameters + ---------- + url : str, optional + The API endpoint for the OpenTTS server, by default "http://localhost:5500/api/tts". + voice : str, optional + The voice model to use, by default "larynx:blizzard_lessac-glow_tts". + """ + def __init__( self, url: str = "http://localhost:5500/api/tts", @@ -33,6 +44,25 @@ def __init__( self.voice = voice def get_speech(self, text: str) -> AudioSegment: + """ + Converts text into speech using the OpenTTS API. + + Parameters + ---------- + text : str + The input text to be converted into speech. + + Returns + ------- + AudioSegment + The generated speech as an `AudioSegment` object. + + Raises + ------ + TTSModelError + If there is an issue with the request or the OpenTTS server is unreachable. + If the response does not contain valid audio data. + """ params = { "voice": self.voice, "text": text, @@ -47,7 +77,7 @@ def get_speech(self, text: str) -> AudioSegment: content_type = response.headers.get("Content-Type", "") if "audio" not in content_type: - raise ValueError("Response does not contain audio data") + raise TTSModelError("Response does not contain audio data") # Load audio into memory audio_bytes = BytesIO(response.content) @@ -64,5 +94,25 @@ def get_speech(self, text: str) -> AudioSegment: return AudioSegment(data, frame_rate=sample_rate, sample_width=2, channels=1) def get_tts_params(self) -> Tuple[int, int]: + """ + Returns TTS samling rate and channels. + + The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation. + + Parameters + ---------- + + Returns + ------- + Tuple[int, int] + sample rate, channels + + Raises + ------ + TTSModelError + If there is an issue with the request or the OpenTTS server is unreachable. + If the response does not contain valid audio data. + """ + data = self.get_speech("A") return data.frame_rate, 1 From 81a252acb6f02fc21107b1a9d7698786442b0fca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 12 Feb 2025 15:38:53 +0100 Subject: [PATCH 10/30] chore: rename runner main method to run --- src/rai_core/rai/communication/ros2/api.py | 5 ++++- src/rai_core/rai/runners/base.py | 2 +- tests/communication/ros2/test_connectors.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index 4cb9d922e..b71047f63 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -189,7 +189,9 @@ def publish( ) msg = build_ros2_msg(msg_type, msg_content) + print(msg) publisher = self._get_or_create_publisher(topic, type(msg), qos_profile) + print(topic) publisher.publish(msg) def _verify_receive_args( @@ -381,8 +383,9 @@ def configure_subscriber( msg_type = import_message_from_str(config.msg_type) + # TODO: this is definitely not generic def callback_wrapper(message): - text = message.data + text = message.text assert config.subscriber_callback is not None config.subscriber_callback( ROS2HRIMessage( diff --git a/src/rai_core/rai/runners/base.py b/src/rai_core/rai/runners/base.py index d0dfb4f20..96333d123 100644 --- a/src/rai_core/rai/runners/base.py +++ b/src/rai_core/rai/runners/base.py @@ -43,7 +43,7 @@ def __init__(self, agents: list[BaseAgent]): """ self.agents = agents - def run_indefinitely(self): + def run(self): """ Starts all agents and keeps them running indefinitely. diff --git a/tests/communication/ros2/test_connectors.py b/tests/communication/ros2/test_connectors.py index 7ddd674f7..066ca90cc 100644 --- a/tests/communication/ros2/test_connectors.py +++ b/tests/communication/ros2/test_connectors.py @@ -18,8 +18,8 @@ import pytest from PIL import Image from pydub import AudioSegment +from rai.communication import HRIPayload from rai.communication.ros2.connectors import ( - HRIPayload, ROS2ARIConnector, ROS2ARIMessage, ROS2HRIConnector, From b82f7e4435fa02199fffcf7a3cdb833a83455675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 12 Feb 2025 16:02:00 +0100 Subject: [PATCH 11/30] fix: tts agent support HRI msg --- src/rai_core/rai/agents/tts_agent.py | 22 ++++++++++++-------- src/rai_core/rai/communication/ros2/api.py | 24 +++++++++------------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index b11dcdc09..a85597a55 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -20,6 +20,7 @@ from numpy._typing import NDArray from pydub import AudioSegment +from std_msgs.msg import String from rai.agents.base import BaseAgent from rai.communication import ( @@ -28,8 +29,10 @@ SoundDeviceConnector, TopicConfig, ) +from rai.communication.ros2.api import IROS2Message from rai.communication.ros2.connectors import ROS2HRIMessage from rai.communication.sound_device.connector import SoundDeviceMessage +from rai_interfaces.msg._hri_message import HRIMessage from rai_tts.models.base import TTSModel if TYPE_CHECKING: @@ -182,7 +185,7 @@ def _transcription_thread(self): def _setup_ros2_connector(self): to_human = TopicConfig( - msg_type="std_msgs/msg/String", + msg_type="rai_interfaces/msg/HRIMessage", auto_qos_matching=True, is_subscriber=True, subscriber_callback=self._on_to_human_message, @@ -200,18 +203,21 @@ def _setup_ros2_connector(self): sources=[("/to_human", to_human), ("/voice_commands", voice_commands)], ) - def _on_to_human_message(self, message: ROS2HRIMessage): + def _on_to_human_message(self, message: IROS2Message): + assert isinstance(message, HRIMessage) + msg = ROS2HRIMessage.from_ros2(message, "ai") self.logger.info(f"Receieved message from human: {message.text}") - self.text_queue.put(message.text) + self.text_queue.put(msg.text) - def _on_command_message(self, message: ROS2HRIMessage): - if message.text == "tog_play": + def _on_command_message(self, message: IROS2Message): + assert isinstance(message, String) + if message.data == "tog_play": self.playback_data.playing = not self.playback_data.playing - elif message.text == "play": + elif message.data == "play": self.playback_data.playing = True - elif message.text == "pause": + elif message.data == "pause": self.playback_data.playing = False - elif message.text == "stop": + elif message.data == "stop": self.playback_data.playing = False while not self.audio_queue.empty(): _ = self.audio_queue.get() diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index b71047f63..dda9cb1a5 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -27,6 +27,7 @@ List, Literal, Optional, + Protocol, Tuple, Type, TypedDict, @@ -56,11 +57,15 @@ from rclpy.task import Future from rclpy.topic_endpoint_info import TopicEndpointInfo -from rai.communication.hri_connector import HRIPayload -from rai.communication.ros2.messages import ROS2HRIMessage from rai.tools.ros.utils import import_message_from_str, wait_for_message +class IROS2Message(Protocol): + __slots__: tuple + + def get_fields_and_field_types(self) -> dict: ... + + def adapt_requests_to_offers(publisher_info: List[TopicEndpointInfo]) -> QoSProfile: if not publisher_info: return QoSProfile(depth=1) @@ -328,7 +333,7 @@ class TopicConfig: auto_qos_matching: bool = True qos_profile: Optional[QoSProfile] = None is_subscriber: bool = False - subscriber_callback: Optional[Callable[[ROS2HRIMessage], None]] = None + subscriber_callback: Optional[Callable[[IROS2Message], None]] = None source_author: Literal["human", "ai"] = "ai" def __post_init__(self): @@ -383,20 +388,11 @@ def configure_subscriber( msg_type = import_message_from_str(config.msg_type) - # TODO: this is definitely not generic - def callback_wrapper(message): - text = message.text - assert config.subscriber_callback is not None - config.subscriber_callback( - ROS2HRIMessage( - HRIPayload(text=text), message_author=config.source_author - ) - ) - + assert config.subscriber_callback is not None self._subscribtions[topic] = self._node.create_subscription( msg_type=msg_type, topic=topic, - callback=callback_wrapper, + callback=config.subscriber_callback, qos_profile=qos_profile, ) From aacf6c781ed15e175973c2add2972761d5f22165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 12 Feb 2025 16:19:40 +0100 Subject: [PATCH 12/30] fix: s2s migrate to HRIMessage --- examples/s2s-demo.py | 125 ++++++++++++++++++++- src/rai_core/rai/communication/ros2/api.py | 2 +- 2 files changed, 123 insertions(+), 4 deletions(-) diff --git a/examples/s2s-demo.py b/examples/s2s-demo.py index 33618c621..74abd5bb3 100644 --- a/examples/s2s-demo.py +++ b/examples/s2s-demo.py @@ -11,9 +11,128 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time +from queue import Queue +from threading import Event, Thread +from typing import Dict, List + import rclpy -from rai.runners import Speech2SpeechRunner +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from rai.agents.base import BaseAgent +from rai.communication import BaseConnector +from rai.communication.ros2.api import IROS2Message +from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig +from rai.runners.base import BaseRunner +from rai.utils.model_initialization import get_llm_model + +from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage + +# NOTE: the Agent code included here is temporary until a dedicated speech agent is created +# it can still serve as a reference for writing your own RAI agents rclpy.init() -runner = Speech2SpeechRunner() -runner.run_indefinitely() + + +class LLMTextHandler(BaseCallbackHandler): + def __init__(self, connector: ROS2HRIConnector): + self.connector = connector + self.token_buffer = "" + + def on_llm_new_token(self, token: str, **kwargs): + self.token_buffer += token + if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]: + self.connector.send_all_targets(AIMessage(content=self.token_buffer)) + self.token_buffer = "" + + def on_llm_end( + self, + response, + *, + run_id, + parent_run_id=None, + **kwargs, + ): + if self.token_buffer: + self.connector.send_all_targets(AIMessage(content=self.token_buffer)) + self.token_buffer = "" + + +class S2SConversationalAgent(BaseAgent): + def __init__(self, connectors: Dict[str, BaseConnector]): # type: ignore + super().__init__(connectors=connectors) + self.message_history: List[HumanMessage | AIMessage | SystemMessage] = [ + SystemMessage( + content="Pretend you are a robot. Answer as if you were a robot." + ) + ] + self.speech_queue: Queue[InterfacesHRIMessage] = Queue() + + self.llm = get_llm_model(model_type="complex_model", streaming=True) + self._setup_ros_connector() + self.main_thread = None + self.stop_thread = Event() + + def run(self): + self.main_thread = Thread(target=self._main_loop) + self.main_thread.start() + + def _main_loop(self): + while not self.stop_thread.is_set(): + time.sleep(0.01) + speech = "" + while not self.speech_queue.empty(): + speech += " ".join(self.speech_queue.get().text) + print(f"Got sum speeach {speech}!") + if speech != "": + self.message_history.append(HumanMessage(content=speech)) + assert isinstance(self.connectors["ros2"], ROS2HRIConnector) + ai_answer = self.llm.invoke( + speech, + config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]}, + ) + self.message_history.append(ai_answer) # type: ignore + + def _on_from_human(self, msg: IROS2Message): + assert isinstance(msg, InterfacesHRIMessage) + self.speech_queue.put(msg) + + def _setup_ros_connector(self): + self.connectors["ros2"] = ROS2HRIConnector( + sources=[ + ( + "/from_human", + TopicConfig( + "rai_interfaces/msg/HRIMessage", + is_subscriber=True, + source_author="human", + subscriber_callback=self._on_from_human, + ), + ) + ], + targets=[ + ( + "/to_human", + TopicConfig( + "rai_interfaces/msg/HRIMessage", + source_author="ai", + is_subscriber=False, + ), + ) + ], + ) + + def stop(self): + assert isinstance(self.connectors["ros2"], ROS2HRIConnector) + self.connectors["ros2"].shutdown() + self.stop_thread.set() + if self.main_thread is not None: + self.main_thread.join() + + +talking_agent = S2SConversationalAgent({}) + + +# runner = Speech2SpeechRunner() +runner = BaseRunner([talking_agent]) +runner.run() diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index dda9cb1a5..f685f2938 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -411,7 +411,7 @@ def publish_configured(self, topic: str, msg_content: dict[str, Any]) -> None: except Exception as e: raise ValueError(f"{topic} has not been configured for publishing") from e msg_type = publisher.msg_type - msg = build_ros2_msg(msg_type, {"data": msg_content.text}) # type: ignore + msg = build_ros2_msg(msg_type, msg_content) # type: ignore publisher.publish(msg) From d41053f578f25b25a080987edb8a9be3aa81d896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 12 Feb 2025 16:52:50 +0100 Subject: [PATCH 13/30] fix: end to end working runner with HRI --- examples/s2s-demo.py | 6 ++--- src/rai_core/rai/agents/tts_agent.py | 3 +++ src/rai_core/rai/agents/voice_agent.py | 30 +++++++++++++++++----- src/rai_core/rai/communication/ros2/api.py | 2 -- 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/examples/s2s-demo.py b/examples/s2s-demo.py index 74abd5bb3..1b19185e9 100644 --- a/examples/s2s-demo.py +++ b/examples/s2s-demo.py @@ -23,7 +23,7 @@ from rai.communication import BaseConnector from rai.communication.ros2.api import IROS2Message from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig -from rai.runners.base import BaseRunner +from rai.runners import Speech2SpeechRunner from rai.utils.model_initialization import get_llm_model from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage @@ -133,6 +133,6 @@ def stop(self): talking_agent = S2SConversationalAgent({}) -# runner = Speech2SpeechRunner() -runner = BaseRunner([talking_agent]) +runner = Speech2SpeechRunner([talking_agent]) +# runner = BaseRunner([talking_agent]) runner.run() diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index a85597a55..063f644c1 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -211,6 +211,7 @@ def _on_to_human_message(self, message: IROS2Message): def _on_command_message(self, message: IROS2Message): assert isinstance(message, String) + self.logger.debug(f"Receieved status message: {message}") if message.data == "tog_play": self.playback_data.playing = not self.playback_data.playing elif message.data == "play": @@ -224,3 +225,5 @@ def _on_command_message(self, message: IROS2Message): self.playback_data.data = None self.playback_data.current_frame = 0 self.playback_data.current_segment = None + + self.logger.debug(f"Current status is: {self.playback_data.playing}") diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index abb802054..6e0111ac8 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -24,8 +24,11 @@ from rai.agents.base import BaseAgent from rai.communication import ( + HRIPayload, ROS2ARIConnector, ROS2ARIMessage, + ROS2HRIConnector, + ROS2HRIMessage, SoundDeviceConfig, SoundDeviceConnector, SoundDeviceMessage, @@ -76,8 +79,15 @@ def __init__( microphone = SoundDeviceConnector( targets=[], sources=[("microphone", microphone_config)] ) - ros2_connector = ROS2ARIConnector(ros2_name) - super().__init__(connectors={"microphone": microphone, "ros2": ros2_connector}) + ros2_hri_connector = ROS2HRIConnector(ros2_name, targets=["/from_human"]) + ros2_ari_connector = ROS2ARIConnector(ros2_name + "ari") + super().__init__( + connectors={ + "microphone": microphone, + "ros2_hri": ros2_hri_connector, + "ros2_ari": ros2_ari_connector, + } + ) self.should_record_pipeline: List[BaseVoiceDetectionModel] = [] self.should_stop_pipeline: List[BaseVoiceDetectionModel] = [] @@ -141,6 +151,7 @@ def run(self): on_feedback=self._on_new_sample, on_done=lambda: None, ) + self.logger.info("Started Voice Agent") def stop(self): """ @@ -149,6 +160,8 @@ def stop(self): self.logger.info("Stopping voice agent") self.running = False self.connectors["microphone"].terminate_action(self.listener_handle) + assert isinstance(self.connectors["ros2_hri"], ROS2HRIConnector) + self.connectors["ros2_hri"].shutdown() while not all( [thread["joined"] for thread in self.transcription_threads.values()] ): @@ -242,8 +255,11 @@ def _transcription_thread(self, identifier: str): self.transcription_threads[identifier]["event"].set() def _send_ros2_message(self, data: str, topic: str): - self.connectors["ros2"].send_message( - ROS2ARIMessage({"data": data}, {"msg_type": "std_msgs/msg/String"}), - topic, - msg_type="std_msgs/msg/String", - ) + if topic == "/voice_commands": + msg = ROS2ARIMessage({"data": data}) + self.connectors["ros2_ari"].send_message( + msg, topic, msg_type="std_msgs/msg/String" + ) + else: + msg = ROS2HRIMessage(HRIPayload(text=data), "human") + self.connectors["ros2_hri"].send_message(msg, topic) diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index f685f2938..e75116112 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -194,9 +194,7 @@ def publish( ) msg = build_ros2_msg(msg_type, msg_content) - print(msg) publisher = self._get_or_create_publisher(topic, type(msg), qos_profile) - print(topic) publisher.publish(msg) def _verify_receive_args( From 237a2799848f84aedfeda2ce82e55354b61bbe7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 12 Feb 2025 17:49:27 +0100 Subject: [PATCH 14/30] test: update tests to support AudioSegment api --- tests/communication/sounds_device/test_api.py | 48 ++++++++++++++++--- .../sounds_device/test_connector.py | 28 +++++------ 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/tests/communication/sounds_device/test_api.py b/tests/communication/sounds_device/test_api.py index b5a62a54a..2c05dd0f2 100644 --- a/tests/communication/sounds_device/test_api.py +++ b/tests/communication/sounds_device/test_api.py @@ -11,16 +11,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io from unittest.mock import MagicMock, patch import numpy as np import pytest import sounddevice +from pydub import AudioSegment from rai.communication.sound_device import ( SoundDeviceAPI, SoundDeviceConfig, SoundDeviceError, ) +from scipy.io import wavfile + + +def audio_to_numpy(audio): + samples = np.array(audio.get_array_of_samples()) + if audio.channels == 2: # Stereo: reshape into two columns + samples = samples.reshape((-1, 2)) + return samples + + +def get_audio(): + frequency = 440 + duration = 2.0 + sample_rate = 44100 + amplitude = 0.5 + t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False) + wave = amplitude * np.sin(2 * np.pi * frequency * t) + wave_int16 = np.int16(wave * 32767) + wav_buffer = io.BytesIO() + wavfile.write(wav_buffer, sample_rate, wave_int16) + audio = AudioSegment.from_wav(wav_buffer) + return audio + + +@pytest.fixture +def sine_wav(): + return get_audio() + + +@pytest.fixture +def sine_wav_np(): + wav = get_audio() + return audio_to_numpy(wav) @pytest.fixture @@ -126,7 +161,7 @@ def test_init( @pytest.mark.parametrize("is_output", [True, False]) -def test_write_unsupported(input_device_id, mock_sd, is_output): +def test_write_unsupported(input_device_id, mock_sd, is_output, sine_wav): """Ensure writing raises an error if output is not supported.""" config = SoundDeviceConfig( stream=True, @@ -143,14 +178,14 @@ def test_write_unsupported(input_device_id, mock_sd, is_output): if not is_output: with pytest.raises(SoundDeviceError, match="does not support writing!"): - api.write(np.array([0.0, 1.0])) + api.write(sine_wav) else: - api.write(np.array([0.0, 1.0]), blocking=True) + api.write(sine_wav, blocking=True) mock_sd["play"].assert_called_once() @pytest.mark.parametrize("is_input", [True, False]) -def test_read_unsupported(input_device_id, mock_sd, is_input): +def test_read_unsupported(input_device_id, mock_sd, is_input, sine_wav, sine_wav_np): """Ensure reading raises an error if input is not supported.""" config = SoundDeviceConfig( stream=True, @@ -169,9 +204,10 @@ def test_read_unsupported(input_device_id, mock_sd, is_input): with pytest.raises(SoundDeviceError, match="does not support reading!"): api.read(1.0) else: - mock_sd["rec"].return_value = np.array([[0.0], [1.0]]) + mock_sd["rec"].return_value = sine_wav_np result = api.read(1.0, blocking=True) - np.testing.assert_array_equal(result, np.array([[0.0], [1.0]])) + arr = audio_to_numpy(result) + np.testing.assert_array_equal(arr.flatten(), sine_wav_np) @pytest.mark.parametrize("method", ["stop", "wait"]) diff --git a/tests/communication/sounds_device/test_connector.py b/tests/communication/sounds_device/test_connector.py index a7f86caa9..0de8b4bfb 100644 --- a/tests/communication/sounds_device/test_connector.py +++ b/tests/communication/sounds_device/test_connector.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import base64 import io from unittest.mock import MagicMock, patch import numpy as np import pytest import sounddevice +from pydub import AudioSegment from rai.communication import HRIPayload from rai.communication.sound_device import SoundDeviceConfig, SoundDeviceError from rai.communication.sound_device.connector import ( # Replace with actual module name @@ -69,7 +69,7 @@ def connector(mock_sound_device_api): @pytest.fixture -def base64_audio(): +def sine_wav(): frequency = 440 duration = 2.0 sample_rate = 44100 @@ -79,11 +79,8 @@ def base64_audio(): wave_int16 = np.int16(wave * 32767) wav_buffer = io.BytesIO() wavfile.write(wav_buffer, sample_rate, wave_int16) - - wav_binary = wav_buffer.getvalue() - - base64_string = base64.b64encode(wav_binary).decode("utf-8") - return base64_string + audio = AudioSegment.from_wav(wav_buffer) + return audio @pytest.fixture @@ -102,11 +99,11 @@ def binary_audio(): return wav_binary -def test_send_message_play_audio(connector, mock_sound_device_api, base64_audio): +def test_send_message_play_audio(connector, mock_sound_device_api, sine_wav): message = SoundDeviceMessage( payload=HRIPayload( text="", - audios=[base64_audio], + audios=[sine_wav], ) ) connector.send_message(message, "speaker") @@ -128,21 +125,20 @@ def test_send_message_read_error(connector): connector.send_message(message, "speaker") -def test_service_call_play_audio(connector, mock_sound_device_api, base64_audio): - message = SoundDeviceMessage(payload=HRIPayload(text="", audios=[base64_audio])) +def test_service_call_play_audio(connector, mock_sound_device_api, sine_wav): + message = SoundDeviceMessage(payload=HRIPayload(text="", audios=[sine_wav])) result = connector.service_call(message, "speaker") mock_sound_device_api.write.assert_called_once() assert isinstance(result, SoundDeviceMessage) -def test_service_call_read_audio( - connector, mock_sound_device_api, binary_audio, base64_audio -): - mock_sound_device_api.read.return_value = binary_audio +def test_service_call_read_audio(connector, mock_sound_device_api, sine_wav): + mock_sound_device_api.read.return_value = sine_wav message = SoundDeviceMessage(read=True) result = connector.service_call(message, "microphone") mock_sound_device_api.read.assert_called_once_with(1.0, blocking=True) - assert result.audios == [base64_audio] + + assert result.audios == [sine_wav] def test_service_call_stop_error(connector): From ca88199068811b8dfa216fd4c1e081e794d849f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 18 Feb 2025 17:25:58 +0100 Subject: [PATCH 15/30] feat: working multiterminal version --- examples/s2s-demo.py | 2 ++ src/rai_core/rai/agents/tts_agent.py | 7 ++-- src/rai_core/rai/agents/voice_agent.py | 2 +- .../rai/communication/sound_device/api.py | 35 +++++++++++++------ .../communication/sound_device/connector.py | 15 +++++++- src/rai_tts/rai_tts/models/base.py | 18 ++++++++++ src/rai_tts/rai_tts/models/open_tts.py | 8 ++++- 7 files changed, 70 insertions(+), 17 deletions(-) diff --git a/examples/s2s-demo.py b/examples/s2s-demo.py index 1b19185e9..72cd63d02 100644 --- a/examples/s2s-demo.py +++ b/examples/s2s-demo.py @@ -87,6 +87,8 @@ def _main_loop(self): if speech != "": self.message_history.append(HumanMessage(content=speech)) assert isinstance(self.connectors["ros2"], ROS2HRIConnector) + # ai_answer = AIMessage(content="Yes, I am Jar Jar Binks") + # self.connectors["ros2"].send_all_targets(ai_answer) ai_answer = self.llm.invoke( speech, config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]}, diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index 063f644c1..da42566fb 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -90,6 +90,10 @@ def __init__( speaker = SoundDeviceConnector( targets=[("speaker", speaker_config)], sources=[] ) + sample_rate, _, out_channels = speaker.get_audio_params("speaker") + tts.sample_rate = sample_rate + tts.channels = out_channels + self.node_base_name = ros2_name self.model = tts ros2_connector = self._setup_ros2_connector() @@ -119,7 +123,6 @@ def run(self): self.logger.info("TextToSpeechAgent started") self.transcription_thread = Thread(target=self._transcription_thread) self.transcription_thread.start() - sample_rate, channels = self.model.get_tts_params() msg = SoundDeviceMessage(read=False) assert isinstance(self.connectors["speaker"], SoundDeviceConnector) @@ -128,8 +131,6 @@ def run(self): "speaker", on_feedback=self._speaker_callback, on_done=lambda: None, - sample_rate=sample_rate, - channels=channels, ) def _speaker_callback(self, outdata, frames, time, status_dict): diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index 6e0111ac8..634e702e4 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -238,7 +238,7 @@ def _should_record( ) -> bool: for model in self.should_record_pipeline: detected, output = model(audio_data, input_parameters) - self.logger.debug(f"detected {detected}, output {output}") + self.logger.info(f"detected {detected}, output {output}") if detected: return True return False diff --git a/src/rai_core/rai/communication/sound_device/api.py b/src/rai_core/rai/communication/sound_device/api.py index 009714b6a..9cdf4a023 100644 --- a/src/rai_core/rai/communication/sound_device/api.py +++ b/src/rai_core/rai/communication/sound_device/api.py @@ -72,7 +72,7 @@ class SoundDeviceConfig: stream: bool = False block_size: int = 1024 dtype: str = "int16" - channels: int = 1 + channels: Optional[int] = None consumer_sampling_rate: Optional[int] = None device_number: Optional[int] = None device_name: Optional[str] = None @@ -82,12 +82,15 @@ class SoundDeviceConfig: def __post_init__(self): if self.device_number is None and self.device_name is None: raise ValueError("Either 'device_number' or 'device_name' must be set.") + if not self.is_input and not self.is_output: + raise ValueError("Either 'is_input' or 'is_output' must be True.") class SoundDeviceAPI: def __init__(self, config: SoundDeviceConfig): self.device_name = "" + self.config = config if not sd: raise SoundDeviceError("SoundDeviceAPI requires sound_device module!") if config.device_name: @@ -100,16 +103,23 @@ def __init__(self, config: SoundDeviceConfig): break else: self.device_number = config.device_number - self.sample_rate = int( - sd.query_devices(device=self.device_number, kind="input")[ - "default_samplerate" - ] # type: ignore - ) + try: + device_data = sd.query_devices(device=self.device_number) + except AttributeError: + raise SoundDeviceError( + f"Device {self.device_name} was not found for configuration" + ) + self.sample_rate = int(device_data["default_samplerate"]) # type: ignore + if self.config.channels is None: + self.in_channels = int(device_data["max_input_channels"]) # type: ignore + self.out_channels = int(device_data["max_output_channels"]) # type: ignore + else: + self.in_channels = self.config.channels + self.out_channels = self.config.channels self.read_flag = config.is_input self.write_flag = config.is_output self.stream_flag = config.stream - self.config = config self.in_stream = None self.out_stream = None @@ -182,7 +192,7 @@ def read(self, time: float, blocking: bool = False) -> AudioSegment: recording = sd.rec( frames=frames, samplerate=self.sample_rate, - channels=self.config.channels, + channels=self.in_channels, device=self.device_number, blocking=blocking, dtype=self.config.dtype, @@ -192,7 +202,7 @@ def read(self, time: float, blocking: bool = False) -> AudioSegment: data=recording.flatten(), sample_width=recording.dtype.itemsize, frame_rate=self.sample_rate, - channels=self.config.channels, + channels=self.in_channels, ) def stop(self): @@ -247,9 +257,12 @@ def callback(indata: NDArray, frames: int, time: Any, status: CallbackFlags): try: assert sd is not None + print(sample_rate) sample_rate = self.sample_rate if sample_rate is None else sample_rate print(sample_rate) - channels = self.config.channels if channels is None else channels + print(channels) + channels = self.out_channels if channels is None else channels + print(channels) self.out_stream = sd.OutputStream( samplerate=sample_rate, channels=channels, @@ -312,7 +325,7 @@ def callback(indata: NDArray, frames: int, _, status: CallbackFlags): self.in_stream = sd.InputStream( samplerate=self.sample_rate, - channels=self.config.channels, + channels=self.in_channels, device=self.device_number, dtype=self.config.dtype, blocksize=window_size_samples, diff --git a/src/rai_core/rai/communication/sound_device/connector.py b/src/rai_core/rai/communication/sound_device/connector.py index 48f29b28b..2dd79f56c 100644 --- a/src/rai_core/rai/communication/sound_device/connector.py +++ b/src/rai_core/rai/communication/sound_device/connector.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Callable, Literal, Optional, Tuple +from typing import Callable, Literal, NamedTuple, Optional, Tuple try: import sounddevice as sd @@ -51,6 +51,12 @@ def __init__( self.duration = duration +class AudioParams(NamedTuple): + sample_rate: int + in_channels: int + out_channels: int + + class SoundDeviceConnector(HRIConnector[SoundDeviceMessage]): """SoundDevice connector implementing the Human-Robot Interface. @@ -79,6 +85,13 @@ def __init__( super().__init__(configured_targets, configured_sources) sd.default.latency = ("low", "low") # type: ignore + def get_audio_params(self, target: str) -> AudioParams: + return AudioParams( + self.devices[target].sample_rate, + self.devices[target].in_channels, + self.devices[target].out_channels, + ) + def configure_device( self, target: str, diff --git a/src/rai_tts/rai_tts/models/base.py b/src/rai_tts/rai_tts/models/base.py index 6014cc50c..9af07595f 100644 --- a/src/rai_tts/rai_tts/models/base.py +++ b/src/rai_tts/rai_tts/models/base.py @@ -23,6 +23,9 @@ class TTSModelError(Exception): class TTSModel(ABC): + sample_rate: int = -1 + channels: int = 1 + @abstractmethod def get_speech(self, text: str) -> AudioSegment: pass @@ -30,3 +33,18 @@ def get_speech(self, text: str) -> AudioSegment: @abstractmethod def get_tts_params(self) -> Tuple[int, int]: pass + + def set_tts_params(self, target_sample_rate: int, channels: int): + self.sample_rate = target_sample_rate + self.channels = channels + + def _resample(self, audio: AudioSegment) -> AudioSegment: + """ + Resample an AudioSegment to a specified sample rate and number of channels. + + :param audio: The input AudioSegment. + :param target_sample_rate: The desired sample rate in Hz. + :param channels: The desired number of audio channels. + :return: A new AudioSegment with the specified sample rate and channels. + """ + return audio.set_frame_rate(self.sample_rate) diff --git a/src/rai_tts/rai_tts/models/open_tts.py b/src/rai_tts/rai_tts/models/open_tts.py index 5eb71fd5b..2aeee6446 100644 --- a/src/rai_tts/rai_tts/models/open_tts.py +++ b/src/rai_tts/rai_tts/models/open_tts.py @@ -91,7 +91,13 @@ def get_speech(self, text: str) -> AudioSegment: (data * 32768).clip(-32768, 32767).astype(np.int16) ) # Convert float32 to int16 - return AudioSegment(data, frame_rate=sample_rate, sample_width=2, channels=1) + audio = AudioSegment( + data.tobytes(), frame_rate=sample_rate, sample_width=2, channels=1 + ) + if self.sample_rate == -1: + return audio + else: + return self._resample(audio) def get_tts_params(self) -> Tuple[int, int]: """ From 590569794eb93db94fb337316306840d4d38554e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 26 Feb 2025 16:14:26 +0100 Subject: [PATCH 16/30] feat: working singleterminal setup --- examples/s2s/asr.py | 119 ++++++++++++ examples/s2s/conversational.py | 171 ++++++++++++++++++ examples/s2s/run.sh | 66 +++++++ examples/s2s/tts.py | 76 ++++++++ src/rai_asr/rai_asr/models/local_whisper.py | 14 +- src/rai_core/rai/agents/tts_agent.py | 3 +- src/rai_core/rai/agents/voice_agent.py | 15 +- .../rai/communication/sound_device/api.py | 51 ++++-- .../communication/sound_device/connector.py | 8 - 9 files changed, 490 insertions(+), 33 deletions(-) create mode 100644 examples/s2s/asr.py create mode 100644 examples/s2s/conversational.py create mode 100755 examples/s2s/run.sh create mode 100644 examples/s2s/tts.py diff --git a/examples/s2s/asr.py b/examples/s2s/asr.py new file mode 100644 index 000000000..462d034ba --- /dev/null +++ b/examples/s2s/asr.py @@ -0,0 +1,119 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import signal +import time + +import rclpy +from rai.agents import VoiceRecognitionAgent +from rai.communication.sound_device.api import SoundDeviceConfig + +from rai_asr.models import LocalWhisper, OpenWakeWord, SileroVAD + +VAD_THRESHOLD = 0.8 # Note that this might be different depending on your device +OWW_THRESHOLD = 0.1 # Note that this might be different depending on your device + +VAD_SAMPLING_RATE = 16000 # Or 8000 +DEFAULT_BLOCKSIZE = 1280 + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Voice Activity Detection and Wake Word Detection Configuration", + allow_abbrev=True, + ) + + # Predefined arguments + parser.add_argument( + "--vad-threshold", + type=float, + default=VAD_THRESHOLD, + help="Voice Activity Detection threshold (default: 0.5)", + ) + parser.add_argument( + "--oww-threshold", + type=float, + default=OWW_THRESHOLD, + help="OpenWakeWord threshold (default: 0.1)", + ) + parser.add_argument( + "--vad-sampling-rate", + type=int, + choices=[8000, 16000], + default=VAD_SAMPLING_RATE, + help="VAD sampling rate (default: 16000)", + ) + parser.add_argument( + "--block-size", + type=int, + default=DEFAULT_BLOCKSIZE, + help="Audio block size (default: 1280)", + ) + parser.add_argument( + "--device-name", + type=str, + default="default", + help="Microphone device name (default: 'default')", + ) + + # Use parse_known_args to ignore unknown arguments + args, unknown = parser.parse_known_args() + + if unknown: + print(f"Ignoring unknown arguments: {unknown}") + + return args + + +if __name__ == "__main__": + args = parse_arguments() + + microphone_configuration = SoundDeviceConfig( + stream=True, + channels=1, + device_name=args.device_name, + block_size=args.block_size, + consumer_sampling_rate=args.vad_sampling_rate, + dtype="int16", + device_number=None, + is_input=True, + is_output=False, + ) + vad = SileroVAD(args.vad_sampling_rate, args.vad_threshold) + oww = OpenWakeWord("hey jarvis", args.oww_threshold) + whisper = LocalWhisper("tiny", args.vad_sampling_rate) + # whisper = OpenAIWhisper("whisper-1", args.vad_sampling_rate, "en") + + rclpy.init() + ros2_name = "rai_asr_agent" + + agent = VoiceRecognitionAgent(microphone_configuration, ros2_name, whisper, vad) + agent.add_detection_model(oww, pipeline="record") + + agent.run() + + def cleanup(signum, frame): + print("\nCustom handler: Caught SIGINT (Ctrl+C).") + print("Performing cleanup") + # Optionally exit the program + agent.stop() + rclpy.shutdown() + exit(0) + + signal.signal(signal.SIGINT, cleanup) + + print("Runnin") + while True: + time.sleep(1) diff --git a/examples/s2s/conversational.py b/examples/s2s/conversational.py new file mode 100644 index 000000000..a59de6226 --- /dev/null +++ b/examples/s2s/conversational.py @@ -0,0 +1,171 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import signal +import time +from queue import Queue +from threading import Event, Thread +from typing import Dict, List + +import rclpy +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from rai.agents.base import BaseAgent +from rai.communication import BaseConnector +from rai.communication.ros2.api import IROS2Message +from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig +from rai.utils.model_initialization import get_llm_model + +from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage + +# NOTE: the Agent code included here is temporary until a dedicated speech agent is created +# it can still serve as a reference for writing your own RAI agents + + +class LLMTextHandler(BaseCallbackHandler): + def __init__(self, connector: ROS2HRIConnector): + self.connector = connector + self.token_buffer = "" + + def on_llm_new_token(self, token: str, **kwargs): + self.token_buffer += token + if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]: + self.connector.send_all_targets(AIMessage(content=self.token_buffer)) + self.token_buffer = "" + + def on_llm_end( + self, + response, + *, + run_id, + parent_run_id=None, + **kwargs, + ): + if self.token_buffer: + self.connector.send_all_targets(AIMessage(content=self.token_buffer)) + self.token_buffer = "" + + +class S2SConversationalAgent(BaseAgent): + def __init__(self, connectors: Dict[str, BaseConnector]): # type: ignore + super().__init__(connectors=connectors) + self.message_history: List[HumanMessage | AIMessage | SystemMessage] = [ + SystemMessage( + content="Pretend you are a robot. Answer as if you were a robot." + ) + ] + self.speech_queue: Queue[InterfacesHRIMessage] = Queue() + + self.llm = get_llm_model(model_type="complex_model", streaming=True) + self._setup_ros_connector() + self.main_thread = None + self.stop_thread = Event() + + def run(self): + logging.info("Running S2SConversationalAgent") + self.main_thread = Thread(target=self._main_loop) + self.main_thread.start() + + def _main_loop(self): + while not self.stop_thread.is_set(): + time.sleep(0.01) + speech = "" + while not self.speech_queue.empty(): + speech += "".join(self.speech_queue.get().text) + logging.info(f"Received human speech {speech}!") + if speech != "": + self.message_history.append(HumanMessage(content=speech)) + assert isinstance(self.connectors["ros2"], ROS2HRIConnector) + # ai_answer = AIMessage(content="Yes, I am Jar Jar Binks") + # self.connectors["ros2"].send_all_targets(ai_answer) + ai_answer = self.llm.invoke( + speech, + config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]}, + ) + self.message_history.append(ai_answer) # type: ignore + + def _on_from_human(self, msg: IROS2Message): + assert isinstance(msg, InterfacesHRIMessage) + logging.info("Received message from human: %s", msg.text) + self.speech_queue.put(msg) + + def _setup_ros_connector(self): + self.connectors["ros2"] = ROS2HRIConnector( + sources=[ + ( + "/from_human", + TopicConfig( + "rai_interfaces/msg/HRIMessage", + is_subscriber=True, + source_author="human", + subscriber_callback=self._on_from_human, + ), + ) + ], + targets=[ + ( + "/to_human", + TopicConfig( + "rai_interfaces/msg/HRIMessage", + source_author="ai", + is_subscriber=False, + ), + ) + ], + ) + + def stop(self): + assert isinstance(self.connectors["ros2"], ROS2HRIConnector) + self.connectors["ros2"].shutdown() + self.stop_thread.set() + if self.main_thread is not None: + self.main_thread.join() + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Text To Speech Configuration", + allow_abbrev=True, + ) + + # Use parse_known_args to ignore unknown arguments + args, unknown = parser.parse_known_args() + + if unknown: + print(f"Ignoring unknown arguments: {unknown}") + + return args + + +if __name__ == "__main__": + args = parse_arguments() + rclpy.init() + agent = S2SConversationalAgent(connectors={}) + agent.run() + + def cleanup(signum, frame): + print("\nCustom handler: Caught SIGINT (Ctrl+C).") + print("Performing cleanup") + # Optionally exit the program + agent.stop() + rclpy.shutdown() + exit(0) + + signal.signal(signal.SIGINT, cleanup) + + print("Runnin") + while True: + time.sleep(1) diff --git a/examples/s2s/run.sh b/examples/s2s/run.sh new file mode 100755 index 000000000..660b9ca27 --- /dev/null +++ b/examples/s2s/run.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash +# Directory where the scripts are located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +# Array to store PIDs of background processes +declare -a PIDS + +# Function to run a script with the given arguments +run_script() { + local script="$1" + shift + python3 "$script" "$@" & + # Store the PID of the last background process + PIDS+=($!) +} + +# Function to handle Ctrl+C (SIGINT) +handle_sigint() { + echo -e "\nReceived SIGINT, forwarding to all running Python processes..." + + # Send SIGINT to all child processes + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Sending SIGINT to process $pid" + kill -SIGINT "$pid" + fi + done + + echo "Waiting for all processes to exit..." + wait + + echo "All processes have exited. Cleaning up and exiting." + exit 0 +} + +# Main logic +main() { + # Set up trap for SIGINT (Ctrl+C) + trap handle_sigint SIGINT + + # Find all Python scripts in the scripts directory + mapfile -t scripts < <(find "$SCRIPT_DIR" -name "*.py") + + # If no scripts found, exit + if [ ${#scripts[@]} -eq 0 ]; then + echo "No Python scripts found in $SCRIPT_DIR" + exit 1 + fi + + echo "Found ${#scripts[@]} Python scripts in $SCRIPT_DIR" + + # Run all scripts in parallel with all arguments properly quoted + for script in "${scripts[@]}"; do + run_script "$script" "$@" + done + + echo "All scripts are running in the background. Press Ctrl+C to stop them." + + # Wait for all background processes to finish + wait + + echo "All scripts completed successfully." +} + +# Call main with all arguments properly quoted +main "$@" diff --git a/examples/s2s/tts.py b/examples/s2s/tts.py new file mode 100644 index 000000000..997c8d4a5 --- /dev/null +++ b/examples/s2s/tts.py @@ -0,0 +1,76 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import signal +import time + +import rclpy +from rai.agents import TextToSpeechAgent +from rai.communication.sound_device import SoundDeviceConfig + +from rai_tts.models import OpenTTS + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Text To Speech Configuration", + allow_abbrev=True, + ) + + parser.add_argument( + "--device-name", + type=str, + default="default", + help="Speaker device name (default: 'default')", + ) + + # Use parse_known_args to ignore unknown arguments + args, unknown = parser.parse_known_args() + + if unknown: + print(f"Ignoring unknown arguments: {unknown}") + + return args + + +if __name__ == "__main__": + rclpy.init() + args = parse_arguments() + + config = SoundDeviceConfig( + stream=True, + is_output=True, + # device_name="Sennheiser USB headset: Audio (hw:2,0)", + # device_name="Jabra Speak2 40 MS: USB Audio (hw:2,0)", + device_name=args.device_name, + ) + tts = OpenTTS() + + agent = TextToSpeechAgent(config, "text_to_speech", tts) + agent.run() + + def cleanup(signum, frame): + print("\nCustom handler: Caught SIGINT (Ctrl+C).") + print("Performing cleanup") + # Optionally exit the program + agent.stop() + rclpy.shutdown() + exit(0) + + signal.signal(signal.SIGINT, cleanup) + + print("Runnin") + while True: + time.sleep(1) diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index afd16433d..6aa8a15b8 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -54,6 +54,15 @@ def __init__( self, model_name: str, sample_rate: int, language: str = "en", **kwargs ): super().__init__(model_name, sample_rate, language) + self.decode_options = { + "language": language, # Set language to English + "task": "transcribe", # Set task to transcribe (not translate) + "fp16": False, # Use FP32 instead of FP16 for better precision + "without_timestamps": True, # Don't include timestamps in output + "suppress_tokens": [-1], # Default tokens to suppress + "suppress_blank": True, # Suppress blank outputs + "beam_size": 5, # Beam size for beam search + } if torch.cuda.is_available(): self.whisper = whisper.load_model(self.model_name, device="cuda", **kwargs) else: @@ -79,9 +88,10 @@ def transcribe(self, data: NDArray[np.int16]) -> str: The transcribed text from the audio input. """ normalized_data = data.astype(np.float32) / 32768.0 + result = whisper.transcribe( - self.whisper, normalized_data - ) # TODO: handling of additional transcribe arguments (perhaps in model init) + self.whisper, normalized_data, **self.decode_options + ) transcription = result["text"] self.logger.info("transcription: %s", transcription) transcription = cast(str, transcription) diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index da42566fb..865233da7 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -170,6 +170,7 @@ def stop(self): """ Clean exit the text-to-speech agent, terminating playback and joining the transcription thread. """ + self.logger.info("Stopping TextToSpeechAgent") self.terminate_agent.set() if self.transcription_thread is not None: self.transcription_thread.join() @@ -207,7 +208,7 @@ def _setup_ros2_connector(self): def _on_to_human_message(self, message: IROS2Message): assert isinstance(message, HRIMessage) msg = ROS2HRIMessage.from_ros2(message, "ai") - self.logger.info(f"Receieved message from human: {message.text}") + self.logger.debug(f"Receieved message from human: {message.text}") self.text_queue.put(msg.text) def _on_command_message(self, message: IROS2Message): diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index 634e702e4..ab3027cc5 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -157,7 +157,7 @@ def stop(self): """ Clean exit the voice recognition agent, ensuring all transcription threads finish before termination. """ - self.logger.info("Stopping voice agent") + self.logger.info("Stopping Voice Agent") self.running = False self.connectors["microphone"].terminate_action(self.listener_handle) assert isinstance(self.connectors["ros2_hri"], ROS2HRIConnector) @@ -189,6 +189,7 @@ def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.transcription_threads[thread_id]["joined"] = True voice_detected, output_parameters = self.vad(indata, {}) + self.logger.debug(f"Voice detected: {voice_detected}: {output_parameters}") should_record = False # TODO: second condition is temporary if voice_detected and not self.recording_started: @@ -238,7 +239,7 @@ def _should_record( ) -> bool: for model in self.should_record_pipeline: detected, output = model(audio_data, input_parameters) - self.logger.info(f"detected {detected}, output {output}") + self.logger.debug(f"detected {detected}, output {output}") if detected: return True return False @@ -255,11 +256,15 @@ def _transcription_thread(self, identifier: str): self.transcription_threads[identifier]["event"].set() def _send_ros2_message(self, data: str, topic: str): + self.logger.debug(f"Sending message to {topic}: {data}") if topic == "/voice_commands": msg = ROS2ARIMessage({"data": data}) - self.connectors["ros2_ari"].send_message( - msg, topic, msg_type="std_msgs/msg/String" - ) + try: + self.connectors["ros2_ari"].send_message( + msg, topic, msg_type="std_msgs/msg/String" + ) + except Exception as e: + self.logger.error(f"Error sending message to {topic}: {e}") else: msg = ROS2HRIMessage(HRIPayload(text=data), "human") self.connectors["ros2_hri"].send_message(msg, topic) diff --git a/src/rai_core/rai/communication/sound_device/api.py b/src/rai_core/rai/communication/sound_device/api.py index 9cdf4a023..c9bdc7bff 100644 --- a/src/rai_core/rai/communication/sound_device/api.py +++ b/src/rai_core/rai/communication/sound_device/api.py @@ -21,12 +21,6 @@ from pydub import AudioSegment from scipy.signal import resample -try: - import sounddevice as sd -except ImportError: - logging.warning("Install sound_device module to use sound device features!") - sd = None - class SoundDeviceError(Exception): def __init__(self, msg: str): @@ -91,8 +85,11 @@ def __init__(self, config: SoundDeviceConfig): self.device_name = "" self.config = config - if not sd: + try: + import sounddevice as sd + except ImportError: raise SoundDeviceError("SoundDeviceAPI requires sound_device module!") + sd.default.latency = ("low", "low") # type: ignore if config.device_name: self.device_name = config.device_name devices = sd.query_devices() @@ -146,7 +143,10 @@ def write(self, data: AudioSegment, blocking: bool = False, loop: bool = False): """ if not self.write_flag: raise SoundDeviceError(f"{self.device_name} does not support writing!") - assert sd is not None + try: + import sounddevice as sd + except ImportError: + raise SoundDeviceError("SoundDeviceAPI requires sound_device module!") audio = np.array(data.get_array_of_samples()) sd.play( audio, @@ -187,7 +187,10 @@ def read(self, time: float, blocking: bool = False) -> AudioSegment: if not self.read_flag: raise SoundDeviceError(f"{self.device_name} does not support reading!") - assert sd is not None + try: + import sounddevice as sd + except ImportError: + raise SoundDeviceError("SoundDeviceAPI requires sound_device module!") frames = int(time * self.sample_rate) recording = sd.rec( frames=frames, @@ -214,7 +217,10 @@ def stop(self): - This is a convenience function to stop the sound device from playing or recording. - It will stop any sound that is currently playing and any recording currently happening. """ - assert sd is not None + try: + import sounddevice as sd + except ImportError: + raise SoundDeviceError("SoundDeviceAPI requires sound_device module!") sd.stop() def wait(self): @@ -226,7 +232,10 @@ def wait(self): - This is a convenience function to wait for the sound device to finish playing or recording. - It will block until the sound is played or recorded. """ - assert sd is not None + try: + import sounddevice as sd + except ImportError: + raise SoundDeviceError("SoundDeviceAPI requires sound_device module!") sd.wait() def open_write_stream( @@ -241,7 +250,10 @@ def open_write_stream( f"{self.device_name} does not support streaming writing!" ) - assert sd is not None + try: + import sounddevice as sd + except ImportError: + raise SoundDeviceError("SoundDeviceAPI requires sound_device module!") from sounddevice import CallbackFlags def callback(indata: NDArray, frames: int, time: Any, status: CallbackFlags): @@ -257,12 +269,11 @@ def callback(indata: NDArray, frames: int, time: Any, status: CallbackFlags): try: assert sd is not None - print(sample_rate) sample_rate = self.sample_rate if sample_rate is None else sample_rate - print(sample_rate) - print(channels) channels = self.out_channels if channels is None else channels - print(channels) + logging.warning( + f"Opening Output Stream with sample_rate: {sample_rate} channels: {channels} device: {self.device_number} dtype: {self.config.dtype}" + ) self.out_stream = sd.OutputStream( samplerate=sample_rate, channels=channels, @@ -313,7 +324,10 @@ def callback(indata: NDArray, frames: int, _, status: CallbackFlags): on_feedback(indata, flag_dict) try: - assert sd is not None + import sounddevice as sd + except ImportError: + raise SoundDeviceError("SoundDeviceAPI requires sound_device module!") + try: if self.config.consumer_sampling_rate is None: window_size_samples = self.config.block_size * self.sample_rate else: @@ -323,6 +337,9 @@ def callback(indata: NDArray, frames: int, _, status: CallbackFlags): / self.config.consumer_sampling_rate ) + logging.warning( + f"Opening Input Stream with sample_rate: {self.sample_rate} channels: {self.in_channels} device: {self.device_number} dtype: {self.config.dtype} blocksize: {window_size_samples}" + ) self.in_stream = sd.InputStream( samplerate=self.sample_rate, channels=self.in_channels, diff --git a/src/rai_core/rai/communication/sound_device/connector.py b/src/rai_core/rai/communication/sound_device/connector.py index 2dd79f56c..a61a97c2d 100644 --- a/src/rai_core/rai/communication/sound_device/connector.py +++ b/src/rai_core/rai/communication/sound_device/connector.py @@ -15,13 +15,6 @@ from typing import Callable, Literal, NamedTuple, Optional, Tuple -try: - import sounddevice as sd -except ImportError as e: - raise ImportError( - "The sounddevice package is required to use the SoundDeviceConnector." - ) from e - from rai.communication import HRIConnector, HRIMessage, HRIPayload from rai.communication.sound_device import ( SoundDeviceAPI, @@ -83,7 +76,6 @@ def __init__( self.configure_device(dev_target, dev_config) super().__init__(configured_targets, configured_sources) - sd.default.latency = ("low", "low") # type: ignore def get_audio_params(self, target: str) -> AudioParams: return AudioParams( From 18957ba625fa9e999ff5d1d36aff49d81e868bfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 26 Feb 2025 16:15:19 +0100 Subject: [PATCH 17/30] feat: remove runner --- examples/s2s-demo.py | 140 ------------------ .../o3de_test_bench/configs/o3de_config.yaml | 13 ++ src/rai_bench/rai_bench/results.csv | 5 + src/rai_core/rai/runners/base.py | 64 -------- src/rai_core/rai/runners/s2s.py | 94 ------------ 5 files changed, 18 insertions(+), 298 deletions(-) delete mode 100644 examples/s2s-demo.py create mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml create mode 100644 src/rai_bench/rai_bench/results.csv delete mode 100644 src/rai_core/rai/runners/base.py delete mode 100644 src/rai_core/rai/runners/s2s.py diff --git a/examples/s2s-demo.py b/examples/s2s-demo.py deleted file mode 100644 index 72cd63d02..000000000 --- a/examples/s2s-demo.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import time -from queue import Queue -from threading import Event, Thread -from typing import Dict, List - -import rclpy -from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from rai.agents.base import BaseAgent -from rai.communication import BaseConnector -from rai.communication.ros2.api import IROS2Message -from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig -from rai.runners import Speech2SpeechRunner -from rai.utils.model_initialization import get_llm_model - -from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage - -# NOTE: the Agent code included here is temporary until a dedicated speech agent is created -# it can still serve as a reference for writing your own RAI agents - -rclpy.init() - - -class LLMTextHandler(BaseCallbackHandler): - def __init__(self, connector: ROS2HRIConnector): - self.connector = connector - self.token_buffer = "" - - def on_llm_new_token(self, token: str, **kwargs): - self.token_buffer += token - if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]: - self.connector.send_all_targets(AIMessage(content=self.token_buffer)) - self.token_buffer = "" - - def on_llm_end( - self, - response, - *, - run_id, - parent_run_id=None, - **kwargs, - ): - if self.token_buffer: - self.connector.send_all_targets(AIMessage(content=self.token_buffer)) - self.token_buffer = "" - - -class S2SConversationalAgent(BaseAgent): - def __init__(self, connectors: Dict[str, BaseConnector]): # type: ignore - super().__init__(connectors=connectors) - self.message_history: List[HumanMessage | AIMessage | SystemMessage] = [ - SystemMessage( - content="Pretend you are a robot. Answer as if you were a robot." - ) - ] - self.speech_queue: Queue[InterfacesHRIMessage] = Queue() - - self.llm = get_llm_model(model_type="complex_model", streaming=True) - self._setup_ros_connector() - self.main_thread = None - self.stop_thread = Event() - - def run(self): - self.main_thread = Thread(target=self._main_loop) - self.main_thread.start() - - def _main_loop(self): - while not self.stop_thread.is_set(): - time.sleep(0.01) - speech = "" - while not self.speech_queue.empty(): - speech += " ".join(self.speech_queue.get().text) - print(f"Got sum speeach {speech}!") - if speech != "": - self.message_history.append(HumanMessage(content=speech)) - assert isinstance(self.connectors["ros2"], ROS2HRIConnector) - # ai_answer = AIMessage(content="Yes, I am Jar Jar Binks") - # self.connectors["ros2"].send_all_targets(ai_answer) - ai_answer = self.llm.invoke( - speech, - config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]}, - ) - self.message_history.append(ai_answer) # type: ignore - - def _on_from_human(self, msg: IROS2Message): - assert isinstance(msg, InterfacesHRIMessage) - self.speech_queue.put(msg) - - def _setup_ros_connector(self): - self.connectors["ros2"] = ROS2HRIConnector( - sources=[ - ( - "/from_human", - TopicConfig( - "rai_interfaces/msg/HRIMessage", - is_subscriber=True, - source_author="human", - subscriber_callback=self._on_from_human, - ), - ) - ], - targets=[ - ( - "/to_human", - TopicConfig( - "rai_interfaces/msg/HRIMessage", - source_author="ai", - is_subscriber=False, - ), - ) - ], - ) - - def stop(self): - assert isinstance(self.connectors["ros2"], ROS2HRIConnector) - self.connectors["ros2"].shutdown() - self.stop_thread.set() - if self.main_thread is not None: - self.main_thread.join() - - -talking_agent = S2SConversationalAgent({}) - - -runner = Speech2SpeechRunner([talking_agent]) -# runner = BaseRunner([talking_agent]) -runner.run() diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml new file mode 100644 index 000000000..759e98e7c --- /dev/null +++ b/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml @@ -0,0 +1,13 @@ +binary_path: /home/krachwal/binaries/rai/RAIManipulationDemo/RAIManipulationDemo.GameLauncher +robotic_stack_command: ros2 launch examples/manipulation-demo-no-binary.launch.py +required_services: + - /grounding_dino_classify + - /grounded_sam_segment + - /manipulator_move_to + - /spawn_entity + - /delete_entity +required_topics: + - /color_image5 + - /depth_image5 + - /color_camera_info5 +required_actions: [] diff --git a/src/rai_bench/rai_bench/results.csv b/src/rai_bench/rai_bench/results.csv new file mode 100644 index 000000000..27bece0ab --- /dev/null +++ b/src/rai_bench/rai_bench/results.csv @@ -0,0 +1,5 @@ +task,initial_score,simulation_config,final_score,total_time,number_of_tool_calls +"Manipulate objects, so that all carrots to the left side of the table (positive y)",0.0,src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml,0.0,13.648,5 +"Manipulate objects, so that all carrots to the left side of the table (positive y)",0.5,src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml,0.75,29.011,7 +"Manipulate objects, so that all cubes are adjacent to at least one cube",0.0,src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml,0.0,17.047,5 +"Manipulate objects, so that all cubes are adjacent to at least one cube",1.0,src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml,0.75,18.869,5 diff --git a/src/rai_core/rai/runners/base.py b/src/rai_core/rai/runners/base.py deleted file mode 100644 index 96333d123..000000000 --- a/src/rai_core/rai/runners/base.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import threading - -from rai.agents.base import BaseAgent - - -class BaseRunner: - """ - Manages and runs a collection of agents. - - Parameters - ---------- - agents : list of BaseAgent - A list of agent instances that implement `run` and `stop` methods. - - Attributes - ---------- - agents : list of BaseAgent - The list of agents managed by this runner. - """ - - def __init__(self, agents: list[BaseAgent]): - """ - Initializes the BaseRunner with a list of agents. - - Parameters - ---------- - agents : list of BaseAgent - A list of agent instances that will be managed and executed. - """ - self.agents = agents - - def run(self): - """ - Starts all agents and keeps them running indefinitely. - - This method runs each agent's `run` method and waits indefinitely for a stop signal. - If a `KeyboardInterrupt` is received, it logs the interruption and stops all agents. - """ - for agent in self.agents: - agent.run() - - stop_signal = threading.Event() - - try: - stop_signal.wait() - except KeyboardInterrupt: - logging.info("KeyboardInterrupt received! Shutting down...") - - for agent in self.agents: - agent.stop() diff --git a/src/rai_core/rai/runners/s2s.py b/src/rai_core/rai/runners/s2s.py deleted file mode 100644 index c38d1934c..000000000 --- a/src/rai_core/rai/runners/s2s.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional - -from rai.agents import TextToSpeechAgent, VoiceRecognitionAgent -from rai.agents.base import BaseAgent -from rai.communication.sound_device import SoundDeviceConfig -from rai.runners import BaseRunner -from rai_asr.models import LocalWhisper, OpenWakeWord, SileroVAD -from rai_tts.models import OpenTTS - - -@dataclass -class TTSConfig: - device_name: str = "default" - url: str = "http://localhost:5500/api/tts" - voice: str = "larynx:blizzard_lessac-glow_tts" - - -@dataclass -class ASRConfig: - device_name: str = "default" - vad_threshold: float = 0.5 - oww_threshold: float = 0.1 - whisper_model: str = "tiny" - oww_model: str = "hey jarvis" - - -class Speech2SpeechRunner(BaseRunner): - """ - Manages a speech-to-speech pipeline by integrating text-to-speech (TTS) and automatic speech recognition (ASR) agents. - - Parameters - ---------- - agents : Optional[list[BaseAgent]], optional - A list of existing agents to be included in the runner, by default None. - tts_cfg : TTSConfig, optional - Configuration for the text-to-speech agent, by default TTSConfig(). - asr_config : ASRConfig, optional - Configuration for the automatic speech recognition agent, by default ASRConfig(). - """ - - def __init__( - self, - agents: Optional[list[BaseAgent]] = None, - tts_cfg: TTSConfig = TTSConfig(), - asr_config: ASRConfig = ASRConfig(), - ): - if agents is None: - agents = [] - super().__init__(agents) - tts = self._setup_tts_agent(tts_cfg) - asr = self._setup_asr_agent(asr_config) - self.agents.extend([tts, asr]) - - def _setup_tts_agent(self, cfg: TTSConfig): - speaker_config = SoundDeviceConfig( - stream=True, - is_output=True, - device_name=cfg.device_name, - ) - tts = OpenTTS(cfg.url, cfg.voice) - return TextToSpeechAgent(speaker_config, "text_to_speech", tts) - - def _setup_asr_agent(self, cfg: ASRConfig): - vad = SileroVAD(threshold=cfg.vad_threshold) - oww = OpenWakeWord(cfg.oww_model, cfg.oww_threshold) - whisper = LocalWhisper( - cfg.whisper_model, vad.sampling_rate - ) # models should have compatible sampling rate - microphone_config = SoundDeviceConfig( - stream=True, - device_name=cfg.device_name, - consumer_sampling_rate=vad.sampling_rate, - is_input=True, - ) - asr_agent = VoiceRecognitionAgent( - microphone_config, "automatic_speech_recognition", whisper, vad - ) - asr_agent.add_detection_model(oww, pipeline="record") - return asr_agent From 7e7aa40f45ae5fd094376ff76c7390f563debdb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 26 Feb 2025 16:25:13 +0100 Subject: [PATCH 18/30] chore: remove trash file --- .../o3de_test_bench/configs/o3de_config.yaml | 13 ------------- src/rai_bench/rai_bench/results.csv | 5 ----- 2 files changed, 18 deletions(-) delete mode 100644 src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml delete mode 100644 src/rai_bench/rai_bench/results.csv diff --git a/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml b/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml deleted file mode 100644 index 759e98e7c..000000000 --- a/src/rai_bench/rai_bench/o3de_test_bench/configs/o3de_config.yaml +++ /dev/null @@ -1,13 +0,0 @@ -binary_path: /home/krachwal/binaries/rai/RAIManipulationDemo/RAIManipulationDemo.GameLauncher -robotic_stack_command: ros2 launch examples/manipulation-demo-no-binary.launch.py -required_services: - - /grounding_dino_classify - - /grounded_sam_segment - - /manipulator_move_to - - /spawn_entity - - /delete_entity -required_topics: - - /color_image5 - - /depth_image5 - - /color_camera_info5 -required_actions: [] diff --git a/src/rai_bench/rai_bench/results.csv b/src/rai_bench/rai_bench/results.csv deleted file mode 100644 index 27bece0ab..000000000 --- a/src/rai_bench/rai_bench/results.csv +++ /dev/null @@ -1,5 +0,0 @@ -task,initial_score,simulation_config,final_score,total_time,number_of_tool_calls -"Manipulate objects, so that all carrots to the left side of the table (positive y)",0.0,src/rai_bench/rai_bench/o3de_test_bench/configs/scene1.yaml,0.0,13.648,5 -"Manipulate objects, so that all carrots to the left side of the table (positive y)",0.5,src/rai_bench/rai_bench/o3de_test_bench/configs/scene2.yaml,0.75,29.011,7 -"Manipulate objects, so that all cubes are adjacent to at least one cube",0.0,src/rai_bench/rai_bench/o3de_test_bench/configs/scene3.yaml,0.0,17.047,5 -"Manipulate objects, so that all cubes are adjacent to at least one cube",1.0,src/rai_bench/rai_bench/o3de_test_bench/configs/scene4.yaml,0.75,18.869,5 From 6e02636e091b23f9ac26777ae2569d6631da25ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 26 Feb 2025 16:38:20 +0100 Subject: [PATCH 19/30] fix: race condition on cancelling speech task --- src/rai_core/rai/agents/tts_agent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index 865233da7..d0d66d59c 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -182,8 +182,8 @@ def _transcription_thread(self): except Empty: continue audio = self.model.get_speech(data) - self.audio_queue.put(audio) - self.playback_data.playing = True + if self.playback_data.playing: + self.audio_queue.put(audio) def _setup_ros2_connector(self): to_human = TopicConfig( @@ -210,6 +210,7 @@ def _on_to_human_message(self, message: IROS2Message): msg = ROS2HRIMessage.from_ros2(message, "ai") self.logger.debug(f"Receieved message from human: {message.text}") self.text_queue.put(msg.text) + self.playback_data.playing = True def _on_command_message(self, message: IROS2Message): assert isinstance(message, String) From 92396a84c0e08e29c58f4fd1d32dfaf00b89453f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 26 Feb 2025 18:14:58 +0100 Subject: [PATCH 20/30] fix: race condition on single transcribe queue --- examples/s2s/conversational.py | 2 + examples/s2s/run.sh | 1 + src/rai_asr/rai_asr/models/base.py | 9 ++++ src/rai_asr/rai_asr/models/open_wake_word.py | 6 +++ src/rai_asr/rai_asr/models/silero_vad.py | 6 +++ src/rai_core/rai/agents/tts_agent.py | 55 ++++++++++++++------ src/rai_core/rai/agents/voice_agent.py | 2 +- 7 files changed, 64 insertions(+), 17 deletions(-) diff --git a/examples/s2s/conversational.py b/examples/s2s/conversational.py index a59de6226..8aa654561 100644 --- a/examples/s2s/conversational.py +++ b/examples/s2s/conversational.py @@ -43,6 +43,7 @@ def __init__(self, connector: ROS2HRIConnector): def on_llm_new_token(self, token: str, **kwargs): self.token_buffer += token if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]: + logging.info(f"Sending token buffer: {self.token_buffer}") self.connector.send_all_targets(AIMessage(content=self.token_buffer)) self.token_buffer = "" @@ -55,6 +56,7 @@ def on_llm_end( **kwargs, ): if self.token_buffer: + logging.info(f"Sending token buffer: {self.token_buffer}") self.connector.send_all_targets(AIMessage(content=self.token_buffer)) self.token_buffer = "" diff --git a/examples/s2s/run.sh b/examples/s2s/run.sh index 660b9ca27..184c96274 100755 --- a/examples/s2s/run.sh +++ b/examples/s2s/run.sh @@ -35,6 +35,7 @@ handle_sigint() { # Main logic main() { + echo "NOTE: RUNNING ALL SCRIPTS IN PARALLEL IS AN EXPERIMENTAL FEATURE, AND IS NOT GUARANTEED TO WORK WITHOUT ISSUES. RUN SCRIPTS IN SEPARATE TERMINALS FOR BEST RESULTS " # Set up trap for SIGINT (Ctrl+C) trap handle_sigint SIGINT diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py index a1c9e4c5b..c43dc481f 100644 --- a/src/rai_asr/rai_asr/models/base.py +++ b/src/rai_asr/rai_asr/models/base.py @@ -80,6 +80,15 @@ def detect( """ pass + @abstractmethod + def reset(self): + """ + Abstract method for resetting the voice detection model. + + Subclasses must implement this method to reset the internal state of the model. + """ + pass + class BaseTranscriptionModel(ABC): """ diff --git a/src/rai_asr/rai_asr/models/open_wake_word.py b/src/rai_asr/rai_asr/models/open_wake_word.py index c084eaff7..cbff9a0cc 100644 --- a/src/rai_asr/rai_asr/models/open_wake_word.py +++ b/src/rai_asr/rai_asr/models/open_wake_word.py @@ -109,3 +109,9 @@ def detect( self.model.reset() return True, ret return False, ret + + def reset(self): + """ + Resets the wake word detection model. + """ + self.model.reset() diff --git a/src/rai_asr/rai_asr/models/silero_vad.py b/src/rai_asr/rai_asr/models/silero_vad.py index dd325f836..d0a0fe9ef 100644 --- a/src/rai_asr/rai_asr/models/silero_vad.py +++ b/src/rai_asr/rai_asr/models/silero_vad.py @@ -113,3 +113,9 @@ def detect( ret.update({self.model_name: {"vad_confidence": vad_confidence}}) return vad_confidence > self.threshold, ret + + def reset(self): + """ + Resets the voice activity detection model. + """ + self.model.reset() diff --git a/src/rai_core/rai/agents/tts_agent.py b/src/rai_core/rai/agents/tts_agent.py index d0d66d59c..b93256c38 100644 --- a/src/rai_core/rai/agents/tts_agent.py +++ b/src/rai_core/rai/agents/tts_agent.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from threading import Event, Thread from typing import TYPE_CHECKING, Optional +from uuid import uuid4 from numpy._typing import NDArray from pydub import AudioSegment @@ -99,8 +100,9 @@ def __init__( ros2_connector = self._setup_ros2_connector() super().__init__(connectors={"ros2": ros2_connector, "speaker": speaker}) - self.text_queue = Queue() - self.audio_queue = Queue() + self.current_transcription_id = str(uuid4())[0:8] + self.text_queues: dict[str, Queue] = {self.current_transcription_id: Queue()} + self.audio_queues: dict[str, Queue] = {self.current_transcription_id: Queue()} self.tog_play_event = Event() self.stop_event = Event() @@ -141,14 +143,16 @@ def _speaker_callback(self, outdata, frames, time, status_dict): if self.playback_data.playing: if self.playback_data.current_segment is None: try: - self.playback_data.current_segment = self.audio_queue.get( - block=False - ) + self.playback_data.current_segment = self.audio_queues[ + self.current_transcription_id + ].get(block=False) self.playback_data.data = np.array( self.playback_data.current_segment.get_array_of_samples() # type: ignore ).reshape(-1, self.playback_data.channels) except Empty: pass + except KeyError: + pass if self.playback_data.data is not None: current_frame = self.playback_data.current_frame chunksize = min(len(self.playback_data.data) - current_frame, frames) @@ -177,13 +181,21 @@ def stop(self): def _transcription_thread(self): while not self.terminate_agent.wait(timeout=0.01): - try: - data = self.text_queue.get(block=False) - except Empty: - continue - audio = self.model.get_speech(data) - if self.playback_data.playing: - self.audio_queue.put(audio) + if self.current_transcription_id in self.text_queues: + try: + data = self.text_queues[self.current_transcription_id].get( + block=False + ) + except Empty: + continue + audio = self.model.get_speech(data) + try: + self.audio_queues[self.current_transcription_id].put(audio) + except KeyError as e: + self.logger.error( + f"Could not find queue for {self.current_transcription_id}: queuse: {self.audio_queues.keys()}" + ) + raise e def _setup_ros2_connector(self): to_human = TopicConfig( @@ -208,8 +220,11 @@ def _setup_ros2_connector(self): def _on_to_human_message(self, message: IROS2Message): assert isinstance(message, HRIMessage) msg = ROS2HRIMessage.from_ros2(message, "ai") - self.logger.debug(f"Receieved message from human: {message.text}") - self.text_queue.put(msg.text) + self.logger.info(f"Receieved message from human: {message.text}") + self.logger.warning( + f"Starting playback, current id: {self.current_transcription_id}" + ) + self.text_queues[self.current_transcription_id].put(msg.text) self.playback_data.playing = True def _on_command_message(self, message: IROS2Message): @@ -223,8 +238,16 @@ def _on_command_message(self, message: IROS2Message): self.playback_data.playing = False elif message.data == "stop": self.playback_data.playing = False - while not self.audio_queue.empty(): - _ = self.audio_queue.get() + previous_id = self.current_transcription_id + self.logger.warning(f"Stopping playback, previous id: {previous_id}") + self.current_transcription_id = str(uuid4())[0:8] + self.audio_queues[self.current_transcription_id] = Queue() + self.text_queues[self.current_transcription_id] = Queue() + try: + del self.audio_queues[previous_id] + del self.text_queues[previous_id] + except KeyError: + pass self.playback_data.data = None self.playback_data.current_frame = 0 self.playback_data.current_segment = None diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index ab3027cc5..e52635a49 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -191,7 +191,6 @@ def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): voice_detected, output_parameters = self.vad(indata, {}) self.logger.debug(f"Voice detected: {voice_detected}: {output_parameters}") should_record = False - # TODO: second condition is temporary if voice_detected and not self.recording_started: should_record = self._should_record(indata, output_parameters) @@ -241,6 +240,7 @@ def _should_record( detected, output = model(audio_data, input_parameters) self.logger.debug(f"detected {detected}, output {output}") if detected: + model.reset() return True return False From cb60dbc6e6232625018d7e301b2e80d42ab93399 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 27 Feb 2025 14:32:54 +0100 Subject: [PATCH 21/30] fix: send voice commands only on changes --- src/rai_core/rai/agents/voice_agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index e52635a49..29956515d 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -107,6 +107,7 @@ def __init__( self.active_thread = "" self.transcription_threads: dict[str, ThreadData] = {} self.transcription_buffers: dict[str, list[NDArray]] = {} + self.last_command_sent = "" def __call__(self): self.run() @@ -230,8 +231,11 @@ def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.transcription_threads[self.active_thread]["thread"].start() self.active_thread = "" self._send_ros2_message("stop", "/voice_commands") + self.last_command_sent = "stop" elif sample_time - self.grace_period_start > self.grace_period: - self._send_ros2_message("play", "/voice_commands") + if self.last_command_sent == "stop": + self._send_ros2_message("play", "/voice_commands") + self.last_command_sent = "start" def _should_record( self, audio_data: NDArray, input_parameters: dict[str, Any] From d4340d5cc0a43a99fa2f8eaaa957051cd5fe8f8e Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 27 Feb 2025 14:36:59 +0100 Subject: [PATCH 22/30] Revert "fix: send voice commands only on changes" This reverts commit cb60dbc6e6232625018d7e301b2e80d42ab93399. --- src/rai_core/rai/agents/voice_agent.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index 29956515d..e52635a49 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -107,7 +107,6 @@ def __init__( self.active_thread = "" self.transcription_threads: dict[str, ThreadData] = {} self.transcription_buffers: dict[str, list[NDArray]] = {} - self.last_command_sent = "" def __call__(self): self.run() @@ -231,11 +230,8 @@ def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.transcription_threads[self.active_thread]["thread"].start() self.active_thread = "" self._send_ros2_message("stop", "/voice_commands") - self.last_command_sent = "stop" elif sample_time - self.grace_period_start > self.grace_period: - if self.last_command_sent == "stop": - self._send_ros2_message("play", "/voice_commands") - self.last_command_sent = "start" + self._send_ros2_message("play", "/voice_commands") def _should_record( self, audio_data: NDArray, input_parameters: dict[str, Any] From 161fe5b564deedaf9370de125985feba26842b6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 27 Feb 2025 15:13:25 +0100 Subject: [PATCH 23/30] fix: minimise ros2 traffic --- src/rai_core/rai/agents/voice_agent.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index e52635a49..378814d63 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -107,6 +107,7 @@ def __init__( self.active_thread = "" self.transcription_threads: dict[str, ThreadData] = {} self.transcription_buffers: dict[str, list[NDArray]] = {} + self.is_playing = True def __call__(self): self.run() @@ -215,6 +216,7 @@ def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.logger.debug("Voice detected... resetting grace period") self.grace_period_start = sample_time self._send_ros2_message("pause", "/voice_commands") + self.is_playing = False if ( self.recording_started and sample_time - self.grace_period_start > self.grace_period @@ -230,12 +232,16 @@ def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.transcription_threads[self.active_thread]["thread"].start() self.active_thread = "" self._send_ros2_message("stop", "/voice_commands") + self.is_playing = False elif sample_time - self.grace_period_start > self.grace_period: self._send_ros2_message("play", "/voice_commands") + self.is_playing = True def _should_record( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> bool: + if len(self.should_record_pipeline) == 0: + return True for model in self.should_record_pipeline: detected, output = model(audio_data, input_parameters) self.logger.debug(f"detected {detected}, output {output}") From e2d08d429e8f5ae328d001b2ed85e03e97436bc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 27 Feb 2025 15:53:19 +0100 Subject: [PATCH 24/30] docs: add S2S docs --- docs/human_robot_interface/voice_interface.md | 65 ++++++++++++++----- examples/s2s/asr.py | 3 +- examples/s2s/conversational.py | 1 - examples/s2s/tts.py | 1 - 4 files changed, 48 insertions(+), 22 deletions(-) diff --git a/docs/human_robot_interface/voice_interface.md b/docs/human_robot_interface/voice_interface.md index 541ad5c30..368e5e477 100644 --- a/docs/human_robot_interface/voice_interface.md +++ b/docs/human_robot_interface/voice_interface.md @@ -1,36 +1,65 @@ # Human Robot Interface via Voice -> [!IMPORTANT] -> RAI_ASR supports both local Whisper models and OpenAI Whisper (cloud). When using the cloud version, the OPENAI_API_KEY environment variable must be set with a valid API key. +RAI provides two ROS enabled agents for Speech to Speech communication. -## Running example +## Automatic Speech Recognition Agent + +See `examples/s2s/asr.py` for an example usage. + +The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required voice activity detection (eg. `SileroVAD`) and transcription model e.g. (`LocalWhisper`), as well as optionally additional models to decide if the transcription should start (e.g. `OpenWakeWord`). -When your robot's whoami package is ready, run the following: +The Agent publishes information on two topics: -> [!TIP] -> Make sure rai_whoami is running. +`/from_human`: `rai_interfaces/msg/HRIMessages` - containing transcriptions of the recorded speech -** Parameters ** -recording_device: The device you want to record with. Check available with: +`/voice_commands`: `std_msgs/msg/String` - containing control commands, to inform the consumer if speech is currently detected (`{"data": "pause"}`), was detected, and now it stopped (`{"data": "play"}`), and if speech was transcribed (`{"data": "stop"}`). -```bash -python -c 'import sounddevice as sd; print(sd.query_devices())' +The Agent utilises sounddevice module to access user's microphone, by default the `"default"` sound device is used. +To get information about available sounddeives use: + +``` +python -c "import sounddevice; sounddevice.query_devices()" ``` -keep_speaker_busy: some speakers may go into low power mode, which may result in truncated speech beginnings. Set to true to play low frequency, low volume noise to prevent sleep mode. +The device can be identifed by name and passed to the configuration. + +## TextToSpeechAgent + +See `examples/s2s/wtts.py` for an example usage. + +The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required TextToSpeech model (e.g. `OpenTTS`). +The Agent listens for information on two topics: + +`/to_human`: `rai_interfaces/msg/HRIMessages` - containing responses to be played to human. These responses are then transcribed and put into the playback queue. + +`/voice_commands`: `std_msgs/msg/String` - containing control commands, to pause current playback (`{"data": "pause"}`), start/continue playback (`{"data": "play"}`), or stop the playback and drop the current playback queue (`{"data": "play"}`). + +The Agent utilises sounddevice module to access user's speaker, by default the `"default"` sound device is used. +To get information about available sounddeives use: + +``` +python -c "import sounddevice; sounddevice.query_devices()" +``` + +The device can be identifed by name and passed to the configuration. ### OpenTTS -```bash -ros2 launch rai_bringup hri.launch.py tts_vendor:=opentts robot_description_package:= recording_device:=0 keep_speaker_busy:=(true|false) asr_vendor:=(whisper|openai) +To run OpenTTS (and the example) a docker server containing the model must be running. +To start it run: + +``` +docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak ``` -> [!NOTE] -> Run OpenTTS with `docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak` +## Running example -### ElevenLabs +To run the provided example of S2S configuration with a minimal LLM-based agent run in 4 separate terminals: -```bash -ros2 launch rai_bringup hri.launch.py robot_description_package:= recording_device:=0 keep_speaker_busy:=(true|false) asr_vendor:=(whisper|openai) +``` +$ docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak +$ python ./examples/s2s/asr.py +$ python examples/s2s/tts.py +$ python examples/s2s/conversational.py ``` diff --git a/examples/s2s/asr.py b/examples/s2s/asr.py index 462d034ba..5ea131652 100644 --- a/examples/s2s/asr.py +++ b/examples/s2s/asr.py @@ -100,7 +100,7 @@ def parse_arguments(): ros2_name = "rai_asr_agent" agent = VoiceRecognitionAgent(microphone_configuration, ros2_name, whisper, vad) - agent.add_detection_model(oww, pipeline="record") + # agent.add_detection_model(oww, pipeline="record") agent.run() @@ -114,6 +114,5 @@ def cleanup(signum, frame): signal.signal(signal.SIGINT, cleanup) - print("Runnin") while True: time.sleep(1) diff --git a/examples/s2s/conversational.py b/examples/s2s/conversational.py index 8aa654561..b0e8b46d3 100644 --- a/examples/s2s/conversational.py +++ b/examples/s2s/conversational.py @@ -168,6 +168,5 @@ def cleanup(signum, frame): signal.signal(signal.SIGINT, cleanup) - print("Runnin") while True: time.sleep(1) diff --git a/examples/s2s/tts.py b/examples/s2s/tts.py index 997c8d4a5..f5e33c13d 100644 --- a/examples/s2s/tts.py +++ b/examples/s2s/tts.py @@ -71,6 +71,5 @@ def cleanup(signum, frame): signal.signal(signal.SIGINT, cleanup) - print("Runnin") while True: time.sleep(1) From 7d2e24da896a62de32e6d37e4da8b8c10ed6a8be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 28 Feb 2025 11:03:15 +0100 Subject: [PATCH 25/30] fix: minimise ros2 traffic -- add missing if --- src/rai_core/rai/agents/voice_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index 378814d63..1cb934e89 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -233,7 +233,9 @@ def _on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.active_thread = "" self._send_ros2_message("stop", "/voice_commands") self.is_playing = False - elif sample_time - self.grace_period_start > self.grace_period: + elif not self.is_playing and ( + sample_time - self.grace_period_start > self.grace_period + ): self._send_ros2_message("play", "/voice_commands") self.is_playing = True From 3d5a1eaaa0d68cb0fb83727432b7c8e08052ffb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 28 Feb 2025 12:25:18 +0100 Subject: [PATCH 26/30] fix: conversational example use history --- examples/s2s/conversational.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/s2s/conversational.py b/examples/s2s/conversational.py index b0e8b46d3..b50660068 100644 --- a/examples/s2s/conversational.py +++ b/examples/s2s/conversational.py @@ -91,10 +91,8 @@ def _main_loop(self): if speech != "": self.message_history.append(HumanMessage(content=speech)) assert isinstance(self.connectors["ros2"], ROS2HRIConnector) - # ai_answer = AIMessage(content="Yes, I am Jar Jar Binks") - # self.connectors["ros2"].send_all_targets(ai_answer) ai_answer = self.llm.invoke( - speech, + self.message_history, config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]}, ) self.message_history.append(ai_answer) # type: ignore From 3579e8e16709c12fd4b47347e4e4d448c3607295 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 28 Feb 2025 12:30:37 +0100 Subject: [PATCH 27/30] docs: fix typos --- docs/human_robot_interface/voice_interface.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/human_robot_interface/voice_interface.md b/docs/human_robot_interface/voice_interface.md index 368e5e477..f10c4d734 100644 --- a/docs/human_robot_interface/voice_interface.md +++ b/docs/human_robot_interface/voice_interface.md @@ -15,7 +15,7 @@ The Agent publishes information on two topics: `/voice_commands`: `std_msgs/msg/String` - containing control commands, to inform the consumer if speech is currently detected (`{"data": "pause"}`), was detected, and now it stopped (`{"data": "play"}`), and if speech was transcribed (`{"data": "stop"}`). The Agent utilises sounddevice module to access user's microphone, by default the `"default"` sound device is used. -To get information about available sounddeives use: +To get information about available sounddevices use: ``` python -c "import sounddevice; sounddevice.query_devices()" @@ -25,7 +25,7 @@ The device can be identifed by name and passed to the configuration. ## TextToSpeechAgent -See `examples/s2s/wtts.py` for an example usage. +See `examples/s2s/tts.py` for an example usage. The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required TextToSpeech model (e.g. `OpenTTS`). The Agent listens for information on two topics: @@ -60,6 +60,6 @@ To run the provided example of S2S configuration with a minimal LLM-based agent ``` $ docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak $ python ./examples/s2s/asr.py -$ python examples/s2s/tts.py -$ python examples/s2s/conversational.py +$ python ./examples/s2s/tts.py +$ python ./examples/s2s/conversational.py ``` From 6123d5cb358a34e711a30204feec6b035070579d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 28 Feb 2025 15:16:35 +0100 Subject: [PATCH 28/30] chore: add comments on example --- examples/s2s/asr.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/s2s/asr.py b/examples/s2s/asr.py index 5ea131652..6f8027cb6 100644 --- a/examples/s2s/asr.py +++ b/examples/s2s/asr.py @@ -94,20 +94,19 @@ def parse_arguments(): vad = SileroVAD(args.vad_sampling_rate, args.vad_threshold) oww = OpenWakeWord("hey jarvis", args.oww_threshold) whisper = LocalWhisper("tiny", args.vad_sampling_rate) + # you can easily switch the the provider by changing the whisper object # whisper = OpenAIWhisper("whisper-1", args.vad_sampling_rate, "en") rclpy.init() ros2_name = "rai_asr_agent" agent = VoiceRecognitionAgent(microphone_configuration, ros2_name, whisper, vad) + # optionally add additional models to decide when to record data for transcription # agent.add_detection_model(oww, pipeline="record") agent.run() def cleanup(signum, frame): - print("\nCustom handler: Caught SIGINT (Ctrl+C).") - print("Performing cleanup") - # Optionally exit the program agent.stop() rclpy.shutdown() exit(0) From e69437eb4c6ae6f7b860005957b35347e3e74445 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 28 Feb 2025 15:18:59 +0100 Subject: [PATCH 29/30] chore: remove useless comment --- src/rai_core/rai/communication/ros2/connectors.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/rai_core/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py index 170781326..ff62dc7ff 100644 --- a/src/rai_core/rai/communication/ros2/connectors.py +++ b/src/rai_core/rai/communication/ros2/connectors.py @@ -232,9 +232,6 @@ def __init__( self._thread = threading.Thread(target=self._executor.spin) self._thread.start() - # def run(self): - # self._executor.spin() - def _configure_publishers(self, targets: List[Tuple[str, TopicConfig]]): for target in targets: self._topic_api.configure_publisher(target[0], target[1]) From 582b9447e2375ae202ba162f927fa61ad249616b Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Wed, 12 Feb 2025 20:20:44 +0100 Subject: [PATCH 30/30] feat: add ElevenLabsTTS --- src/rai_tts/rai_tts/models/__init__.py | 3 +- src/rai_tts/rai_tts/models/elevenlabs_tts.py | 112 +++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 src/rai_tts/rai_tts/models/elevenlabs_tts.py diff --git a/src/rai_tts/rai_tts/models/__init__.py b/src/rai_tts/rai_tts/models/__init__.py index b1187962b..e4744cf72 100644 --- a/src/rai_tts/rai_tts/models/__init__.py +++ b/src/rai_tts/rai_tts/models/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .base import TTSModel, TTSModelError +from .elevenlabs_tts import ElevenLabsTTS from .open_tts import OpenTTS -__all__ = ["OpenTTS", "TTSModel", "TTSModelError"] +__all__ = ["ElevenLabsTTS", "OpenTTS", "TTSModel", "TTSModelError"] diff --git a/src/rai_tts/rai_tts/models/elevenlabs_tts.py b/src/rai_tts/rai_tts/models/elevenlabs_tts.py new file mode 100644 index 000000000..699bfb6d5 --- /dev/null +++ b/src/rai_tts/rai_tts/models/elevenlabs_tts.py @@ -0,0 +1,112 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from io import BytesIO +from typing import Tuple + +from elevenlabs.client import ElevenLabs +from elevenlabs.types import Voice +from elevenlabs.types.voice_settings import VoiceSettings +from pydub import AudioSegment + +from rai_tts.models import TTSModel, TTSModelError + + +class ElevenLabsTTS(TTSModel): + """ + A text-to-speech (TTS) model interface for ElevenLabs. + + Parameters + ---------- + voice : str, optional + The voice model to use. + base_url : str, optional + The API endpoint for the ElevenLabs API, by default None. + """ + + def __init__( + self, + voice: str, + base_url: str | None = None, + ): + api_key = os.getenv(key="ELEVENLABS_API_KEY") + if api_key is None: + raise TTSModelError("ELEVENLABS_API_KEY environment variable is not set.") + + self.client = ElevenLabs(base_url=base_url, api_key=api_key) + self.voice_settings = VoiceSettings( + stability=0.7, + similarity_boost=0.5, + ) + + voices = self.client.voices.get_all().voices + voice_id = next((v.voice_id for v in voices if v.name == voice), None) + if voice_id is None: + raise TTSModelError(f"Voice {voice} not found") + self.voice = Voice(voice_id=voice_id, settings=self.voice_settings) + + def get_speech(self, text: str) -> AudioSegment: + """ + Converts text into speech using the ElevenLabs API. + + Parameters + ---------- + text : str + The input text to be converted into speech. + + Returns + ------- + AudioSegment + The generated speech as an `AudioSegment` object. + + Raises + ------ + TTSModelError + If there is an issue with the request or the ElevenLabs API is unreachable. + If the response does not contain valid audio data. + """ + try: + response = self.client.generate( + text=text, + voice=self.voice, + optimize_streaming_latency=4, + ) + audio_data = b"".join(response) + except Exception as e: + raise TTSModelError(f"Error occurred while fetching audio: {e}") from e + + # Load audio into memory (ElevenLabs returns MP3) + audio_segment = AudioSegment.from_mp3(BytesIO(audio_data)) + return audio_segment + + def get_tts_params(self) -> Tuple[int, int]: + """ + Returns TTS sampling rate and channels. + + The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation. + + Returns + ------- + Tuple[int, int] + sample rate, channels + + Raises + ------ + TTSModelError + If there is an issue with the request or the ElevenLabs API is unreachable. + If the response does not contain valid audio data. + """ + data = self.get_speech("A") + return data.frame_rate, 1