Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/ps routes #73

Merged
merged 5 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Taskfile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ tasks:
sources:
- Dockerfile.*
- faster_whisper_server/*.py
sync: lsyncd lsyncd.conf
cii:
cmds:
- act --rm --action-offline-mode --secret-file .secrets {{.CLI_ARGS}}
Expand Down
25 changes: 23 additions & 2 deletions faster_whisper_server/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
from typing import Self

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

SAMPLES_PER_SECOND = 16000
Expand Down Expand Up @@ -151,7 +152,9 @@ class WhisperConfig(BaseModel):

model: str = Field(default="Systran/faster-whisper-medium.en")
"""
Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2.
Default Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2.
This model will be used if no model is specified in the request.

Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran
You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2
"""
Expand Down Expand Up @@ -199,6 +202,16 @@ class Config(BaseSettings):
"""
Maximum number of models that can be loaded at a time.
"""
preload_models: list[str] = Field(
default_factory=list,
examples=[
["Systran/faster-whisper-medium.en"],
["Systran/faster-whisper-medium.en", "Systran/faster-whisper-small.en"],
],
)
"""
List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request.
""" # noqa: E501
max_no_data_seconds: float = 1.0
"""
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
Expand All @@ -218,5 +231,13 @@ class Config(BaseSettings):
Should be greater than `max_inactivity_seconds`
"""

@model_validator(mode="after")
def ensure_preloaded_models_is_lte_max_models(self) -> Self:
if len(self.preload_models) > self.max_models:
raise ValueError(
f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
)
return self


config = Config()
39 changes: 36 additions & 3 deletions faster_whisper_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import asyncio
from collections import OrderedDict
from contextlib import asynccontextmanager
import gc
from io import BytesIO
import time
from typing import TYPE_CHECKING, Annotated, Literal
Expand Down Expand Up @@ -45,7 +47,7 @@
from faster_whisper_server.transcriber import audio_transcriber

if TYPE_CHECKING:
from collections.abc import Generator, Iterable
from collections.abc import AsyncGenerator, Generator, Iterable

from faster_whisper.transcribe import TranscriptionInfo
from huggingface_hub.hf_api import ModelInfo
Expand All @@ -63,7 +65,7 @@ def load_model(model_name: str) -> WhisperModel:
del loaded_models[oldest_model_name]
logger.debug(f"Loading {model_name}...")
start = time.perf_counter()
# NOTE: will raise an exception if the model name isn't valid
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
whisper = WhisperModel(
model_name,
device=config.whisper.inference_device,
Expand All @@ -81,7 +83,15 @@ def load_model(model_name: str) -> WhisperModel:

logger.debug(f"Config: {config}")

app = FastAPI()

@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
for model_name in config.preload_models:
load_model(model_name)
yield


app = FastAPI(lifespan=lifespan)

if config.allow_origins is not None:
app.add_middleware(
Expand All @@ -98,6 +108,29 @@ def health() -> Response:
return Response(status_code=200, content="OK")


@app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
def get_running_models() -> dict[str, list[str]]:
return {"models": list(loaded_models.keys())}


@app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
def load_model_route(model_name: str) -> Response:
if model_name in loaded_models:
return Response(status_code=409, content="Model already loaded")
load_model(model_name)
return Response(status_code=201)


@app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
def stop_running_model(model_name: str) -> Response:
model = loaded_models.get(model_name)
if model is not None:
del loaded_models[model_name]
gc.collect()
return Response(status_code=204)
return Response(status_code=404)


@app.get("/v1/models")
def get_models() -> ModelListResponse:
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
Expand Down
19 changes: 16 additions & 3 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,32 @@
act
ffmpeg-full
go-task
lsyncd
parallel
pre-commit
pv
python312
rsync
websocat
uv
cudaPackages_12.cudnn_8_9
];

# https://github.com/NixOS/nixpkgs/issues/278976#issuecomment-1879685177
# NOTE: Without adding `/run/...` the following error occurs
# RuntimeError: CUDA failed with error CUDA driver version is insufficient for CUDA runtime version
#
# NOTE: sometimes it still doesn't work but rebooting the system fixes it
LD_LIBRARY_PATH = "/run/opengl-driver/lib:${
pkgs.lib.makeLibraryPath [
pkgs.cudaPackages_12.cudnn_8_9
pkgs.zlib
pkgs.stdenv.cc.cc
pkgs.openssl
]
}";

shellHook = ''
source .venv/bin/activate
export LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=${pkgs.zlib}/lib:$LD_LIBRARY_PATH
source .env
'';
};
Expand Down
20 changes: 0 additions & 20 deletions lsyncd.conf

This file was deleted.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ ignore = [
"W505",
"ISC001", # recommended to disable for formatting
"INP001",
"PT018",
]

[tool.ruff.lint.isort]
Expand Down
48 changes: 26 additions & 22 deletions scripts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
import subprocess
import threading
import time

import httpx
import keyboard
Expand All @@ -14,11 +15,12 @@
# The audio file will be sent to the server for transcription.
# The transcription will be copied to the clipboard.
# When having a short audio of a couple of sentences and running inference on a GPU the response time is very fast (less than 2 seconds). # noqa: E501
# Run this with `sudo -E python scripts/client.py`

CHUNK = 2**12
AUDIO_RECORD_CMD = [
"ffmpeg",
# "-hide_banner",
"-hide_banner",
# "-loglevel",
# "quiet",
"-f",
Expand All @@ -27,15 +29,6 @@
"default",
"-f",
"wav",
# "-ac",
# "1",
# "-ar",
# "16000",
# "-f",
# "s16le",
# "-acodec",
# "pcm_s16le",
# "-",
]
COPY_TO_CLIPBOARD_CMD = "wl-copy"
OPENAI_BASE_URL = "ws://localhost:8000/v1"
Expand All @@ -48,12 +41,13 @@

client = httpx.Client(base_url=OPENAI_BASE_URL, timeout=TIMEOUT)
is_running = threading.Event()
file = Path("test.wav") # TODO: use tempfile

file = Path("test.wav") # HACK: I had a hard time trying to use a temporary file due to permissions issues


while True:
keyboard.wait(KEYBIND)
print("Action started")
print("Recording started")
process = subprocess.Popen(
[*AUDIO_RECORD_CMD, "-y", str(file.name)],
stdout=subprocess.PIPE,
Expand All @@ -63,17 +57,27 @@
)
keyboard.wait(KEYBIND)
process.kill()
print("Action finished")
stdout, stderr = process.communicate()
if stdout or stderr:
print(f"stdout: {stdout}")
print(f"stderr: {stderr}")
print(f"Recording finished. File size: {file.stat().st_size} bytes")

with open(file, "rb") as f:
res = client.post(
OPENAI_BASE_URL + TRANSCRIBE_PATH,
files={"file": f},
data={
"response_format": RESPONSE_FORMAT,
"language": LANGUAGE,
},
)
try:
with open(file, "rb") as fd:
start = time.perf_counter()
res = client.post(
OPENAI_BASE_URL + TRANSCRIBE_PATH,
files={"file": fd},
data={
"response_format": RESPONSE_FORMAT,
"language": LANGUAGE,
},
)
end = time.perf_counter()
print(f"Transcription took {end - start} seconds")
transcription = res.text
print(transcription)
subprocess.run([COPY_TO_CLIPBOARD_CMD], input=transcription.encode(), check=True)
except httpx.ConnectError as e:
print(f"Couldn't connect to server: {e}")