From 9f0537dcbd0e9bac7054c403a8a8ef715c0e7bbd Mon Sep 17 00:00:00 2001 From: Fedir Zadniprovskyi Date: Mon, 30 Sep 2024 08:30:29 -0700 Subject: [PATCH] feat: model unloading --- src/faster_whisper_server/config.py | 25 ++-- src/faster_whisper_server/dependencies.py | 2 +- src/faster_whisper_server/model_manager.py | 143 ++++++++++++++++----- src/faster_whisper_server/routers/misc.py | 18 +-- src/faster_whisper_server/routers/stt.py | 96 +++++++------- tests/model_manager_test.py | 120 +++++++++++++++++ 6 files changed, 301 insertions(+), 103 deletions(-) create mode 100644 tests/model_manager_test.py diff --git a/src/faster_whisper_server/config.py b/src/faster_whisper_server/config.py index bf64f430..14b4230c 100644 --- a/src/faster_whisper_server/config.py +++ b/src/faster_whisper_server/config.py @@ -1,7 +1,6 @@ import enum -from typing import Self -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict SAMPLES_PER_SECOND = 16000 @@ -163,6 +162,12 @@ class WhisperConfig(BaseModel): compute_type: Quantization = Field(default=Quantization.DEFAULT) cpu_threads: int = 0 num_workers: int = 1 + ttl: int = Field(default=300, ge=-1) + """ + Time in seconds until the model is unloaded if it is not being used. + -1: Never unload the model. + 0: Unload the model immediately after usage. + """ class Config(BaseSettings): @@ -198,10 +203,6 @@ class Config(BaseSettings): """ default_response_format: ResponseFormat = ResponseFormat.JSON whisper: WhisperConfig = WhisperConfig() - max_models: int = 1 - """ - Maximum number of models that can be loaded at a time. - """ preload_models: list[str] = Field( default_factory=list, examples=[ @@ -210,8 +211,8 @@ class Config(BaseSettings): ], ) """ - List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request. - """ # noqa: E501 + List of models to preload on startup. By default, the model is first loaded on first request. + """ max_no_data_seconds: float = 1.0 """ Max duration to wait for the next audio chunk before transcription is finilized and connection is closed. @@ -230,11 +231,3 @@ class Config(BaseSettings): Controls how many latest seconds of audio are being passed through VAD. Should be greater than `max_inactivity_seconds` """ - - @model_validator(mode="after") - def ensure_preloaded_models_is_lte_max_models(self) -> Self: - if len(self.preload_models) > self.max_models: - raise ValueError( - f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501 - ) - return self diff --git a/src/faster_whisper_server/dependencies.py b/src/faster_whisper_server/dependencies.py index 9846d0fc..ade976fe 100644 --- a/src/faster_whisper_server/dependencies.py +++ b/src/faster_whisper_server/dependencies.py @@ -18,7 +18,7 @@ def get_config() -> Config: @lru_cache def get_model_manager() -> ModelManager: config = get_config() # HACK - return ModelManager(config) + return ModelManager(config.whisper) ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)] diff --git a/src/faster_whisper_server/model_manager.py b/src/faster_whisper_server/model_manager.py index a715c1f9..095ff4d4 100644 --- a/src/faster_whisper_server/model_manager.py +++ b/src/faster_whisper_server/model_manager.py @@ -3,48 +3,131 @@ from collections import OrderedDict import gc import logging +import threading import time from typing import TYPE_CHECKING from faster_whisper import WhisperModel if TYPE_CHECKING: + from collections.abc import Callable + from faster_whisper_server.config import ( - Config, + WhisperConfig, ) logger = logging.getLogger(__name__) +class SelfDisposingWhisperModel: + def __init__( + self, + model_id: str, + whisper_config: WhisperConfig, + *, + on_unload: Callable[[str], None] | None = None, + ) -> None: + self.model_id = model_id + self.whisper_config = whisper_config + self.on_unload = on_unload + + self.ref_count: int = 0 + self.rlock = threading.RLock() + self.expire_timer: threading.Timer | None = None + self.whisper: WhisperModel | None = None + + def unload(self) -> None: + with self.rlock: + if self.whisper is None: + raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}") + if self.ref_count > 0: + raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}") + if self.expire_timer: + self.expire_timer.cancel() + self.whisper = None + # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992 + gc.collect() + logger.info(f"Model {self.model_id} unloaded") + if self.on_unload is not None: + self.on_unload(self.model_id) + + def _load(self) -> None: + with self.rlock: + assert self.whisper is None + logger.debug(f"Loading model {self.model_id}") + start = time.perf_counter() + self.whisper = WhisperModel( + self.model_id, + device=self.whisper_config.inference_device, + device_index=self.whisper_config.device_index, + compute_type=self.whisper_config.compute_type, + cpu_threads=self.whisper_config.cpu_threads, + num_workers=self.whisper_config.num_workers, + local_files_only=True, # TODO + ) + logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s") + + def _increment_ref(self) -> None: + with self.rlock: + self.ref_count += 1 + if self.expire_timer: + logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling") + self.expire_timer.cancel() + logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}") + + def _decrement_ref(self) -> None: + with self.rlock: + self.ref_count -= 1 + logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}") + if self.ref_count <= 0: + if self.whisper_config.ttl > 0: + logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s") + self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload) + self.expire_timer.start() + elif self.whisper_config.ttl == 0: + logger.info(f"Model {self.model_id} is idle, unloading immediately") + self.unload() + else: + logger.info(f"Model {self.model_id} is idle, not unloading") + + def __enter__(self) -> WhisperModel: + with self.rlock: + if self.whisper is None: + self._load() + self._increment_ref() + assert self.whisper is not None + return self.whisper + + def __exit__(self, *_args) -> None: # noqa: ANN002 + self._decrement_ref() + + class ModelManager: - def __init__(self, config: Config) -> None: - self.config = config - self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict() + def __init__(self, whisper_config: WhisperConfig) -> None: + self.whisper_config = whisper_config + self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict() + self._lock = threading.Lock() - def load_model(self, model_name: str) -> WhisperModel: - if model_name in self.loaded_models: - logger.debug(f"{model_name} model already loaded") - return self.loaded_models[model_name] - if len(self.loaded_models) >= self.config.max_models: - oldest_model_name = next(iter(self.loaded_models)) - logger.info( - f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}" + def _handle_model_unload(self, model_name: str) -> None: + with self._lock: + if model_name in self.loaded_models: + del self.loaded_models[model_name] + + def unload_model(self, model_name: str) -> None: + with self._lock: + model = self.loaded_models.get(model_name) + if model is None: + raise KeyError(f"Model {model_name} not found") + self.loaded_models[model_name].unload() + + def load_model(self, model_name: str) -> SelfDisposingWhisperModel: + with self._lock: + if model_name in self.loaded_models: + logger.debug(f"{model_name} model already loaded") + return self.loaded_models[model_name] + self.loaded_models[model_name] = SelfDisposingWhisperModel( + model_name, + self.whisper_config, + on_unload=self._handle_model_unload, ) - del self.loaded_models[oldest_model_name] - gc.collect() - logger.debug(f"Loading {model_name}...") - start = time.perf_counter() - # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check? - whisper = WhisperModel( - model_name, - device=self.config.whisper.inference_device, - device_index=self.config.whisper.device_index, - compute_type=self.config.whisper.compute_type, - cpu_threads=self.config.whisper.cpu_threads, - num_workers=self.config.whisper.num_workers, - ) - logger.info( - f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501 - ) - self.loaded_models[model_name] = whisper - return whisper + return self.loaded_models[model_name] diff --git a/src/faster_whisper_server/routers/misc.py b/src/faster_whisper_server/routers/misc.py index ca02c96f..26ac2941 100644 --- a/src/faster_whisper_server/routers/misc.py +++ b/src/faster_whisper_server/routers/misc.py @@ -1,7 +1,5 @@ from __future__ import annotations -import gc - from fastapi import ( APIRouter, Response, @@ -42,15 +40,19 @@ def get_running_models( def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response: if model_name in model_manager.loaded_models: return Response(status_code=409, content="Model already loaded") - model_manager.load_model(model_name) + with model_manager.load_model(model_name): + pass return Response(status_code=201) @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.") def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response: - model = model_manager.loaded_models.get(model_name) - if model is not None: - del model_manager.loaded_models[model_name] - gc.collect() + try: + model_manager.unload_model(model_name) return Response(status_code=204) - return Response(status_code=404) + except (KeyError, ValueError) as e: + match e: + case KeyError(): + return Response(status_code=404, content="Model not found") + case ValueError(): + return Response(status_code=409, content=str(e)) diff --git a/src/faster_whisper_server/routers/stt.py b/src/faster_whisper_server/routers/stt.py index 91e2be6a..70ef2044 100644 --- a/src/faster_whisper_server/routers/stt.py +++ b/src/faster_whisper_server/routers/stt.py @@ -142,20 +142,20 @@ def translate_file( model = config.whisper.model if response_format is None: response_format = config.default_response_format - whisper = model_manager.load_model(model) - segments, transcription_info = whisper.transcribe( - file.file, - task=Task.TRANSLATE, - initial_prompt=prompt, - temperature=temperature, - vad_filter=vad_filter, - ) - segments = TranscriptionSegment.from_faster_whisper_segments(segments) - - if stream: - return segments_to_streaming_response(segments, transcription_info, response_format) - else: - return segments_to_response(segments, transcription_info, response_format) + with model_manager.load_model(model) as whisper: + segments, transcription_info = whisper.transcribe( + file.file, + task=Task.TRANSLATE, + initial_prompt=prompt, + temperature=temperature, + vad_filter=vad_filter, + ) + segments = TranscriptionSegment.from_faster_whisper_segments(segments) + + if stream: + return segments_to_streaming_response(segments, transcription_info, response_format) + else: + return segments_to_response(segments, transcription_info, response_format) # HACK: Since Form() doesn't support `alias`, we need to use a workaround. @@ -206,23 +206,23 @@ def transcribe_file( logger.warning( "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501 ) - whisper = model_manager.load_model(model) - segments, transcription_info = whisper.transcribe( - file.file, - task=Task.TRANSCRIBE, - language=language, - initial_prompt=prompt, - word_timestamps="word" in timestamp_granularities, - temperature=temperature, - vad_filter=vad_filter, - hotwords=hotwords, - ) - segments = TranscriptionSegment.from_faster_whisper_segments(segments) - - if stream: - return segments_to_streaming_response(segments, transcription_info, response_format) - else: - return segments_to_response(segments, transcription_info, response_format) + with model_manager.load_model(model) as whisper: + segments, transcription_info = whisper.transcribe( + file.file, + task=Task.TRANSCRIBE, + language=language, + initial_prompt=prompt, + word_timestamps="word" in timestamp_granularities, + temperature=temperature, + vad_filter=vad_filter, + hotwords=hotwords, + ) + segments = TranscriptionSegment.from_faster_whisper_segments(segments) + + if stream: + return segments_to_streaming_response(segments, transcription_info, response_format) + else: + return segments_to_response(segments, transcription_info, response_format) async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: @@ -280,24 +280,24 @@ async def transcribe_stream( "vad_filter": vad_filter, "condition_on_previous_text": False, } - whisper = model_manager.load_model(model) - asr = FasterWhisperASR(whisper, **transcribe_opts) - audio_stream = AudioStream() - async with asyncio.TaskGroup() as tg: - tg.create_task(audio_receiver(ws, audio_stream)) - async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration): - logger.debug(f"Sending transcription: {transcription.text}") - if ws.client_state == WebSocketState.DISCONNECTED: - break + with model_manager.load_model(model) as whisper: + asr = FasterWhisperASR(whisper, **transcribe_opts) + audio_stream = AudioStream() + async with asyncio.TaskGroup() as tg: + tg.create_task(audio_receiver(ws, audio_stream)) + async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration): + logger.debug(f"Sending transcription: {transcription.text}") + if ws.client_state == WebSocketState.DISCONNECTED: + break - if response_format == ResponseFormat.TEXT: - await ws.send_text(transcription.text) - elif response_format == ResponseFormat.JSON: - await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump()) - elif response_format == ResponseFormat.VERBOSE_JSON: - await ws.send_json( - CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump() - ) + if response_format == ResponseFormat.TEXT: + await ws.send_text(transcription.text) + elif response_format == ResponseFormat.JSON: + await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump()) + elif response_format == ResponseFormat.VERBOSE_JSON: + await ws.send_json( + CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump() + ) if ws.client_state != WebSocketState.DISCONNECTED: logger.info("Closing the connection.") diff --git a/tests/model_manager_test.py b/tests/model_manager_test.py new file mode 100644 index 00000000..2904f894 --- /dev/null +++ b/tests/model_manager_test.py @@ -0,0 +1,120 @@ +import asyncio +import os + +import anyio +from httpx import ASGITransport, AsyncClient +import pytest + +from faster_whisper_server.main import create_app + + +@pytest.mark.asyncio +async def test_model_unloaded_after_ttl() -> None: + ttl = 5 + model = "Systran/faster-whisper-tiny.en" + os.environ["WHISPER__TTL"] = str(ttl) + os.environ["ENABLE_UI"] = "false" + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 0 + await aclient.post(f"/api/ps/{model}") + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 1 + await asyncio.sleep(ttl + 1) + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 0 + + +@pytest.mark.asyncio +async def test_ttl_resets_after_usage() -> None: + ttl = 5 + model = "Systran/faster-whisper-tiny.en" + os.environ["WHISPER__TTL"] = str(ttl) + os.environ["ENABLE_UI"] = "false" + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: + await aclient.post(f"/api/ps/{model}") + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 1 + await asyncio.sleep(ttl - 2) + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 1 + + async with await anyio.open_file("audio.wav", "rb") as f: + data = await f.read() + res = ( + await aclient.post( + "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} + ) + ).json() + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 1 + await asyncio.sleep(ttl - 2) + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 1 + + await asyncio.sleep(3) + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 0 + + # test the model can be used again after being unloaded + # this just ensures the model can be loaded again after being unloaded + res = ( + await aclient.post( + "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} + ) + ).json() + + +@pytest.mark.asyncio +async def test_model_cant_be_unloaded_when_used() -> None: + ttl = 0 + model = "Systran/faster-whisper-tiny.en" + os.environ["WHISPER__TTL"] = str(ttl) + os.environ["ENABLE_UI"] = "false" + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: + async with await anyio.open_file("audio.wav", "rb") as f: + data = await f.read() + + task = asyncio.create_task( + aclient.post( + "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} + ) + ) + await asyncio.sleep(0.01) + res = await aclient.delete(f"/api/ps/{model}") + assert res.status_code == 409 + + await task + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 0 + + +@pytest.mark.asyncio +async def test_model_cant_be_loaded_twice() -> None: + model = "Systran/faster-whisper-tiny.en" + os.environ["ENABLE_UI"] = "false" + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: + res = await aclient.post(f"/api/ps/{model}") + assert res.status_code == 201 + res = await aclient.post(f"/api/ps/{model}") + assert res.status_code == 409 + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 1 + + +@pytest.mark.asyncio +async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None: + ttl = 0 + os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" + os.environ["WHISPER__TTL"] = str(ttl) + os.environ["ENABLE_UI"] = "false" + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: + async with await anyio.open_file("audio.wav", "rb") as f: + data = await f.read() + res = await aclient.post( + "/v1/audio/transcriptions", + files={"file": ("audio.wav", data, "audio/wav")}, + data={"model": "Systran/faster-whisper-tiny.en"}, + ) + res = (await aclient.get("/api/ps")).json() + assert len(res["models"]) == 0