Skip to content

Commit

Permalink
Screen sharing mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Giom-V committed Dec 18, 2024
1 parent 34f898c commit 5c54883
Showing 1 changed file with 45 additions and 2 deletions.
47 changes: 45 additions & 2 deletions gemini-2/websockets/live_api_starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# And to run this script, ensure the GOOGLE_API_KEY environment
# variable is set to the key you obtained from Google AI Studio.

# Add the "--mode screen" if you want to share your screen to the model
# Add the "--mode screen" if you want to share your screen to the model
# instead of your camera stream

import asyncio
Expand All @@ -32,6 +32,18 @@
import cv2
import pyaudio
import PIL.Image
import mss
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
type=str,
default="camera",
help="pixels to stream from",
choices=["camera", "screen"],
)
args = parser.parse_args()

from websockets.asyncio.client import connect

Expand All @@ -50,6 +62,8 @@
host = "generativelanguage.googleapis.com"
model = "gemini-2.0-flash-exp"

MODE = args.mode

api_key = os.environ["GOOGLE_API_KEY"]
uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}"

Expand Down Expand Up @@ -123,6 +137,32 @@ async def get_frames(self):
# Release the VideoCapture object
cap.release()

def _get_screen(self):
sct = mss.mss()
monitor = sct.monitors[0]

i = sct.grab(monitor)
mime_type = "image/jpeg"
image_bytes = mss.tools.to_png(i.rgb, i.size)
img = PIL.Image.open(io.BytesIO(image_bytes))

image_io = io.BytesIO()
img.save(image_io, format="jpeg")
image_io.seek(0)

image_bytes = image_io.read()
return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()}

async def get_screen(self):
while True:
frame = await asyncio.to_thread(self._get_screen)
if frame is None:
break

await asyncio.sleep(1.0)

await self.out_queue.put(frame)

async def send_realtime(self):
while True:
msg = await self.out_queue.get()
Expand Down Expand Up @@ -214,7 +254,10 @@ async def run(self):

tg.create_task(self.send_realtime())
tg.create_task(self.listen_audio())
tg.create_task(self.get_frames())
if MODE == "camera":
tg.create_task(self.get_frames())
elif MODE == "screen":
tg.create_task(self.get_screen())
tg.create_task(self.receive_audio())
tg.create_task(self.play_audio())

Expand Down

0 comments on commit 5c54883

Please sign in to comment.