Skip to content

Commit

Permalink
kws microphone
Browse files Browse the repository at this point in the history
  • Loading branch information
roatienza committed Apr 4, 2022
1 parent b8fd022 commit f0f9116
Showing 1 changed file with 96 additions and 18 deletions.
114 changes: 96 additions & 18 deletions versions/2022/supervised/python/kws-infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
Usage:
python3 kws-infer.py --wav-file test_audio/right.wav --model-path resnet18-kws-best-acc.pt
To use microphone input, run:
python3 kws-infer.py --gui
Dependencies:
pip3 install pysimplegui
pip3 install sounddevice
Inference time:
0.03 sec Quad Core Intel i7 2.3GHz
'''


Expand All @@ -13,6 +23,8 @@
import os
import numpy as np
import librosa
import sounddevice as sd
import time
from torchvision.transforms import ToTensor


Expand All @@ -24,8 +36,8 @@ def get_args():
parser.add_argument("--win-length", type=int, default=None)
parser.add_argument("--hop-length", type=int, default=512)
parser.add_argument("--wav-file", type=str, default=None)
parser.add_argument("--model-path", type=str, default=None)

parser.add_argument("--model-path", type=str, default="resnet18-kws-best-acc.pt")
parser.add_argument("--gui", default=False, action="store_true")
args = parser.parse_args()
return args

Expand All @@ -38,37 +50,103 @@ def get_args():
'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']

idx_to_class = {i: c for i, c in enumerate(CLASSES)}

args = get_args()

model_path = args.model_path if args.model_path else os.path.join(args.path, "checkpoints", "resnet18-kws-best-acc.pt")
model_path = args.model_path if args.model_path else os.path.join(
args.path, "checkpoints", "resnet18-kws-best-acc.pt")
print("Loading model checkpoint: ", model_path)
scripted_module = torch.jit.load(model_path)

if args.wav_file is None:
if args.gui:
import PySimpleGUI as sg
sample_rate = 16000
sd.default.samplerate = sample_rate
sd.default.channels = 1
sg.theme('DarkAmber')

elif args.wav_file is None:
# list wav files given a folder
print("Searching for random kws wav file...")
label = CLASSES[2:]
label = np.random.choice(label)
path = os.path.join(args.path, "SpeechCommands/speech_commands_v0.02/")
path = os.path.join(path, label)
wav_files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.wav')]
wav_files = [os.path.join(path, f)
for f in os.listdir(path) if f.endswith('.wav')]
# select random wav file
wav_file = np.random.choice(wav_files)
else:
wav_file = args.wav_file
label = args.wav_file.split("/")[-1].split(".")[0]

waveform, sample_rate = torchaudio.load(wav_file)
if not args.gui:
waveform, sample_rate = torchaudio.load(wav_file)

#wav = np.expand_dims(waveform.numpy().squeeze(), axis=1)

#print("Shape:", wav.shape)
#sd.default.samplerate = sample_rate
#sd.default.channels = 1
#sd.play(wav, blocking=True)

#print("Shape:", waveform.shape)
#sd.fs = sample_rate
#sd.channel_count = waveform.shape[1]
#sd.play(waveform.numpy(), sample_rate)

transform = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
n_fft=args.n_fft,
win_length=args.win_length,
hop_length=args.hop_length,
n_mels=args.n_mels,
power=2.0)
n_fft=args.n_fft,
win_length=args.win_length,
hop_length=args.hop_length,
n_mels=args.n_mels,
power=2.0)
if not args.gui:
mel = ToTensor()(librosa.power_to_db(transform(waveform).squeeze().numpy(), ref=np.max))
mel = mel.unsqueeze(0)

pred = torch.argmax(scripted_module(mel), dim=1)
print(f"Ground Truth: {label}, Prediction: {idx_to_class[pred.item()]}")
exit(0)

layout = [
[sg.Text('Detecting...', justification='center', size=(10, 1), font=("Helvetica", 100), key='-OUTPUT-')],
[sg.Text('', size=(10, 1), font=("Helvetica", 24), key='-TIME-')],
]
window = sg.Window('KWS Inference', layout, finalize=True)
total_runtime = 0
n_loops = 0
while True:
event, values = window.read(100)
if event == sg.WIN_CLOSED:
break

waveform = sd.rec(sample_rate).squeeze()

sd.wait()
if waveform.max() > 1.0:
continue
start_time = time.time()
waveform = torch.from_numpy(waveform).unsqueeze(0)
mel = ToTensor()(librosa.power_to_db(transform(waveform).squeeze().numpy(), ref=np.max))
mel = mel.unsqueeze(0)
pred = scripted_module(mel)
max_prob = pred.max()
elapsed_time = time.time() - start_time
total_runtime += elapsed_time
n_loops += 1
ave_pred_time = total_runtime / n_loops
if max_prob > 2.0:
pred = torch.argmax(pred, dim=1)
human_label = f"{idx_to_class[pred.item()]}"
window['-OUTPUT-'].update(human_label)
else:
window['-OUTPUT-'].update("...")

window['-TIME-'].update(f"{ave_pred_time:.2f} sec")


window.close()




mel = ToTensor()(librosa.power_to_db(transform(waveform).squeeze().numpy(), ref=np.max))
mel = mel.unsqueeze(0)
scripted_module = torch.jit.load(model_path)
pred = torch.argmax(scripted_module(mel), dim=1)
print(f"Ground Truth: {label}, Prediction: {idx_to_class[pred.item()]}")

0 comments on commit f0f9116

Please sign in to comment.