Skip to content

Commit

Permalink
fix: additional changes to refactor/asr (#187)
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Dąbrowski <[email protected]>
Signed-off-by: Wiktoria Siekierska <[email protected]>
Co-authored-by: Kacper Dąbrowski <[email protected]>
  • Loading branch information
2 people authored and maciejmajek committed Sep 3, 2024
1 parent 909ee6f commit 518e71c
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/rai_asr/rai_asr/asr_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,6 @@ def hmi_status_callback(self, msg: String):
self.hmi_lock = False

def should_listen(self, audio_data: NDArray[np.int16]) -> bool:
if self.oww_model and not self.is_recording: # use only for detecting wake word
predictions = self.oww_model.predict(audio_data[-512:])
for key, value in predictions.items():
if value > self.wake_word_threshold:
self.get_logger().debug(f"Detected wake word: {key}") # type: ignore
return True

def int2float(sound: NDArray[np.int16]):
abs_max = np.abs(sound).max()
sound = sound.astype("float32")
Expand All @@ -253,12 +246,24 @@ def int2float(sound: NDArray[np.int16]):
sound = sound.squeeze()
return sound

confidence = self.vad_model(
vad_confidence = self.vad_model(
torch.tensor(int2float(audio_data[-512:])), self.sample_rate
).item()
if confidence > self.vad_threshold:
self.get_logger().debug(f"Detected speech with confidence: {confidence:.2f}") # type: ignore
return True

if self.oww_model:
if self.is_recording:
self.get_logger().debug(f"VAD confidence: {vad_confidence}") # type: ignore
return vad_confidence > self.vad_threshold
else:
predictions = self.oww_model.predict(audio_data)
for key, value in predictions.items():
if value > self.wake_word_threshold:
self.get_logger().debug(f"Detected wake word: {key}") # type: ignore
self.oww_model.reset()
return True
else:
return vad_confidence > self.vad_threshold

return False

def capture_sound(self):
Expand All @@ -282,15 +287,17 @@ def capture_sound(self):
)
if asr_lock or self.hmi_lock or self.tts_lock:
continue


self.audio_buffer.append(audio_data)
if self.should_listen(audio_data):
self.silence_start_time = datetime.now()
if not self.is_recording:
self.start_recording()
self.reset_buffer()
self.audio_buffer.append(audio_data)
elif self.is_recording:
if datetime.now() - self.silence_start_time > self.grace_period:
self.audio_buffer.append(audio_data)
if datetime.now() - self.silence_start_time > timedelta(seconds=self.silence_grace_period):
self.stop_recording_and_transcribe()
self.reset_buffer()

Expand Down

0 comments on commit 518e71c

Please sign in to comment.