-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathspeech_recognition.py
94 lines (80 loc) · 2.51 KB
/
speech_recognition.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import tempfile
import ffmpeg
import asyncio
import subprocess
import os
SAMPLE_RATE = 16000
def convert_audio(data: bytes, out_filename: str):
try:
with tempfile.NamedTemporaryFile("w+b") as file:
file.write(data)
file.flush()
print(f"Converting media {file.name} to {out_filename}")
out, err = (
ffmpeg.input(file.name, threads=0)
.output(out_filename, format="wav", acodec="pcm_s16le", ac=1, ar=SAMPLE_RATE)
.overwrite_output()
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=data)
)
if os.path.getsize(out_filename) == 0:
print(str(err, "utf-8"))
raise Exception("Converted file is empty")
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return out
MODELS = [
"tiny.en",
"tiny.en-q5_1",
"tiny",
"tiny-q5_1",
"base.en",
"base.en-q5_1",
"base",
"base-q5_1",
"small.en",
"small.en-q5_1",
"small",
"small-q5_1",
"medium.en-q5_0",
"medium-q5_0",
"large-q5_0"
]
class ASR():
def __init__(self, model = "tiny", language = "en"):
if model not in MODELS:
raise ValueError(f"Invalid model: {model}. Must be one of {MODELS}")
self.model = model
self.language = language
if os.path.exists(f"/app/ggml-model-whisper-{model}.bin"):
self.model_path = f"/app/ggml-model-whisper-{model}.bin"
else:
self.model_path = f"/data/models/ggml-{model}.bin"
if not os.path.exists("/data/models"):
os.mkdir("/data/models")
self.model_url = f"https://ggml.ggerganov.com/ggml-model-whisper-{self.model}.bin"
self.lock = asyncio.Lock()
def load_model(self):
if not os.path.exists(self.model_path) or os.path.getsize(self.model_path) == 0:
print("Downloading model...")
subprocess.run(["wget", "-nv", self.model_url, "-O", self.model_path], check=True)
print("Done.")
async def transcribe(self, audio: bytes) -> str:
filename = tempfile.mktemp(suffix=".wav")
convert_audio(audio, filename)
async with self.lock:
proc = await asyncio.create_subprocess_exec(
"./main",
"-m", self.model_path,
"-l", self.language,
"-f", filename,
"-nt",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
os.remove(filename)
if stderr:
print(stderr.decode())
text = stdout.decode().strip()
print(text)
return text