Skip to content

Commit

Permalink
style(asr_node): rename DEFAULT_SAMPLING_RATE to VAD_SAMPLING_RATE
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek committed Sep 9, 2024
1 parent 41b86ce commit 3e50e72
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/rai_asr/rai_asr/asr_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from scipy.signal import resample
from std_msgs.msg import String

DEFAULT_SAMPLING_RATE = 16000
VAD_SAMPLING_RATE = 16000 # default value used by silero vad
DEFAULT_BLOCKSIZE = 1280


Expand Down Expand Up @@ -231,14 +231,12 @@ def _initialize_asr_model(self):
from rai_asr.asr_clients import OpenAIWhisper

self.model = OpenAIWhisper(
self.model_name, DEFAULT_SAMPLING_RATE, self.language
self.model_name, VAD_SAMPLING_RATE, self.language
)
elif self.model_vendor == "whisper":
from rai_asr.asr_clients import LocalWhisper

self.model = LocalWhisper(
self.model_name, DEFAULT_SAMPLING_RATE, self.language
)
self.model = LocalWhisper(self.model_name, VAD_SAMPLING_RATE, self.language)
else:
raise ValueError(f"Unknown model vendor: {self.model_vendor}")

Expand All @@ -264,7 +262,7 @@ def int2float(sound: NDArray[np.int16]):
return sound

vad_confidence = self.vad_model(
torch.tensor(int2float(audio_data[-512:])), DEFAULT_SAMPLING_RATE
torch.tensor(int2float(audio_data[-512:])), VAD_SAMPLING_RATE
).item()

if self.oww_model:
Expand All @@ -288,8 +286,8 @@ def sd_callback(self, indata, frames, _, status):
self.get_logger().warning(f"Stream status: {status}") # type: ignore
indata = indata.flatten()
sample_time_length = len(indata) / self.device_sample_rate
if self.device_sample_rate != DEFAULT_SAMPLING_RATE:
indata = resample(indata, int(sample_time_length * DEFAULT_SAMPLING_RATE))
if self.device_sample_rate != VAD_SAMPLING_RATE:
indata = resample(indata, int(sample_time_length * VAD_SAMPLING_RATE))

asr_lock = (
time.time()
Expand Down Expand Up @@ -325,7 +323,7 @@ def initialize_sounddevice_stream(self):
"default_samplerate"
] # type: ignore
self.window_size_samples = int(
DEFAULT_BLOCKSIZE * self.device_sample_rate / DEFAULT_SAMPLING_RATE
DEFAULT_BLOCKSIZE * self.device_sample_rate / VAD_SAMPLING_RATE
)
self.stream = sd.InputStream(
samplerate=self.device_sample_rate,
Expand Down

0 comments on commit 3e50e72

Please sign in to comment.