Skip to content

Commit

Permalink
Use response.text and response.data (#344)
Browse files Browse the repository at this point in the history
* use .text & .data
* fix colors, limit output queue size
* Fix one TEXT -> AUDIO
* format
* mention video
* simplify error cleanup
* max_queue_size
* Block python versions < 3.11
* async_enumarate instead of aenumerate
---------
Co-authored-by: Guillaume Vernade <[email protected]>
  • Loading branch information
MarkDaoust authored Dec 17, 2024
1 parent 16a9a62 commit 97c0998
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 376 deletions.
2 changes: 1 addition & 1 deletion gemini-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ Explore Gemini 2.0’s capabilities through the following notebooks using Google

Or explore on your own local machine.

* [Live API starter script](./live_api_starter.py) \- A locally runnable Python script using GenAI SDK that supports streaming audio in and out from your machine
* [Live API starter script](./live_api_starter.py) \- A locally runnable Python script using GenAI SDK that supports streaming audio in and audio + video out from your machine

Also find websocket-specific examples in the [`websockets`](./websockets/) directory.
155 changes: 83 additions & 72 deletions gemini-2/live_api_starter.ipynb

Large diffs are not rendered by default.

160 changes: 70 additions & 90 deletions gemini-2/live_api_starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,13 @@
CHANNELS = 1
SEND_SAMPLE_RATE = 16000
RECEIVE_SAMPLE_RATE = 24000
CHUNK_SIZE = 512
CHUNK_SIZE = 1024

MODEL = "models/gemini-2.0-flash-exp"

client = genai.Client(
http_options={'api_version': 'v1alpha'})
client = genai.Client(http_options={"api_version": "v1alpha"})

CONFIG={
"generation_config": {"response_modalities": ["AUDIO"]}}
CONFIG = {"generation_config": {"response_modalities": ["AUDIO"]}}

pya = pyaudio.PyAudio()

Expand All @@ -67,7 +65,10 @@ def __init__(self):

async def send_text(self):
while True:
text = await asyncio.to_thread(input, "message > ")
text = await asyncio.to_thread(
input,
"message > ",
)
if text.lower() == "q":
break
await self.session.send(text or ".", end_of_turn=True)
Expand All @@ -78,8 +79,11 @@ def _get_frame(self, cap):
# Check if the frame was read successfully
if not ret:
return None

img = PIL.Image.fromarray(frame)
# Fix: Convert BGR to RGB color space
# OpenCV captures in BGR but PIL expects RGB format
# This prevents the blue tint in the video feed
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = PIL.Image.fromarray(frame_rgb) # Now using RGB frame
img.thumbnail([1024, 1024])

image_io = io.BytesIO()
Expand All @@ -93,7 +97,9 @@ def _get_frame(self, cap):
async def get_frames(self):
# This takes about a second, and will block the whole program
# causing the audio pipeline to overflow if you don't to_thread it.
cap = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 represents the default camera
cap = await asyncio.to_thread(
cv2.VideoCapture, 0
) # 0 represents the default camera

while True:
frame = await asyncio.to_thread(self._get_frame, cap)
Expand All @@ -102,21 +108,19 @@ async def get_frames(self):

await asyncio.sleep(1.0)

self.video_out_queue.put_nowait(frame)
await self.out_queue.put(frame)

# Release the VideoCapture object
cap.release()

async def send_frames(self):
async def send_realtime(self):
while True:
frame = await self.video_out_queue.get()
await self.session.send(frame)
msg = await self.out_queue.get()
await self.session.send(msg)

async def listen_audio(self):
pya = pyaudio.PyAudio()

mic_info = pya.get_default_input_device_info()
stream = await asyncio.to_thread(
self.audio_stream = await asyncio.to_thread(
pya.open,
format=FORMAT,
channels=CHANNELS,
Expand All @@ -125,94 +129,70 @@ async def listen_audio(self):
input_device_index=mic_info["index"],
frames_per_buffer=CHUNK_SIZE,
)
if __debug__:
kwargs = {"exception_on_overflow": False}
else:
kwargs = {}
while True:
data = await asyncio.to_thread(stream.read, CHUNK_SIZE)
self.audio_out_queue.put_nowait(data)

async def send_audio(self):
while True:
chunk = await self.audio_out_queue.get()
await self.session.send({"data": chunk, "mime_type": "audio/pcm"})
data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE, **kwargs)
await self.out_queue.put({"data": data, "mime_type": "audio/pcm"})

async def receive_audio(self):
"Background task to reads from the websocket and write pcm chunks to the output queue"
while True:
async for response in self.session.receive():
server_content = response.server_content
if server_content is not None:
model_turn = server_content.model_turn
if model_turn is not None:
parts = model_turn.parts

for part in parts:
if part.text is not None:
print(part.text, end="")
elif part.inline_data is not None:
self.audio_in_queue.put_nowait(part.inline_data.data)

server_content.model_turn = None
turn_complete = server_content.turn_complete
if turn_complete:
# If you interrupt the model, it sends a turn_complete.
# For interruptions to work, we need to stop playback.
# So empty out the audio queue because it may have loaded
# much more audio than has played yet.
print("Turn complete")
while not self.audio_in_queue.empty():
self.audio_in_queue.get_nowait()
turn = self.session.receive()
async for response in turn:
if data := response.data:
self.audio_in_queue.put_nowait(data)
continue
if text := response.text:
print(text, end="")

# If you interrupt the model, it sends a turn_complete.
# For interruptions to work, we need to stop playback.
# So empty out the audio queue because it may have loaded
# much more audio than has played yet.
while not self.audio_in_queue.empty():
self.audio_in_queue.get_nowait()

async def play_audio(self):
pya = pyaudio.PyAudio()
stream = await asyncio.to_thread(
pya.open, format=FORMAT, channels=CHANNELS, rate=RECEIVE_SAMPLE_RATE, output=True
pya.open,
format=FORMAT,
channels=CHANNELS,
rate=RECEIVE_SAMPLE_RATE,
output=True,
)
while True:
bytestream = await self.audio_in_queue.get()
await asyncio.to_thread(stream.write, bytestream)

async def run(self):
"""Takes audio chunks off the input queue, and writes them to files.
Splits and displays files if the queue pauses for more than `max_pause`.
"""
async with (
client.aio.live.connect(model=MODEL, config=CONFIG) as session,
asyncio.TaskGroup() as tg,
):
self.session = session

self.audio_in_queue = asyncio.Queue()
self.audio_out_queue = asyncio.Queue()
self.video_out_queue = asyncio.Queue()

send_text_task = tg.create_task(self.send_text())

def cleanup(task):
for t in tg._tasks:
t.cancel()

send_text_task.add_done_callback(cleanup)

tg.create_task(self.listen_audio())
tg.create_task(self.send_audio())
tg.create_task(self.get_frames())
tg.create_task(self.send_frames())
tg.create_task(self.receive_audio())
tg.create_task(self.play_audio())

def check_error(task):
if task.cancelled():
return

if task.exception() is None:
return

e = task.exception()
traceback.print_exception(None, e, e.__traceback__)
sys.exit(1)

for task in tg._tasks:
task.add_done_callback(check_error)
try:
async with (
client.aio.live.connect(model=MODEL, config=CONFIG) as session,
asyncio.TaskGroup() as tg,
):
self.session = session

self.audio_in_queue = asyncio.Queue()
self.out_queue = asyncio.Queue(maxsize=5)

send_text_task = tg.create_task(self.send_text())
tg.create_task(self.send_realtime())
tg.create_task(self.listen_audio())
tg.create_task(self.get_frames())
tg.create_task(self.receive_audio())
tg.create_task(self.play_audio())

await send_text_task
raise asyncio.CancelledError("User requested exit")

except asyncio.CancelledError:
pass
except ExceptionGroup as EG:
self.audio_stream.close()
traceback.print_exception(EG)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 97c0998

Please sign in to comment.