-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add tts to rai core #419
base: development
Are you sure you want to change the base?
Changes from all commits
03a0e6f
ef39181
983d148
87adc65
f2eef35
1591a59
b018bb5
8644b5b
64b55b8
81a252a
b82f7e4
aacf6c7
d41053f
237a279
ca88199
5905697
18957ba
7e7aa40
6e02636
92396a8
cb60dbc
d4340d5
161fe5b
e2d08d4
7d2e24d
3d5a1ea
3579e8e
6123d5c
e69437e
582b944
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,65 @@ | ||
# Human Robot Interface via Voice | ||
|
||
> [!IMPORTANT] | ||
> RAI_ASR supports both local Whisper models and OpenAI Whisper (cloud). When using the cloud version, the OPENAI_API_KEY environment variable must be set with a valid API key. | ||
RAI provides two ROS enabled agents for Speech to Speech communication. | ||
|
||
## Running example | ||
## Automatic Speech Recognition Agent | ||
|
||
See `examples/s2s/asr.py` for an example usage. | ||
|
||
The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required voice activity detection (eg. `SileroVAD`) and transcription model e.g. (`LocalWhisper`), as well as optionally additional models to decide if the transcription should start (e.g. `OpenWakeWord`). | ||
|
||
When your robot's whoami package is ready, run the following: | ||
The Agent publishes information on two topics: | ||
|
||
> [!TIP] | ||
> Make sure rai_whoami is running. | ||
`/from_human`: `rai_interfaces/msg/HRIMessages` - containing transcriptions of the recorded speech | ||
|
||
** Parameters ** | ||
recording_device: The device you want to record with. Check available with: | ||
`/voice_commands`: `std_msgs/msg/String` - containing control commands, to inform the consumer if speech is currently detected (`{"data": "pause"}`), was detected, and now it stopped (`{"data": "play"}`), and if speech was transcribed (`{"data": "stop"}`). | ||
|
||
```bash | ||
python -c 'import sounddevice as sd; print(sd.query_devices())' | ||
The Agent utilises sounddevice module to access user's microphone, by default the `"default"` sound device is used. | ||
To get information about available sounddevices use: | ||
|
||
``` | ||
python -c "import sounddevice; sounddevice.query_devices()" | ||
``` | ||
|
||
keep_speaker_busy: some speakers may go into low power mode, which may result in truncated speech beginnings. Set to true to play low frequency, low volume noise to prevent sleep mode. | ||
The device can be identifed by name and passed to the configuration. | ||
|
||
## TextToSpeechAgent | ||
|
||
See `examples/s2s/tts.py` for an example usage. | ||
|
||
The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required TextToSpeech model (e.g. `OpenTTS`). | ||
The Agent listens for information on two topics: | ||
|
||
`/to_human`: `rai_interfaces/msg/HRIMessages` - containing responses to be played to human. These responses are then transcribed and put into the playback queue. | ||
|
||
`/voice_commands`: `std_msgs/msg/String` - containing control commands, to pause current playback (`{"data": "pause"}`), start/continue playback (`{"data": "play"}`), or stop the playback and drop the current playback queue (`{"data": "play"}`). | ||
|
||
The Agent utilises sounddevice module to access user's speaker, by default the `"default"` sound device is used. | ||
To get information about available sounddeives use: | ||
|
||
``` | ||
python -c "import sounddevice; sounddevice.query_devices()" | ||
``` | ||
|
||
The device can be identifed by name and passed to the configuration. | ||
|
||
### OpenTTS | ||
|
||
```bash | ||
ros2 launch rai_bringup hri.launch.py tts_vendor:=opentts robot_description_package:=<robot_description_package> recording_device:=0 keep_speaker_busy:=(true|false) asr_vendor:=(whisper|openai) | ||
To run OpenTTS (and the example) a docker server containing the model must be running. | ||
|
||
To start it run: | ||
|
||
``` | ||
docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak | ||
``` | ||
|
||
> [!NOTE] | ||
> Run OpenTTS with `docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak` | ||
## Running example | ||
|
||
### ElevenLabs | ||
To run the provided example of S2S configuration with a minimal LLM-based agent run in 4 separate terminals: | ||
|
||
```bash | ||
ros2 launch rai_bringup hri.launch.py robot_description_package:=<robot_description_package> recording_device:=0 keep_speaker_busy:=(true|false) asr_vendor:=(whisper|openai) | ||
``` | ||
$ docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak | ||
$ python ./examples/s2s/asr.py | ||
$ python ./examples/s2s/tts.py | ||
$ python ./examples/s2s/conversational.py | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import signal | ||
import time | ||
|
||
import rclpy | ||
from rai.agents import VoiceRecognitionAgent | ||
from rai.communication.sound_device.api import SoundDeviceConfig | ||
|
||
from rai_asr.models import LocalWhisper, OpenWakeWord, SileroVAD | ||
|
||
VAD_THRESHOLD = 0.8 # Note that this might be different depending on your device | ||
OWW_THRESHOLD = 0.1 # Note that this might be different depending on your device | ||
|
||
VAD_SAMPLING_RATE = 16000 # Or 8000 | ||
DEFAULT_BLOCKSIZE = 1280 | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser( | ||
description="Voice Activity Detection and Wake Word Detection Configuration", | ||
allow_abbrev=True, | ||
) | ||
|
||
# Predefined arguments | ||
parser.add_argument( | ||
"--vad-threshold", | ||
type=float, | ||
default=VAD_THRESHOLD, | ||
help="Voice Activity Detection threshold (default: 0.5)", | ||
) | ||
parser.add_argument( | ||
"--oww-threshold", | ||
type=float, | ||
default=OWW_THRESHOLD, | ||
help="OpenWakeWord threshold (default: 0.1)", | ||
) | ||
parser.add_argument( | ||
"--vad-sampling-rate", | ||
type=int, | ||
choices=[8000, 16000], | ||
default=VAD_SAMPLING_RATE, | ||
help="VAD sampling rate (default: 16000)", | ||
) | ||
parser.add_argument( | ||
"--block-size", | ||
type=int, | ||
default=DEFAULT_BLOCKSIZE, | ||
help="Audio block size (default: 1280)", | ||
) | ||
parser.add_argument( | ||
"--device-name", | ||
type=str, | ||
default="default", | ||
help="Microphone device name (default: 'default')", | ||
) | ||
|
||
# Use parse_known_args to ignore unknown arguments | ||
args, unknown = parser.parse_known_args() | ||
|
||
if unknown: | ||
print(f"Ignoring unknown arguments: {unknown}") | ||
|
||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arguments() | ||
|
||
microphone_configuration = SoundDeviceConfig( | ||
stream=True, | ||
channels=1, | ||
device_name=args.device_name, | ||
block_size=args.block_size, | ||
consumer_sampling_rate=args.vad_sampling_rate, | ||
dtype="int16", | ||
device_number=None, | ||
is_input=True, | ||
is_output=False, | ||
) | ||
vad = SileroVAD(args.vad_sampling_rate, args.vad_threshold) | ||
oww = OpenWakeWord("hey jarvis", args.oww_threshold) | ||
whisper = LocalWhisper("tiny", args.vad_sampling_rate) | ||
# you can easily switch the the provider by changing the whisper object | ||
# whisper = OpenAIWhisper("whisper-1", args.vad_sampling_rate, "en") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added info in 6123d5c |
||
|
||
rclpy.init() | ||
ros2_name = "rai_asr_agent" | ||
|
||
agent = VoiceRecognitionAgent(microphone_configuration, ros2_name, whisper, vad) | ||
# optionally add additional models to decide when to record data for transcription | ||
# agent.add_detection_model(oww, pipeline="record") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add more information why is this commented out There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added info in 6123d5c |
||
|
||
agent.run() | ||
|
||
def cleanup(signum, frame): | ||
agent.stop() | ||
rclpy.shutdown() | ||
exit(0) | ||
|
||
signal.signal(signal.SIGINT, cleanup) | ||
|
||
while True: | ||
time.sleep(1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import logging | ||
import signal | ||
import time | ||
from queue import Queue | ||
from threading import Event, Thread | ||
from typing import Dict, List | ||
|
||
import rclpy | ||
from langchain_core.callbacks import BaseCallbackHandler | ||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | ||
from rai.agents.base import BaseAgent | ||
from rai.communication import BaseConnector | ||
from rai.communication.ros2.api import IROS2Message | ||
from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig | ||
from rai.utils.model_initialization import get_llm_model | ||
|
||
from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage | ||
|
||
# NOTE: the Agent code included here is temporary until a dedicated speech agent is created | ||
# it can still serve as a reference for writing your own RAI agents | ||
|
||
|
||
class LLMTextHandler(BaseCallbackHandler): | ||
def __init__(self, connector: ROS2HRIConnector): | ||
self.connector = connector | ||
self.token_buffer = "" | ||
|
||
def on_llm_new_token(self, token: str, **kwargs): | ||
self.token_buffer += token | ||
if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]: | ||
logging.info(f"Sending token buffer: {self.token_buffer}") | ||
self.connector.send_all_targets(AIMessage(content=self.token_buffer)) | ||
self.token_buffer = "" | ||
|
||
def on_llm_end( | ||
self, | ||
response, | ||
*, | ||
run_id, | ||
parent_run_id=None, | ||
**kwargs, | ||
): | ||
if self.token_buffer: | ||
logging.info(f"Sending token buffer: {self.token_buffer}") | ||
self.connector.send_all_targets(AIMessage(content=self.token_buffer)) | ||
self.token_buffer = "" | ||
|
||
|
||
class S2SConversationalAgent(BaseAgent): | ||
def __init__(self, connectors: Dict[str, BaseConnector]): # type: ignore | ||
super().__init__(connectors=connectors) | ||
self.message_history: List[HumanMessage | AIMessage | SystemMessage] = [ | ||
SystemMessage( | ||
content="Pretend you are a robot. Answer as if you were a robot." | ||
) | ||
] | ||
self.speech_queue: Queue[InterfacesHRIMessage] = Queue() | ||
|
||
self.llm = get_llm_model(model_type="complex_model", streaming=True) | ||
self._setup_ros_connector() | ||
self.main_thread = None | ||
self.stop_thread = Event() | ||
|
||
def run(self): | ||
logging.info("Running S2SConversationalAgent") | ||
self.main_thread = Thread(target=self._main_loop) | ||
self.main_thread.start() | ||
|
||
def _main_loop(self): | ||
while not self.stop_thread.is_set(): | ||
time.sleep(0.01) | ||
speech = "" | ||
while not self.speech_queue.empty(): | ||
speech += "".join(self.speech_queue.get().text) | ||
logging.info(f"Received human speech {speech}!") | ||
if speech != "": | ||
self.message_history.append(HumanMessage(content=speech)) | ||
assert isinstance(self.connectors["ros2"], ROS2HRIConnector) | ||
ai_answer = self.llm.invoke( | ||
self.message_history, | ||
config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]}, | ||
) | ||
self.message_history.append(ai_answer) # type: ignore | ||
|
||
def _on_from_human(self, msg: IROS2Message): | ||
assert isinstance(msg, InterfacesHRIMessage) | ||
logging.info("Received message from human: %s", msg.text) | ||
self.speech_queue.put(msg) | ||
|
||
def _setup_ros_connector(self): | ||
self.connectors["ros2"] = ROS2HRIConnector( | ||
sources=[ | ||
( | ||
"/from_human", | ||
TopicConfig( | ||
"rai_interfaces/msg/HRIMessage", | ||
is_subscriber=True, | ||
source_author="human", | ||
subscriber_callback=self._on_from_human, | ||
), | ||
) | ||
], | ||
targets=[ | ||
( | ||
"/to_human", | ||
TopicConfig( | ||
"rai_interfaces/msg/HRIMessage", | ||
source_author="ai", | ||
is_subscriber=False, | ||
), | ||
) | ||
], | ||
) | ||
|
||
def stop(self): | ||
assert isinstance(self.connectors["ros2"], ROS2HRIConnector) | ||
self.connectors["ros2"].shutdown() | ||
self.stop_thread.set() | ||
if self.main_thread is not None: | ||
self.main_thread.join() | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser( | ||
description="Text To Speech Configuration", | ||
allow_abbrev=True, | ||
) | ||
|
||
# Use parse_known_args to ignore unknown arguments | ||
args, unknown = parser.parse_known_args() | ||
|
||
if unknown: | ||
print(f"Ignoring unknown arguments: {unknown}") | ||
|
||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arguments() | ||
rclpy.init() | ||
agent = S2SConversationalAgent(connectors={}) | ||
agent.run() | ||
|
||
def cleanup(signum, frame): | ||
print("\nCustom handler: Caught SIGINT (Ctrl+C).") | ||
print("Performing cleanup") | ||
# Optionally exit the program | ||
agent.stop() | ||
rclpy.shutdown() | ||
exit(0) | ||
|
||
signal.signal(signal.SIGINT, cleanup) | ||
|
||
while True: | ||
time.sleep(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which part of the name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? All of it