Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

realtime misc changes #356

Merged
merged 12 commits into from
Feb 27, 2025
12 changes: 9 additions & 3 deletions src/speaches/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def audio_file_dependency(
def get_completion_client() -> AsyncCompletions:
config = get_config()
oai_client = AsyncOpenAI(
base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key.get_secret_value()
base_url=config.chat_completion_base_url,
api_key=config.chat_completion_api_key.get_secret_value(),
max_retries=1,
)
return oai_client.chat.completions

Expand All @@ -129,7 +131,9 @@ def get_speech_client() -> AsyncSpeech:
transport=ASGITransport(speech_router), base_url="http://test/v1"
) # NOTE: "test" can be replaced with any other value
oai_client = AsyncOpenAI(
http_client=http_client, api_key=config.api_key.get_secret_value() if config.api_key else "cant-be-empty"
http_client=http_client,
api_key=config.api_key.get_secret_value() if config.api_key else "cant-be-empty",
max_retries=1,
)
return oai_client.audio.speech

Expand All @@ -149,7 +153,9 @@ def get_transcription_client() -> AsyncTranscriptions:
transport=ASGITransport(stt_router), base_url="http://test/v1"
) # NOTE: "test" can be replaced with any other value
oai_client = AsyncOpenAI(
http_client=http_client, api_key=config.api_key.get_secret_value() if config.api_key else "cant-be-empty"
http_client=http_client,
api_key=config.api_key.get_secret_value() if config.api_key else "cant-be-empty",
max_retries=1,
)
return oai_client.audio.transcriptions

Expand Down
6 changes: 5 additions & 1 deletion src/speaches/realtime/input_audio_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

import numpy as np
from openai import NotGiven
from pydantic import BaseModel
import soundfile as sf

Expand Down Expand Up @@ -118,7 +119,10 @@ async def _handler(self) -> None:
format="wav",
)
transcript = await self.transcription_client.create(
file=file, model=self.session.input_audio_transcription.model, response_format="text"
file=file,
model=self.session.input_audio_transcription.model,
response_format="text",
language=self.session.input_audio_transcription.language or NotGiven(),
)
content_item.transcript = transcript
self.pubsub.publish_nowait(
Expand Down
27 changes: 15 additions & 12 deletions src/speaches/realtime/response_event_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
from typing import TYPE_CHECKING

import aiostream
import openai
from openai.types.beta.realtime.error_event import Error
from pydantic import BaseModel
Expand Down Expand Up @@ -47,7 +46,7 @@
)

if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator

from openai.resources.chat import AsyncCompletions
from openai.types.chat import ChatCompletionChunk
Expand Down Expand Up @@ -122,7 +121,7 @@ def add_item_content[T: ConversationItemContentText | ConversationItemContentAud
ResponseContentPartDoneEvent(response_id=self.id, item_id=item.id, part=content.to_part())
)

async def conversation_item_message_text_handler(self, chunk_stream: aiostream.Stream[ChatCompletionChunk]) -> None:
async def conversation_item_message_text_handler(self, chunk_stream: AsyncGenerator[ChatCompletionChunk]) -> None:
with self.add_output_item(ConversationItemMessage(role="assistant", status="incomplete", content=[])) as item:
self.conversation.create_item(item)

Expand All @@ -141,9 +140,7 @@ async def conversation_item_message_text_handler(self, chunk_stream: aiostream.S
ResponseTextDoneEvent(item_id=item.id, response_id=self.id, text=content.text)
)

async def conversation_item_message_audio_handler(
self, chunk_stream: aiostream.Stream[ChatCompletionChunk]
) -> None:
async def conversation_item_message_audio_handler(self, chunk_stream: AsyncGenerator[ChatCompletionChunk]) -> None:
with self.add_output_item(ConversationItemMessage(role="assistant", status="incomplete", content=[])) as item:
self.conversation.create_item(item)

Expand Down Expand Up @@ -179,10 +176,8 @@ async def conversation_item_message_audio_handler(
)
)

async def conversation_item_function_call_handler(
self, chunk_stream: aiostream.Stream[ChatCompletionChunk]
) -> None:
chunk = await chunk_stream
async def conversation_item_function_call_handler(self, chunk_stream: AsyncGenerator[ChatCompletionChunk]) -> None:
chunk = await anext(chunk_stream)

assert len(chunk.choices) == 1, chunk
choice = chunk.choices[0]
Expand Down Expand Up @@ -237,15 +232,23 @@ async def generate_response(self) -> None:
self.configuration,
)
chunk_stream = await self.completion_client.create(**completion_params)
chunk = await chunk_stream.__anext__()
chunk = await anext(chunk_stream)
if chunk.choices[0].delta.tool_calls is not None:
handler = self.conversation_item_function_call_handler
elif self.configuration.modalities == ["text"]:
handler = self.conversation_item_message_text_handler
else:
handler = self.conversation_item_message_audio_handler

await handler(aiostream.stream.just(chunk) + chunk_stream)
async def merge_chunks_and_chunk_stream(
*chunks: ChatCompletionChunk, chunk_stream: openai.AsyncStream[ChatCompletionChunk]
) -> AsyncGenerator[ChatCompletionChunk]:
for chunk in chunks:
yield chunk
async for chunk in chunk_stream:
yield chunk

await handler(merge_chunks_and_chunk_stream(chunk, chunk_stream=chunk_stream))
except openai.APIError as e:
logger.exception("Error while generating response")
self.pubsub.publish_nowait(
Expand Down
42 changes: 0 additions & 42 deletions src/speaches/realtime/rtc/audio_stream_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,47 +136,5 @@ def stop(self) -> None:
super().stop()


class ToneAudioStreamTrack(MediaStreamTrack):
kind = "audio"

def __init__(self) -> None:
super().__init__()
self._timestamp = 0
self._sample_rate = 24000
self._samples_per_frame = self._sample_rate // 100 # 10ms
self._running = True
self._frequency = 440 # 440 Hz tone

async def recv(self) -> AudioFrame:
if not self._running:
raise MediaStreamError("Track has ended") # noqa: EM101

# Generate sine wave for this frame
t = np.linspace(0, self._samples_per_frame / self._sample_rate, self._samples_per_frame)
samples = np.sin(2 * np.pi * self._frequency * t)
samples = (samples * 32767).astype(np.int16)

# Create frame
frame = AudioFrame(
format="s16",
layout="mono",
samples=self._samples_per_frame,
)
frame.sample_rate = self._sample_rate
frame.pts = self._timestamp
frame.planes[0].update(samples.tobytes())

self._timestamp += self._samples_per_frame

# Sleep for frame duration
await asyncio.sleep(0.01) # 10ms

return frame

def stop(self) -> None:
self._running = False
super().stop()


class MediaStreamError(Exception):
pass
4 changes: 3 additions & 1 deletion src/speaches/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def create_session_object_configuration(model: str) -> Session:
voice="alloy",
input_audio_format="pcm16",
output_audio_format="pcm16",
input_audio_transcription=InputAudioTranscription(model="Systran/faster-whisper-small"),
input_audio_transcription=InputAudioTranscription(
model="Systran/faster-distil-whisper-small.en", language="en"
),
turn_detection=TurnDetection(
type="server_vad",
threshold=0.9,
Expand Down
21 changes: 16 additions & 5 deletions src/speaches/realtime/session_event_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from openai.types.beta.realtime.error_event import Error

from speaches.realtime.event_router import EventRouter
from speaches.types.realtime import ErrorEvent, Session, SessionUpdatedEvent, SessionUpdateEvent
from speaches.types.realtime import (
NOT_GIVEN,
ErrorEvent,
Session,
SessionUpdatedEvent,
SessionUpdateEvent,
TurnDetection,
)

if TYPE_CHECKING:
from speaches.realtime.context import SessionContext
Expand Down Expand Up @@ -36,16 +43,20 @@ def unsupported_field_error(field: str) -> ErrorEvent:

@event_router.register("session.update")
def handle_session_update_event(ctx: SessionContext, event: SessionUpdateEvent) -> None:
if event.session.input_audio_format is not None:
if event.session.input_audio_format != NOT_GIVEN:
ctx.pubsub.publish_nowait(unsupported_field_error("session.input_audio_format"))
if event.session.output_audio_format is not None:
if event.session.output_audio_format != NOT_GIVEN:
ctx.pubsub.publish_nowait(unsupported_field_error("session.output_audio_format"))
if event.session.turn_detection is not None and event.session.turn_detection.prefix_padding_ms is not None:
if (
event.session.turn_detection is not None
and isinstance(event.session.turn_detection, TurnDetection)
and event.session.turn_detection.prefix_padding_ms != NOT_GIVEN
):
ctx.pubsub.publish_nowait(unsupported_field_error("session.turn_detection.prefix_padding_ms"))

session_dict = ctx.session.model_dump()
session_update_dict = event.session.model_dump(
exclude_none=True,
exclude_defaults=True,
# https://docs.pydantic.dev/latest/concepts/serialization/#advanced-include-and-exclude
exclude={"input_audio_format": True, "output_audio_format": True, "turn_detection": {"prefix_padding_ms"}},
)
Expand Down
2 changes: 1 addition & 1 deletion src/speaches/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ async def handle_completions( # noqa: C901
# NOTE: Adding --use-one-literal-as-default breaks the `exclude_defaults=True` behavior
try:
chat_completion = await chat_completion_client.create(**proxied_body.model_dump(exclude_defaults=True))
except openai.BadRequestError as e:
except openai.APIStatusError as e:
return Response(content=e.message, status_code=e.status_code)
if isinstance(chat_completion, AsyncStream):

Expand Down
Loading
Loading