Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: model unloading #92

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ dependencies = [
client = [
"keyboard>=0.13.5",
]
# NOTE: when installing `dev` group, all other groups should also be installed
dev = [
"anyio>=4.4.0",
"basedpyright>=1.18.0",
"pytest-antilru>=2.0.0",
"pytest-asyncio>=0.24.0",
"pytest-xdist>=3.6.1",
"pytest>=8.3.3",
Expand Down
25 changes: 9 additions & 16 deletions src/faster_whisper_server/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import enum
from typing import Self

from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict

SAMPLES_PER_SECOND = 16000
Expand Down Expand Up @@ -163,6 +162,12 @@ class WhisperConfig(BaseModel):
compute_type: Quantization = Field(default=Quantization.DEFAULT)
cpu_threads: int = 0
num_workers: int = 1
ttl: int = Field(default=300, ge=-1)
"""
Time in seconds until the model is unloaded if it is not being used.
-1: Never unload the model.
0: Unload the model immediately after usage.
"""


class Config(BaseSettings):
Expand Down Expand Up @@ -198,10 +203,6 @@ class Config(BaseSettings):
"""
default_response_format: ResponseFormat = ResponseFormat.JSON
whisper: WhisperConfig = WhisperConfig()
max_models: int = 1
"""
Maximum number of models that can be loaded at a time.
"""
preload_models: list[str] = Field(
default_factory=list,
examples=[
Expand All @@ -210,8 +211,8 @@ class Config(BaseSettings):
],
)
"""
List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request.
""" # noqa: E501
List of models to preload on startup. By default, the model is first loaded on first request.
"""
max_no_data_seconds: float = 1.0
"""
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
Expand All @@ -230,11 +231,3 @@ class Config(BaseSettings):
Controls how many latest seconds of audio are being passed through VAD.
Should be greater than `max_inactivity_seconds`
"""

@model_validator(mode="after")
def ensure_preloaded_models_is_lte_max_models(self) -> Self:
if len(self.preload_models) > self.max_models:
raise ValueError(
f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
)
return self
2 changes: 1 addition & 1 deletion src/faster_whisper_server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_config() -> Config:
@lru_cache
def get_model_manager() -> ModelManager:
config = get_config() # HACK
return ModelManager(config)
return ModelManager(config.whisper)


ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
144 changes: 114 additions & 30 deletions src/faster_whisper_server/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,132 @@
from collections import OrderedDict
import gc
import logging
import threading
import time
from typing import TYPE_CHECKING

from faster_whisper import WhisperModel

if TYPE_CHECKING:
from collections.abc import Callable

from faster_whisper_server.config import (
Config,
WhisperConfig,
)

logger = logging.getLogger(__name__)

# TODO: enable concurrent model downloads


class SelfDisposingWhisperModel:
def __init__(
self,
model_id: str,
whisper_config: WhisperConfig,
*,
on_unload: Callable[[str], None] | None = None,
) -> None:
self.model_id = model_id
self.whisper_config = whisper_config
self.on_unload = on_unload

self.ref_count: int = 0
self.rlock = threading.RLock()
self.expire_timer: threading.Timer | None = None
self.whisper: WhisperModel | None = None

def unload(self) -> None:
with self.rlock:
if self.whisper is None:
raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
if self.ref_count > 0:
raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
if self.expire_timer:
self.expire_timer.cancel()
self.whisper = None
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
gc.collect()
logger.info(f"Model {self.model_id} unloaded")
if self.on_unload is not None:
self.on_unload(self.model_id)

def _load(self) -> None:
with self.rlock:
assert self.whisper is None
logger.debug(f"Loading model {self.model_id}")
start = time.perf_counter()
self.whisper = WhisperModel(
self.model_id,
device=self.whisper_config.inference_device,
device_index=self.whisper_config.device_index,
compute_type=self.whisper_config.compute_type,
cpu_threads=self.whisper_config.cpu_threads,
num_workers=self.whisper_config.num_workers,
)
logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")

def _increment_ref(self) -> None:
with self.rlock:
self.ref_count += 1
if self.expire_timer:
logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
self.expire_timer.cancel()
logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")

def _decrement_ref(self) -> None:
with self.rlock:
self.ref_count -= 1
logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
if self.ref_count <= 0:
if self.whisper_config.ttl > 0:
logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
self.expire_timer.start()
elif self.whisper_config.ttl == 0:
logger.info(f"Model {self.model_id} is idle, unloading immediately")
self.unload()
else:
logger.info(f"Model {self.model_id} is idle, not unloading")

def __enter__(self) -> WhisperModel:
with self.rlock:
if self.whisper is None:
self._load()
self._increment_ref()
assert self.whisper is not None
return self.whisper

def __exit__(self, *_args) -> None: # noqa: ANN002
self._decrement_ref()


class ModelManager:
def __init__(self, config: Config) -> None:
self.config = config
self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
def __init__(self, whisper_config: WhisperConfig) -> None:
self.whisper_config = whisper_config
self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
self._lock = threading.Lock()

def load_model(self, model_name: str) -> WhisperModel:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
if len(self.loaded_models) >= self.config.max_models:
oldest_model_name = next(iter(self.loaded_models))
logger.info(
f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
def _handle_model_unload(self, model_name: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]

def unload_model(self, model_name: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()

def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
with self._lock:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingWhisperModel(
model_name,
self.whisper_config,
on_unload=self._handle_model_unload,
)
del self.loaded_models[oldest_model_name]
gc.collect()
logger.debug(f"Loading {model_name}...")
start = time.perf_counter()
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
whisper = WhisperModel(
model_name,
device=self.config.whisper.inference_device,
device_index=self.config.whisper.device_index,
compute_type=self.config.whisper.compute_type,
cpu_threads=self.config.whisper.cpu_threads,
num_workers=self.config.whisper.num_workers,
)
logger.info(
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501
)
self.loaded_models[model_name] = whisper
return whisper
return self.loaded_models[model_name]
18 changes: 10 additions & 8 deletions src/faster_whisper_server/routers/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import gc

from fastapi import (
APIRouter,
Response,
Expand Down Expand Up @@ -42,15 +40,19 @@ def get_running_models(
def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
if model_name in model_manager.loaded_models:
return Response(status_code=409, content="Model already loaded")
model_manager.load_model(model_name)
with model_manager.load_model(model_name):
pass
return Response(status_code=201)


@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
model = model_manager.loaded_models.get(model_name)
if model is not None:
del model_manager.loaded_models[model_name]
gc.collect()
try:
model_manager.unload_model(model_name)
return Response(status_code=204)
return Response(status_code=404)
except (KeyError, ValueError) as e:
match e:
case KeyError():
return Response(status_code=404, content="Model not found")
case ValueError():
return Response(status_code=409, content=str(e))
96 changes: 48 additions & 48 deletions src/faster_whisper_server/routers/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,20 @@ def translate_file(
model = config.whisper.model
if response_format is None:
response_format = config.default_response_format
whisper = model_manager.load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSLATE,
initial_prompt=prompt,
temperature=temperature,
vad_filter=vad_filter,
)
segments = TranscriptionSegment.from_faster_whisper_segments(segments)

if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)
with model_manager.load_model(model) as whisper:
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSLATE,
initial_prompt=prompt,
temperature=temperature,
vad_filter=vad_filter,
)
segments = TranscriptionSegment.from_faster_whisper_segments(segments)

if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)


# HACK: Since Form() doesn't support `alias`, we need to use a workaround.
Expand Down Expand Up @@ -206,23 +206,23 @@ def transcribe_file(
logger.warning(
"It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
)
whisper = model_manager.load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSCRIBE,
language=language,
initial_prompt=prompt,
word_timestamps="word" in timestamp_granularities,
temperature=temperature,
vad_filter=vad_filter,
hotwords=hotwords,
)
segments = TranscriptionSegment.from_faster_whisper_segments(segments)

if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)
with model_manager.load_model(model) as whisper:
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSCRIBE,
language=language,
initial_prompt=prompt,
word_timestamps="word" in timestamp_granularities,
temperature=temperature,
vad_filter=vad_filter,
hotwords=hotwords,
)
segments = TranscriptionSegment.from_faster_whisper_segments(segments)

if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)


async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
Expand Down Expand Up @@ -280,24 +280,24 @@ async def transcribe_stream(
"vad_filter": vad_filter,
"condition_on_previous_text": False,
}
whisper = model_manager.load_model(model)
asr = FasterWhisperASR(whisper, **transcribe_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
logger.debug(f"Sending transcription: {transcription.text}")
if ws.client_state == WebSocketState.DISCONNECTED:
break
with model_manager.load_model(model) as whisper:
asr = FasterWhisperASR(whisper, **transcribe_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
logger.debug(f"Sending transcription: {transcription.text}")
if ws.client_state == WebSocketState.DISCONNECTED:
break

if response_format == ResponseFormat.TEXT:
await ws.send_text(transcription.text)
elif response_format == ResponseFormat.JSON:
await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
elif response_format == ResponseFormat.VERBOSE_JSON:
await ws.send_json(
CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
)
if response_format == ResponseFormat.TEXT:
await ws.send_text(transcription.text)
elif response_format == ResponseFormat.JSON:
await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
elif response_format == ResponseFormat.VERBOSE_JSON:
await ws.send_json(
CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
)

if ws.client_state != WebSocketState.DISCONNECTED:
logger.info("Closing the connection.")
Expand Down
Loading