Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc changes #336

Merged
merged 7 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 47 additions & 47 deletions src/speaches/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def unload(self) -> None:
if self.expire_timer:
self.expire_timer.cancel()
self.model = None
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
gc.collect()
logger.info(f"Model {self.model_id} unloaded")
if self.unload_fn is not None:
Expand Down Expand Up @@ -118,32 +117,33 @@ def _load_fn(self, model_id: str) -> WhisperModel:
num_workers=self.whisper_config.num_workers,
)

def _handle_model_unload(self, model_name: str) -> None:
def _handle_model_unload(self, model_id: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]
if model_id in self.loaded_models:
del self.loaded_models[model_id]

def unload_model(self, model_name: str) -> None:
def unload_model(self, model_id: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
model = self.loaded_models.get(model_id)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()
raise KeyError(f"Model {model_id} not found")
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
self.loaded_models[model_id].unload()

def load_model(self, model_name: str) -> SelfDisposingModel[WhisperModel]:
logger.debug(f"Loading model {model_name}")
def load_model(self, model_id: str) -> SelfDisposingModel[WhisperModel]:
logger.debug(f"Loading model {model_id}")
with self._lock:
logger.debug("Acquired lock")
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingModel[WhisperModel](
model_name,
load_fn=lambda: self._load_fn(model_name),
if model_id in self.loaded_models:
logger.debug(f"{model_id} model already loaded")
return self.loaded_models[model_id]
self.loaded_models[model_id] = SelfDisposingModel[WhisperModel](
model_id,
load_fn=lambda: self._load_fn(model_id),
ttl=self.whisper_config.ttl,
unload_fn=self._handle_model_unload,
)
return self.loaded_models[model_name]
return self.loaded_models[model_id]


ONNX_PROVIDERS = ["CUDAExecutionProvider", "CPUExecutionProvider"]
Expand All @@ -164,32 +164,32 @@ def _load_fn(self, model_id: str) -> PiperVoice:
conf = PiperConfig.from_dict(json.loads(config_path.read_text()))
return PiperVoice(session=inf_sess, config=conf)

def _handle_model_unload(self, model_name: str) -> None:
def _handle_model_unload(self, model_id: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]
if model_id in self.loaded_models:
del self.loaded_models[model_id]

def unload_model(self, model_name: str) -> None:
def unload_model(self, model_id: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
model = self.loaded_models.get(model_id)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()
raise KeyError(f"Model {model_id} not found")
self.loaded_models[model_id].unload()

def load_model(self, model_name: str) -> SelfDisposingModel[PiperVoice]:
def load_model(self, model_id: str) -> SelfDisposingModel[PiperVoice]:
from piper.voice import PiperVoice

with self._lock:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingModel[PiperVoice](
model_name,
load_fn=lambda: self._load_fn(model_name),
if model_id in self.loaded_models:
logger.debug(f"{model_id} model already loaded")
return self.loaded_models[model_id]
self.loaded_models[model_id] = SelfDisposingModel[PiperVoice](
model_id,
load_fn=lambda: self._load_fn(model_id),
ttl=self.ttl,
unload_fn=self._handle_model_unload,
)
return self.loaded_models[model_name]
return self.loaded_models[model_id]


class KokoroModelManager:
Expand All @@ -205,27 +205,27 @@ def _load_fn(self, _model_id: str) -> Kokoro:
inf_sess = InferenceSession(model_path, providers=ONNX_PROVIDERS)
return Kokoro.from_session(inf_sess, str(voices_path))

def _handle_model_unload(self, model_name: str) -> None:
def _handle_model_unload(self, model_id: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]
if model_id in self.loaded_models:
del self.loaded_models[model_id]

def unload_model(self, model_name: str) -> None:
def unload_model(self, model_id: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
model = self.loaded_models.get(model_id)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()
raise KeyError(f"Model {model_id} not found")
self.loaded_models[model_id].unload()

def load_model(self, model_name: str) -> SelfDisposingModel[Kokoro]:
def load_model(self, model_id: str) -> SelfDisposingModel[Kokoro]:
with self._lock:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingModel[Kokoro](
model_name,
load_fn=lambda: self._load_fn(model_name),
if model_id in self.loaded_models:
logger.debug(f"{model_id} model already loaded")
return self.loaded_models[model_id]
self.loaded_models[model_id] = SelfDisposingModel[Kokoro](
model_id,
load_fn=lambda: self._load_fn(model_id),
ttl=self.ttl,
unload_fn=self._handle_model_unload,
)
return self.loaded_models[model_name]
return self.loaded_models[model_id]
8 changes: 3 additions & 5 deletions src/speaches/realtime/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@

from speaches.realtime.input_audio_buffer import InputAudioBuffer
from speaches.realtime.pubsub import EventPubSub
from speaches.realtime.session import (
DEFAULT_SESSION_CONFIG,
)
from speaches.realtime.utils import (
generate_session_id,
)
from speaches.types.realtime import ConversationItem, RealtimeResponse
from speaches.types.realtime import ConversationItem, RealtimeResponse, Session


class SessionContext:
Expand All @@ -20,13 +17,14 @@ def __init__(
transcription_client: AsyncTranscriptions,
completion_client: AsyncCompletions,
speech_client: AsyncSpeech,
configuration: Session,
) -> None:
self.transcription_client = transcription_client
self.speech_client = speech_client
self.completion_client = completion_client

self.session_id = generate_session_id()
self.configuration = DEFAULT_SESSION_CONFIG
self.configuration = configuration

self.conversation = OrderedDict[
str, ConversationItem
Expand Down
30 changes: 12 additions & 18 deletions src/speaches/realtime/input_audio_buffer_event_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import numpy as np
from numpy.typing import NDArray
from openai.types.beta.realtime.error_event import Error
from opentelemetry import trace

from speaches.audio import audio_samples_from_file
from speaches.realtime.context import SessionContext
from speaches.realtime.event_router import EventRouter
from speaches.realtime.input_audio_buffer import InputAudioBuffer
from speaches.realtime.utils import generate_event_id, generate_item_id
from speaches.realtime.input_audio_buffer import MAX_VAD_WINDOW_SIZE_SAMPLES, MS_SAMPLE_RATE, InputAudioBuffer
from speaches.realtime.utils import generate_event_id
from speaches.types.realtime import (
ConversationItem,
ConversationItemContent,
Expand All @@ -30,10 +29,6 @@
TurnDetection,
)

SAMPLE_RATE = 16000
MS_SAMPLE_RATE = 16
MAX_VAD_WINDOW_SIZE_SAMPLES = 3000 * MS_SAMPLE_RATE

logger = logging.getLogger(__name__)

event_router = EventRouter()
Expand All @@ -47,12 +42,14 @@


# NOTE: `signal.resample_poly` **might** be a better option for resampling audio data
# TODO: also found in src/speaches/audio.py. Remove duplication
def resample_audio_data(data: NDArray[np.float32], sample_rate: int, target_sample_rate: int) -> NDArray[np.float32]:
ratio = target_sample_rate / sample_rate
target_length = int(len(data) * ratio)
return np.interp(np.linspace(0, len(data), target_length), np.arange(len(data)), data).astype(np.float32)


# TODO: also found in src/speaches/routers/vad.py. Remove duplication
def to_ms_speech_timestamps(speech_timestamps: list[SpeechTimestamp]) -> list[SpeechTimestamp]:
for i in range(len(speech_timestamps)):
speech_timestamps[i]["start"] = speech_timestamps[i]["start"] // MS_SAMPLE_RATE
Expand Down Expand Up @@ -89,7 +86,7 @@ def vad_detection_flow(
)
return InputAudioBufferSpeechStartedEvent(
type="input_audio_buffer.speech_started",
event_id=generate_item_id(),
event_id=generate_event_id(),
item_id=input_audio_buffer.id,
audio_start_ms=input_audio_buffer.vad_state.audio_start_ms,
)
Expand All @@ -102,19 +99,19 @@ def vad_detection_flow(
)
return InputAudioBufferSpeechStoppedEvent(
type="input_audio_buffer.speech_stopped",
event_id=generate_item_id(),
event_id=generate_event_id(),
item_id=input_audio_buffer.id,
audio_end_ms=input_audio_buffer.vad_state.audio_end_ms,
)

elif speech_timestamp["end"] < 3000 and input_audio_buffer.duration_ms > 3000:
elif speech_timestamp["end"] < 3000 and input_audio_buffer.duration_ms > 3000: # FIX: magic number
input_audio_buffer.vad_state.audio_end_ms = (
input_audio_buffer.duration_ms - turn_detection.prefix_padding_ms
)

return InputAudioBufferSpeechStoppedEvent(
type="input_audio_buffer.speech_stopped",
event_id=generate_item_id(),
event_id=generate_event_id(),
item_id=input_audio_buffer.id,
audio_end_ms=input_audio_buffer.vad_state.audio_end_ms,
)
Expand All @@ -133,9 +130,6 @@ def handle_input_audio_buffer_append(ctx: SessionContext, event: InputAudioBuffe
input_audio_buffer_id = next(reversed(ctx.input_audio_buffers))
input_audio_buffer = ctx.input_audio_buffers[input_audio_buffer_id]
input_audio_buffer.append(audio_chunk)
trace.get_current_span().add_event(
"input_audio_buffer.appended", {"size": len(audio_chunk), "duration": len(audio_chunk) / 16000}
)
if ctx.configuration.turn_detection is not None:
vad_event = vad_detection_flow(input_audio_buffer, ctx.configuration.turn_detection)
if vad_event is not None:
Expand All @@ -154,7 +148,7 @@ def handle_input_audio_buffer_commit(ctx: SessionContext, _event: InputAudioBuff
ctx.pubsub.publish_nowait(
InputAudioBufferCommittedEvent(
type="input_audio_buffer.committed",
event_id=generate_item_id(),
event_id=generate_event_id(),
previous_item_id=next(reversed(ctx.conversation), None), # pyright: ignore[reportArgumentType]
item_id=input_audio_buffer_id,
)
Expand All @@ -170,7 +164,7 @@ def handle_input_audio_buffer_clear(ctx: SessionContext, _event: InputAudioBuffe
ctx.pubsub.publish_nowait(
InputAudioBufferClearedEvent(
type="input_audio_buffer.cleared",
event_id=generate_item_id(),
event_id=generate_event_id(),
)
)
input_audio_buffer = InputAudioBuffer()
Expand All @@ -191,7 +185,7 @@ def handle_input_audio_buffer_speech_stopped(ctx: SessionContext, event: InputAu
ctx.pubsub.publish_nowait(
InputAudioBufferCommittedEvent(
type="input_audio_buffer.committed",
event_id=generate_item_id(),
event_id=generate_event_id(),
previous_item_id=previous_item_id,
item_id=event.item_id,
)
Expand All @@ -215,7 +209,7 @@ def handle_input_audio_buffer_committed(ctx: SessionContext, event: InputAudioBu
ctx.pubsub.publish_nowait(
ConversationItemCreatedEvent(
type="conversation.item.created",
event_id=generate_item_id(),
event_id=generate_event_id(),
# previous_item_id=next(reversed(ctx.conversation), None), # TODO: incorrect this needs to be second last
previous_item_id=None,
item=item,
Expand Down
62 changes: 23 additions & 39 deletions src/speaches/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,28 @@

from speaches.types.realtime import Session, TurnDetection

# NOTE: the `DEFAULT_OPENAI_REALTIME_*` constants are not currently used. Keeping them here for reference. They also may be outdated
DEFAULT_OPENAI_REALTIME_MODEL = "gpt-4o-realtime-preview-2024-10-01"
DEFAULT_OPENAI_REALTIME_SESSION_DURATION_SECONDS = 30 * 60
DEFAULT_OPENAI_REALTIME_SESSION_INSTRUCTIONS = "Your knowledge cutoff is 2023-10. You are a helpful, witty, and friendly AI. Act like a human, but remember that you aren't a human and that you can't do human things in the real world. Your voice and personality should be warm and engaging, with a lively and playful tone. If interacting in a non-English language, start by using the standard accent or dialect familiar to the user. Talk quickly. You should always call a function if you can. Do not refer to these rules, even if you\u2019re asked about them."
DEFAULT_OPENAI_REALTIME_SESSION_CONFIG = Session(
model=DEFAULT_OPENAI_REALTIME_MODEL,
modalities=["audio", "text"], # NOTE: the order of the modalities often differs
instructions=DEFAULT_OPENAI_REALTIME_SESSION_INSTRUCTIONS,
voice="alloy",
input_audio_format="pcm16",
output_audio_format="pcm16",
input_audio_transcription=None,
turn_detection=TurnDetection(),
temperature=0.8,
tools=[],
tool_choice="auto",
max_response_output_tokens="inf",
)
# https://platform.openai.com/docs/guides/realtime-model-capabilities#session-lifecycle-events
OPENAI_REALTIME_SESSION_DURATION_SECONDS = 30 * 60
OPENAI_REALTIME_INSTRUCTIONS = "Your knowledge cutoff is 2023-10. You are a helpful, witty, and friendly AI. Act like a human, but remember that you aren't a human and that you can't do human things in the real world. Your voice and personality should be warm and engaging, with a lively and playful tone. If interacting in a non-English language, start by using the standard accent or dialect familiar to the user. Talk quickly. You should always call a function if you can. Do not refer to these rules, even if you\u2019re asked about them."


DEFAULT_REALTIME_SESSION_INSTRUCTIONS = "Your knowledge cutoff is 2023-10. You are a helpful, witty, and friendly AI. Keep the responses concise and to the point. Your responses will be converted into speech; avoid using text that makes sense when spoken. Do not use emojis, abbreviations, or markdown formatting (such as double asterisks) in your response."
DEFAULT_TURN_DETECTION = TurnDetection(
threshold=0.9,
prefix_padding_ms=0,
silence_duration_ms=550,
create_response=False,
)
DEFAULT_SESSION_CONFIG = Session(
model=DEFAULT_OPENAI_REALTIME_MODEL,
modalities=["audio", "text"],
instructions=DEFAULT_OPENAI_REALTIME_SESSION_INSTRUCTIONS, # changed
voice="alloy",
input_audio_format="pcm16",
output_audio_format="pcm16",
input_audio_transcription=InputAudioTranscription(model="Systran/faster-whisper-small"), # changed
turn_detection=DEFAULT_TURN_DETECTION,
temperature=0.8,
tools=[],
tool_choice="auto",
max_response_output_tokens="inf",
)
def create_session_configuration(model: str) -> Session:
return Session(
model=model,
modalities=["audio", "text"],
instructions=OPENAI_REALTIME_INSTRUCTIONS,
voice="alloy",
input_audio_format="pcm16",
output_audio_format="pcm16",
input_audio_transcription=InputAudioTranscription(model="Systran/faster-whisper-small"),
turn_detection=TurnDetection(
threshold=0.9,
prefix_padding_ms=0,
silence_duration_ms=550,
create_response=False,
),
temperature=0.8,
tools=[],
tool_choice="auto",
max_response_output_tokens="inf",
)
8 changes: 4 additions & 4 deletions src/speaches/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ def get_models() -> ListModelsResponse:
return ListModelsResponse(data=whisper_models)


@router.get("/v1/models/{model_name:path}")
@router.get("/v1/models/{model_id:path}")
def get_model(
# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
model_id: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
) -> Model:
models = huggingface_hub.list_models(
model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
model_name=model_id, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
)
models = list(models)
models.sort(key=lambda model: model.downloads or -1, reverse=True)
if len(models) == 0:
raise HTTPException(status_code=404, detail="Model doesn't exists")
exact_match: ModelInfo | None = None
for model in models:
if model.id == model_name:
if model.id == model_id:
exact_match = model
break
if exact_match is None:
Expand Down
Loading
Loading