Skip to content

Commit

Permalink
feat: implement base class for asr models, add local whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek committed Sep 5, 2024
1 parent 6b5071e commit 83ed691
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 29 deletions.
73 changes: 73 additions & 0 deletions src/rai_asr/launch/local.launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node


def generate_launch_description():
return LaunchDescription(
[
DeclareLaunchArgument(
"recording_device",
default_value="0",
description="Microphone device number. See available by running python -c 'import sounddevice as sd; print(sd.query_devices())'",
),
DeclareLaunchArgument(
"language",
default_value="en",
description="Language code for the ASR model",
),
DeclareLaunchArgument(
"model_name",
default_value="base",
description="Model name for the ASR model",
),
DeclareLaunchArgument(
"model_vendor",
default_value="whisper",
description="Model vendor of the ASR",
),
DeclareLaunchArgument(
"silence_grace_period",
default_value="2.0",
description="Grace period in seconds after silence to stop recording",
),
DeclareLaunchArgument(
"sample_rate",
default_value="0",
description="Sample rate for audio capture (0 for auto-detect)",
),
Node(
package="rai_asr",
executable="asr_node",
name="rai_asr",
output="screen",
emulate_tty=True,
parameters=[
{
"language": LaunchConfiguration("language"),
"model": LaunchConfiguration("model"),
"silence_grace_period": LaunchConfiguration(
"silence_grace_period"
),
"sample_rate": LaunchConfiguration("sample_rate"),
}
],
),
]
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ def generate_launch_description():
description="Language code for the ASR model",
),
DeclareLaunchArgument(
"model",
"model_name",
default_value="whisper-1",
description="Model type for the ASR model",
description="Model name for the ASR model",
),
DeclareLaunchArgument(
"model_vendor",
default_value="openai",
description="Model vendor of the ASR",
),
DeclareLaunchArgument(
"silence_grace_period",
Expand Down
59 changes: 59 additions & 0 deletions src/rai_asr/rai_asr/asr_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import io
import os
from abc import abstractmethod
from functools import partial

import numpy as np
import whisper
from numpy.typing import NDArray
from openai import OpenAI
from scipy.io import wavfile
from whisper.transcribe import transcribe


class ASRModel:
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
self.model_name = model_name
self.sample_rate = sample_rate
self.language = language

@abstractmethod
def transcribe(self, data: NDArray[np.int16]) -> str:
pass

def __call__(self, data: NDArray[np.int16]) -> str:
return self.transcribe(data)


class OpenAIWhisper(ASRModel):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
super().__init__(model_name, sample_rate, language)
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
self.api_key = api_key
self.openai_client = OpenAI()
self.model = partial(
self.openai_client.audio.transcriptions.create,
model=self.model_name,
)

def transcribe(self, data: NDArray[np.int16]) -> str:
with io.BytesIO() as temp_wav_buffer:
wavfile.write(temp_wav_buffer, self.sample_rate, data)
temp_wav_buffer.seek(0)
temp_wav_buffer.name = "temp.wav"
response = self.model(file=temp_wav_buffer, language=self.language)
transcription = response.text
return transcription


class LocalWhisper(ASRModel):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
super().__init__(model_name, sample_rate, language)
self.whisper = whisper.load_model(self.model_name)

def transcribe(self, data: NDArray[np.int16]) -> str:
result = transcribe(self.whisper, data.astype(np.float32) / 32768.0)
transcription = result["text"]
return transcription
62 changes: 35 additions & 27 deletions src/rai_asr/rai_asr/asr_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,19 @@
# limitations under the License.
#

import io
import threading
import time
from datetime import datetime, timedelta
from functools import partial
from typing import Literal

import numpy as np
import rclpy
import sounddevice as sd
import torch
from openai import OpenAI
from rcl_interfaces.msg import ParameterDescriptor, ParameterType
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from scipy.io import wavfile
from std_msgs.msg import String

SAMPLING_RATE = 16000
Expand All @@ -43,7 +39,7 @@ def __init__(self):
self._setup_node_components()
self._initialize_variables()
self._setup_publishers_and_subscribers()
self._load_whisper_model()
self._initialize_asr_model()

def _declare_parameters(self):
self.declare_parameter(
Expand All @@ -57,6 +53,14 @@ def _declare_parameters(self):
),
),
)
self.declare_parameter(
"model_vendor",
"whisper", # openai, whisper
ParameterDescriptor(
type=ParameterType.PARAMETER_STRING,
description="Vendor of the ASR model",
),
)
self.declare_parameter(
"language",
"en",
Expand All @@ -66,8 +70,8 @@ def _declare_parameters(self):
),
)
self.declare_parameter(
"model",
"whisper-1",
"model_name",
"base",
ParameterDescriptor(
type=ParameterType.PARAMETER_STRING,
description="Model type for the ASR model",
Expand Down Expand Up @@ -106,8 +110,11 @@ def _initialize_variables(self):
.get_parameter_value()
.double_value
) # type: ignore
self.whisper_model = (
self.get_parameter("model").get_parameter_value().string_value
self.model_name = (
self.get_parameter("model_name").get_parameter_value().string_value
) # type: ignore
self.model_vendor = (
self.get_parameter("model_vendor").get_parameter_value().string_value
) # type: ignore
self.language = (
self.get_parameter("language").get_parameter_value().string_value
Expand Down Expand Up @@ -135,11 +142,17 @@ def _setup_publishers_and_subscribers(self):
callback_group=self.callback_group,
)

def _load_whisper_model(self):
self.openai_client = OpenAI()
self.model = partial(
self.openai_client.audio.transcriptions.create, model=self.whisper_model
)
def _initialize_asr_model(self):
if self.model_vendor == "openai":
from rai_asr.asr_clients import OpenAIWhisper

self.model = OpenAIWhisper(self.model_name, self.sample_rate, self.language)
elif self.model_vendor == "whisper":
from rai_asr.asr_clients import LocalWhisper

self.model = LocalWhisper(self.model_name, self.sample_rate, self.language)
else:
raise ValueError(f"Unknown model vendor: {self.model_vendor}")

def tts_status_callback(self, msg: String):
if msg.data == "processing":
Expand Down Expand Up @@ -209,19 +222,14 @@ def transcribe_audio(self):
self.get_logger().info("Calling ASR model")
combined_audio = np.concatenate(self.audio_buffer)

with io.BytesIO() as temp_wav_buffer:
wavfile.write(temp_wav_buffer, self.sample_rate, combined_audio)
temp_wav_buffer.seek(0)
temp_wav_buffer.name = "temp.wav"

response = self.model(file=temp_wav_buffer, language=self.language)
transcription = response.text
if transcription.lower() in ["you", ""]:
self.get_logger().info(f"Dropping transcription: '{transcription}'")
self.publish_status("dropping")
else:
self.get_logger().info(f"Transcription: {transcription}")
self.publish_transcription(transcription)
transcription = self.model(data=combined_audio)

if transcription.lower() in ["you", ""]:
self.get_logger().info(f"Dropping transcription: '{transcription}'")
self.publish_status("dropping")
else:
self.get_logger().info(f"Transcription: {transcription}")
self.publish_transcription(transcription)

self.last_transcription_time = time.time()

Expand Down

0 comments on commit 83ed691

Please sign in to comment.