Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/realtime webrtc #340

Merged
merged 4 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"aiostream>=0.6.4",
"cachetools>=5.5.1",
"httpx-ws>=0.7.1",
"aiortc>=1.10.1",
]

[project.optional-dependencies]
Expand Down
4 changes: 4 additions & 0 deletions src/speaches/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def setup_logger(log_level: str) -> None:
"level": "INFO",
"handlers": ["stdout"],
},
"aiortc.rtcrtpreceiver": {
"level": "INFO",
"handlers": ["stdout"],
},
"numba.core": {
"level": "INFO",
"handlers": ["stdout"],
Expand Down
4 changes: 4 additions & 0 deletions src/speaches/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from speaches.routers.models import (
router as models_router,
)
from speaches.routers.realtime.rtc import (
router as realtime_rtc_router,
)
from speaches.routers.realtime.ws import (
router as realtime_ws_router,
)
Expand Down Expand Up @@ -66,6 +69,7 @@ def create_app() -> FastAPI:
app.include_router(stt_router)
app.include_router(models_router)
app.include_router(misc_router)
app.include_router(realtime_rtc_router)
app.include_router(realtime_ws_router)
app.include_router(speech_router)
app.include_router(vad_router)
Expand Down
13 changes: 9 additions & 4 deletions src/speaches/realtime/response_event_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def add_output_item[T: ServerConversationItem](self, item: T) -> Generator[T, No
self.response.output.append(item)
self.pubsub.publish_nowait(ResponseOutputItemAddedEvent(response_id=self.id, item=item))
yield item
assert item.status != "incomplete", item
item.status = "completed"
self.pubsub.publish_nowait(ResponseOutputItemDoneEvent(response_id=self.id, item=item))
self.pubsub.publish_nowait(ResponseDoneEvent(response=self.response))

Expand All @@ -118,7 +120,7 @@ def add_item_content[T: ConversationItemContentText | ConversationItemContentAud
)

async def conversation_item_message_text_handler(self, chunk_stream: aiostream.Stream[ChatCompletionChunk]) -> None:
with self.add_output_item(ConversationItemMessage(role="assistant", content=[])) as item:
with self.add_output_item(ConversationItemMessage(role="assistant", status="incomplete", content=[])) as item:
self.conversation.create_item(item)

with self.add_item_content(item, ConversationItemContentText(text="")) as content:
Expand All @@ -139,7 +141,7 @@ async def conversation_item_message_text_handler(self, chunk_stream: aiostream.S
async def conversation_item_message_audio_handler(
self, chunk_stream: aiostream.Stream[ChatCompletionChunk]
) -> None:
with self.add_output_item(ConversationItemMessage(role="assistant", content=[])) as item:
with self.add_output_item(ConversationItemMessage(role="assistant", status="incomplete", content=[])) as item:
self.conversation.create_item(item)

with self.add_item_content(item, ConversationItemContentAudio(audio="", transcript="")) as content:
Expand Down Expand Up @@ -190,7 +192,10 @@ async def conversation_item_function_call_handler(
and tool_call.function.arguments is not None
), chunk
item = ConversationItemFunctionCall(
call_id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments
status="incomplete",
call_id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
assert item.call_id is not None and item.arguments is not None and item.name is not None, item

Expand Down Expand Up @@ -225,7 +230,7 @@ async def generate_response(self) -> None:
try:
completion_params = create_completion_params(
self.model,
list(items_to_chat_messages(self.conversation.items)),
list(items_to_chat_messages(self.configuration.input)),
self.configuration,
)
chunk_stream = await self.completion_client.create(**completion_params)
Expand Down
184 changes: 184 additions & 0 deletions src/speaches/realtime/rtc/audio_stream_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import asyncio
import base64
import io
import logging

from aiortc import MediaStreamTrack
from av.audio.frame import AudioFrame
import numpy as np
from openai.types.beta.realtime import ResponseAudioDeltaEvent
from opentelemetry import trace

from speaches.audio import audio_samples_from_file
from speaches.realtime.context import SessionContext
from speaches.realtime.input_audio_buffer_event_router import resample_audio_data

logger = logging.getLogger(__name__)
tracer = trace.get_tracer(__name__)


class AudioStreamTrack(MediaStreamTrack):
kind = "audio"

def __init__(self, ctx: SessionContext) -> None:
super().__init__()
self.ctx = ctx
# self.q = ctx.pubsub.subscribe()
self.frame_queue = asyncio.Queue() # Queue for AudioFrames
self._timestamp = 0
self._sample_rate = 48000
self._frame_duration = 0.01 # in seconds
self._samples_per_frame = int(self._sample_rate * self._frame_duration)
self._running = True

# Start the frame processing task
self._process_task = asyncio.create_task(self._audio_frame_generator())

async def recv(self) -> AudioFrame:
"""Receive the next audio frame."""
if not self._running:
raise MediaStreamError("Track has ended") # noqa: EM101

try:
frame = await self.frame_queue.get()
await asyncio.sleep(
0.005
) # NOTE: I believe some delay is neccessary to prevent buffers from being dropped.
except asyncio.CancelledError as e:
raise MediaStreamError("Track has ended") from e # noqa: EM101
else:
return frame

async def _audio_frame_generator(self) -> None:
"""Process incoming numpy arrays and split them into AudioFrames."""
try:
async for event in self.ctx.pubsub.subscribe_to("response.audio.delta"):
assert isinstance(event, ResponseAudioDeltaEvent)

if not self._running:
return

# copied from `input_audio_buffer.append` handler
audio_array = audio_samples_from_file(io.BytesIO(base64.b64decode(event.delta)))
audio_array = resample_audio_data(audio_array, 24000, 48000)

# Convert to int16 if not already
if audio_array.dtype != np.int16:
audio_array = (audio_array * 32767).astype(np.int16)

# Split the array into frame-sized chunks
frames = self._split_into_frames(audio_array)

# Create AudioFrames and add them to the frame queue
logger.info(f"Received audio: {len(audio_array)} samples")
logger.info(f"Split into {len(frames)} frames")
for frame_data in frames:
frame = self._create_frame(frame_data)
self.frame_queue.put_nowait(frame)

except asyncio.CancelledError:
logger.warning("Audio frame generator task cancelled")

def _split_into_frames(self, audio_array: np.ndarray) -> list[np.ndarray]:
# Ensure the array is 1D
if len(audio_array.shape) > 1:
audio_array = audio_array.flatten()

# Calculate number of complete frames
n_frames = len(audio_array) // self._samples_per_frame

frames = []
for i in range(n_frames):
start = i * self._samples_per_frame
end = start + self._samples_per_frame
frame = audio_array[start:end]
frames.append(frame)

remaining = len(audio_array) % self._samples_per_frame
if remaining > 0:
logger.info(f"Processing remaining {remaining} samples")
last_frame = audio_array[-remaining:]
padded_frame = np.pad(last_frame, (0, self._samples_per_frame - remaining), "constant", constant_values=0)
logger.info(f"Padded frame range: {padded_frame.min()}, {padded_frame.max()}")
frames.append(padded_frame)

return frames

def _create_frame(self, frame_data: np.ndarray) -> AudioFrame:
"""Create an AudioFrame from numpy array data.

Args:
frame_data: Numpy array containing exactly samples_per_frame samples

Returns:
AudioFrame object

"""
frame = AudioFrame(
format="s16",
layout="mono",
samples=self._samples_per_frame,
)
frame.sample_rate = self._sample_rate

# Convert numpy array to bytes and update frame
frame.planes[0].update(frame_data.tobytes())

# Set timestamp
frame.pts = self._timestamp
self._timestamp += self._samples_per_frame

return frame

def stop(self) -> None:
"""Stop the audio track and cleanup."""
self._running = False
if hasattr(self, "_process_task"):
self._process_task.cancel()
super().stop()


class ToneAudioStreamTrack(MediaStreamTrack):
kind = "audio"

def __init__(self) -> None:
super().__init__()
self._timestamp = 0
self._sample_rate = 24000
self._samples_per_frame = self._sample_rate // 100 # 10ms
self._running = True
self._frequency = 440 # 440 Hz tone

async def recv(self) -> AudioFrame:
if not self._running:
raise MediaStreamError("Track has ended") # noqa: EM101

# Generate sine wave for this frame
t = np.linspace(0, self._samples_per_frame / self._sample_rate, self._samples_per_frame)
samples = np.sin(2 * np.pi * self._frequency * t)
samples = (samples * 32767).astype(np.int16)

# Create frame
frame = AudioFrame(
format="s16",
layout="mono",
samples=self._samples_per_frame,
)
frame.sample_rate = self._sample_rate
frame.pts = self._timestamp
frame.planes[0].update(samples.tobytes())

self._timestamp += self._samples_per_frame

# Sleep for frame duration
await asyncio.sleep(0.01) # 10ms

return frame

def stop(self) -> None:
self._running = False
super().stop()


class MediaStreamError(Exception):
pass
Loading
Loading