From 2055b01880ffa7266cf9d0b0597a886115278259 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 18 Dec 2024 15:53:01 -0800 Subject: [PATCH] Add realtime_input > media_chunks wrapping --- gemini-2/websockets/live_api_starter.ipynb | 1423 +++++++------- gemini-2/websockets/live_api_starter.py | 13 +- .../live_api_streaming_in_colab.ipynb | 1659 +++++++++-------- gemini-2/websockets/live_api_tool_use.ipynb | 1632 ++++++++-------- 4 files changed, 2374 insertions(+), 2353 deletions(-) diff --git a/gemini-2/websockets/live_api_starter.ipynb b/gemini-2/websockets/live_api_starter.ipynb index b75c75c73..2d511752a 100644 --- a/gemini-2/websockets/live_api_starter.ipynb +++ b/gemini-2/websockets/live_api_starter.ipynb @@ -1,719 +1,722 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "jFP_KbCLhM47" - }, - "source": [ - "*Copyright 2024 Google LLC.*" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "906e07f6e562" - }, - "outputs": [], - "source": [ - "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "R5DkeFMP75as" - }, - "source": [ - "# Live API - Websockets Quickstart" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VQQYhS4-3abT" - }, - "source": [ - "\n", - " \n", - "
\n", - " Run in Google Colab\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "iS0rHk3RBrtA" - }, - "source": [ - "**This** notebok demonstrates simple usage of the Gemini Live API.\n", - "\n", - "This Notebook connects directly to the API websockets to demonstrate the the low-level details for anyone building without using an SDK.\n", - "\n", - "- If you are not interested in the low-level websocket details you should read the [SDK version of this notebook](../Live_API_Text_to_Audio.ipynb).\n", - "\n", - "This notebook implements a simple turn-based chat where you send messages as text, and the model replies with audio. The API is capable of much more than that. The goal here is to **demonstrate with simple code**.\n", - "\n", - "- The [Live API - Text to Text](../Live_API_Text_to_Text.ipynb) notebook is even simpler than this, as it doesn't deal with audio.\n", - "- The [Live API - Audio Streaming in Colab](./live_api_streaming_in_colab.ipynb) demonstrates streaming audio **in Colab**.
It's more _fun_ than this notebook but **not optimized for readability**.\n", - "- The [Live API Audio Video to Audio python script](./live_api_audio_video_to_audio.py) doesn't work in colab, but provides a relatively readable implementation of audio and video streaming." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gnzxta_57_Ip" - }, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IMLUXP3e8FUy" - }, - "source": [ - "### Install and import" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8gTJzV6K4dh6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/168.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━\u001b[0m \u001b[32m163.8/168.2 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.2/168.2 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h" - ] - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "jFP_KbCLhM47" + }, + "source": [ + "*Copyright 2024 Google LLC.*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "906e07f6e562" + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R5DkeFMP75as" + }, + "source": [ + "# Live API - Websockets Quickstart" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VQQYhS4-3abT" + }, + "source": [ + "\n", + " \n", + "
\n", + " Run in Google Colab\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iS0rHk3RBrtA" + }, + "source": [ + "**This** notebok demonstrates simple usage of the Gemini Live API.\n", + "\n", + "This Notebook connects directly to the API websockets to demonstrate the the low-level details for anyone building without using an SDK.\n", + "\n", + "- If you are not interested in the low-level websocket details you should read the [SDK version of this notebook](../Live_API_Text_to_Audio.ipynb).\n", + "\n", + "This notebook implements a simple turn-based chat where you send messages as text, and the model replies with audio. The API is capable of much more than that. The goal here is to **demonstrate with simple code**.\n", + "\n", + "- The [Live API - Text to Text](../Live_API_Text_to_Text.ipynb) notebook is even simpler than this, as it doesn't deal with audio.\n", + "- The [Live API - Audio Streaming in Colab](./live_api_streaming_in_colab.ipynb) demonstrates streaming audio **in Colab**.
It's more _fun_ than this notebook but **not optimized for readability**.\n", + "- The [Live API Audio Video to Audio python script](./live_api_audio_video_to_audio.py) doesn't work in colab, but provides a relatively readable implementation of audio and video streaming." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gnzxta_57_Ip" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IMLUXP3e8FUy" + }, + "source": [ + "### Install and import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8gTJzV6K4dh6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/168.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━\u001b[0m \u001b[32m163.8/168.2 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.2/168.2 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h" + ] + } + ], + "source": [ + "!pip install -q websockets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yd1vs3cP8EmS" + }, + "outputs": [], + "source": [ + "import asyncio\n", + "import base64\n", + "import contextlib\n", + "import datetime\n", + "import os\n", + "import json\n", + "import wave\n", + "import itertools\n", + "\n", + "from websockets.asyncio.client import connect\n", + "from IPython.display import display, Audio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T_C_11Lu8KjK" + }, + "source": [ + "### Constants\n", + "\n", + "To run the following cell, your API key must be stored in a Colab Secret named `GOOGLE_API_KEY`. If you don't already have an API key, or you're not sure how to create a Colab Secret, see [Authentication](../../quickstarts/Authentication.ipynb) for an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "O3GSbPL99z0d" + }, + "outputs": [], + "source": [ + "from google.colab import userdata\n", + "\n", + "os.environ[\"GOOGLE_API_KEY\"] = userdata.get(\"GOOGLE_API_KEY\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fv1EcvfpmHjX" + }, + "source": [ + "Multimodal Live API are a new capability introduced with the [Gemini 2.0](https://ai.google.dev/gemini-api/docs/models/gemini-v2) model so only works with this model. You need to use the `v1alpha` client version.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QNxC25Pg4Hfr" + }, + "outputs": [], + "source": [ + "MODEL = \"models/gemini-2.0-flash-exp\"\n", + "\n", + "HOST = \"generativelanguage.googleapis.com\"\n", + "\n", + "URI = f'wss://{HOST}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={os.environ[\"GOOGLE_API_KEY\"]}'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uffrNtdw8Ce2" + }, + "source": [ + "### Logging" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Wj7XAlzBLZer" + }, + "source": [ + "Uncomment the `logger.setLevel` call to show the log messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "30dIpYsUesm2" + }, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logger = logging.getLogger(\"Bidi\")\n", + "# logger.setLevel('DEBUG')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qtCOiuRM8Osx" + }, + "source": [ + "### Wave file writer\n", + "\n", + "The code in this secrtion is not essential for understanding the API, feel free to skip to the next section.\n", + "\n", + "The simplest way to playback the audio in Colab, is to write it outto a `.wav` file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZEbwH45fEBbc" + }, + "outputs": [], + "source": [ + "@contextlib.contextmanager\n", + "def wave_file(filename, channels=1, rate=24000, sample_width=2):\n", + " with wave.open(filename, \"wb\") as wf:\n", + " wf.setnchannels(channels)\n", + " wf.setsampwidth(sample_width)\n", + " wf.setframerate(rate)\n", + " yield wf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QutDG7r78Zf-" + }, + "source": [ + "## Main audio loop" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ERqyY0IFN8G9" + }, + "source": [ + "The class below implements the interaction with the Live API." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3zAjMOZXFuxI" + }, + "outputs": [], + "source": [ + "class AudioLoop:\n", + " def __init__(self, tools=None):\n", + " if tools is None:\n", + " self.tools = []\n", + " else:\n", + " self.tools = tools\n", + " self.ws = None\n", + " self.index = 0\n", + "\n", + " async def run(self):\n", + " print(\"Type 'q' to quit\")\n", + "\n", + " logger.debug(\"connect\")\n", + " async with connect(\n", + " URI, additional_headers={\"Content-Type\": \"application/json\"}\n", + " ) as ws:\n", + " self.ws = ws\n", + " await self.setup()\n", + "\n", + " while True:\n", + " # Ideally these would be separate tasks.\n", + " if not await self.send():\n", + " break\n", + " await self.recv()\n", + "\n", + " async def setup(self):\n", + " logger.debug(\"set_up\")\n", + " await self.ws.send(json.dumps({\"setup\": {\"model\": MODEL, \"tools\": self.tools}}))\n", + " raw_response = await self.ws.recv(decode=False)\n", + " setup_response = json.loads(raw_response.decode(\"ascii\"))\n", + " logger.debug(f\"Connected: {setup_response}\")\n", + "\n", + " async def send(self):\n", + " logger.debug(\"send\")\n", + " # `asyncio.to_thread` is important here, without it all other tasks are blocked.\n", + " text = await asyncio.to_thread(input, \"message > \")\n", + "\n", + " # If the input returns 'q' quit.\n", + " if text.lower() == \"q\":\n", + " return False\n", + "\n", + " # Wrap the text into a \"client_content\" message.\n", + " msg = {\n", + " \"client_content\": {\n", + " \"turns\": [{\"role\": \"user\", \"parts\": [{\"text\": text}]}],\n", + " \"turn_complete\": True,\n", + " }\n", + " }\n", + "\n", + " # Send the message to the model.\n", + " await self.ws.send(json.dumps(msg))\n", + " logger.debug(\"sent\")\n", + " return True\n", + "\n", + " async def recv(self):\n", + " # Start a new `.wav` file.\n", + " file_name = f\"audio_{self.index}.wav\"\n", + " with wave_file(file_name) as wav:\n", + " self.index += 1\n", + "\n", + " logger.debug(\"receive\")\n", + "\n", + " # Read chunks from the socket.\n", + " async for raw_response in self.ws:\n", + " response = json.loads(raw_response.decode())\n", + " logger.debug(f\"got chunk: {str(response)[:200]}\")\n", + " # print(response)\n", + "\n", + " server_content = response.pop(\"serverContent\", None)\n", + " if server_content is None:\n", + " logger.error(f\"Unhandled server message! - {response}\")\n", + " break\n", + "\n", + " # Write audio the chunk to the `.wav` file.\n", + " model_turn = server_content.pop(\"modelTurn\", None)\n", + " if model_turn is not None:\n", + " b64data = model_turn[\"parts\"][0][\"inlineData\"][\"data\"]\n", + " pcm_data = base64.b64decode(b64data)\n", + " print(\".\", end=\"\")\n", + " logger.debug(\"Got pcm_data\")\n", + " wav.writeframes(pcm_data)\n", + "\n", + " # Break out of the loop if the model's turn is complete.\n", + " turn_complete = server_content.pop(\"turnComplete\", None)\n", + " if turn_complete:\n", + " logger.debug(\"turn_complete\")\n", + " break\n", + "\n", + " display(Audio(file_name, autoplay=True))\n", + " await asyncio.sleep(2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AwNPuC_rAHAc" + }, + "source": [ + "There are 4 methods worth describing here:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tXPhEdHIPBif" + }, + "source": [ + "### `run` - The main loop\n", + "\n", + "This method:\n", + "\n", + "- Opens a `websocket` connecting to the Live API\n", + "- Calls the initial `setup` method\n", + "- Then enters the main loop where it alternates between `send` and `recv` until send returns `False`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lj5MbvafPPCM" + }, + "source": [ + "### `setup` - Initial setup\n", + "\n", + "The `setup` method sends the `setup` message, and awaits the response. You shouldn't try to `send` or `recv` anything else from the model until you've gotten the model's `setup_complete` response.\n", + "\n", + "The `setup` message (a `BidiGenerateContentSetup` object) is where you can set the `model`, `generation_config`, `system_instructions`, `tools` and `safety_settings`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oCg1qFf0PV44" + }, + "source": [ + "### `send` - Sends input text to the api\n", + "\n", + "The `send` method collects input text from the user, wraps it in a `client_content` message (an instance of `BidiGenerateContentClientContent`), and sends it to the model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tLukmBhPPib4" + }, + "source": [ + "### `recv` - Collects audio from the API and plays it\n", + "\n", + "The `recv` method collects audio chunks in a loop and writes them to a `.wav` file. It breaks out of the loop once the model sends a `turn_complete` method, and then plays the audio.\n", + "\n", + "To keep things simple in Colab it collects **all** the audio before playing it. [TODO: link other examples]() demonstrate how to play audio as soon as you start to receive it (using `PyAudio`), and how to interrupt the model (implement input and audio playback on separate tasks)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gGYtiV2N8b2o" + }, + "source": [ + "## Run" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7fwjAvU9MMP7" + }, + "source": [ + "### Example 1: simple usage with Google Search" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eme2cwH0JMwJ" + }, + "outputs": [ + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Type 'q' to quit\n", + "........................................................................." + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "!pip install -q websockets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Yd1vs3cP8EmS" - }, - "outputs": [], - "source": [ - "import asyncio\n", - "import base64\n", - "import contextlib\n", - "import datetime\n", - "import os\n", - "import json\n", - "import wave\n", - "import itertools\n", - "\n", - "from websockets.asyncio.client import connect\n", - "from IPython.display import display, Audio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "T_C_11Lu8KjK" - }, - "source": [ - "### Constants\n", - "\n", - "To run the following cell, your API key must be stored in a Colab Secret named `GOOGLE_API_KEY`. If you don't already have an API key, or you're not sure how to create a Colab Secret, see [Authentication](../../quickstarts/Authentication.ipynb) for an example." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O3GSbPL99z0d" - }, - "outputs": [], - "source": [ - "from google.colab import userdata\n", - "os.environ['GOOGLE_API_KEY'] = userdata.get('GOOGLE_API_KEY')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fv1EcvfpmHjX" - }, - "source": [ - "Multimodal Live API are a new capability introduced with the [Gemini 2.0](https://ai.google.dev/gemini-api/docs/models/gemini-v2) model so only works with this model. You need to use the `v1alpha` client version.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QNxC25Pg4Hfr" - }, - "outputs": [], - "source": [ - "MODEL = 'models/gemini-2.0-flash-exp'\n", - "\n", - "HOST='generativelanguage.googleapis.com'\n", - "\n", - "URI = f'wss://{HOST}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={os.environ[\"GOOGLE_API_KEY\"]}'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uffrNtdw8Ce2" - }, - "source": [ - "### Logging" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Wj7XAlzBLZer" - }, - "source": [ - "Uncomment the `logger.setLevel` call to show the log messages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "30dIpYsUesm2" - }, - "outputs": [], - "source": [ - "import logging\n", - "\n", - "logger = logging.getLogger('Bidi')\n", - "#logger.setLevel('DEBUG')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qtCOiuRM8Osx" - }, - "source": [ - "### Wave file writer\n", - "\n", - "The code in this secrtion is not essential for understanding the API, feel free to skip to the next section.\n", - "\n", - "The simplest way to playback the audio in Colab, is to write it outto a `.wav` file." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ZEbwH45fEBbc" - }, - "outputs": [], - "source": [ - "@contextlib.contextmanager\n", - "def wave_file(filename, channels=1, rate=24000, sample_width=2):\n", - " with wave.open(filename, \"wb\") as wf:\n", - " wf.setnchannels(channels)\n", - " wf.setsampwidth(sample_width)\n", - " wf.setframerate(rate)\n", - " yield wf" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QutDG7r78Zf-" - }, - "source": [ - "## Main audio loop" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ERqyY0IFN8G9" - }, - "source": [ - "The class below implements the interaction with the Live API." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3zAjMOZXFuxI" - }, - "outputs": [], - "source": [ - "class AudioLoop:\n", - " def __init__(self, tools=None):\n", - " if tools is None:\n", - " self.tools = []\n", - " else:\n", - " self.tools = tools\n", - " self.ws = None\n", - " self.index = 0\n", - "\n", - " async def run(self):\n", - " print(\"Type 'q' to quit\")\n", - "\n", - " logger.debug('connect')\n", - " async with connect(URI, additional_headers={'Content-Type': 'application/json'}) as ws:\n", - " self.ws = ws\n", - " await self.setup()\n", - "\n", - " while True:\n", - " # Ideally these would be separate tasks.\n", - " if not await self.send():\n", - " break\n", - " await self.recv()\n", - "\n", - " async def setup(self):\n", - " logger.debug(\"set_up\")\n", - " await self.ws.send(json.dumps({\n", - " 'setup' : {\n", - " \"model\": MODEL,\n", - " \"tools\": self.tools\n", - " }\n", - " }))\n", - " raw_response = await self.ws.recv(decode=False)\n", - " setup_response = json.loads(raw_response.decode('ascii'))\n", - " logger.debug(f'Connected: {setup_response}')\n", - "\n", - " async def send(self):\n", - " logger.debug('send')\n", - " # `asyncio.to_thread` is important here, without it all other tasks are blocked.\n", - " text = await asyncio.to_thread(input, \"message > \")\n", - "\n", - " # If the input returns 'q' quit.\n", - " if text.lower() == 'q':\n", - " return False\n", - "\n", - " # Wrap the text into a \"client_content\" message.\n", - " msg = {\n", - " \"client_content\": {\n", - " \"turns\": [{\n", - " \"role\": \"user\",\n", - " \"parts\": [{ \"text\": text }]\n", - " }],\n", - " 'turn_complete': True\n", - " }\n", - " }\n", - "\n", - " # Send the message to the model.\n", - " await self.ws.send(json.dumps(msg))\n", - " logger.debug('sent')\n", - " return True\n", - "\n", - " async def recv(self):\n", - " # Start a new `.wav` file.\n", - " file_name = f\"audio_{self.index}.wav\"\n", - " with wave_file(file_name) as wav:\n", - " self.index += 1\n", - "\n", - " logger.debug('receive')\n", - "\n", - " # Read chunks from the socket.\n", - " async for raw_response in self.ws:\n", - " response = json.loads(raw_response.decode())\n", - " logger.debug(f'got chunk: {str(response)[:200]}')\n", - " #print(response)\n", - "\n", - " server_content = response.pop('serverContent', None)\n", - " if server_content is None:\n", - " logger.error(f'Unhandled server message! - {response}')\n", - " break\n", - "\n", - " # Write audio the chunk to the `.wav` file.\n", - " model_turn = server_content.pop('modelTurn', None)\n", - " if model_turn is not None:\n", - " b64data = model_turn['parts'][0]['inlineData']['data']\n", - " pcm_data = base64.b64decode(b64data)\n", - " print('.', end='')\n", - " logger.debug('Got pcm_data')\n", - " wav.writeframes(pcm_data)\n", - "\n", - " # Break out of the loop if the model's turn is complete.\n", - " turn_complete = server_content.pop('turnComplete', None)\n", - " if turn_complete:\n", - " logger.debug('turn_complete')\n", - " break\n", - "\n", - " display(Audio(file_name, autoplay=True))\n", - " await asyncio.sleep(2)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AwNPuC_rAHAc" - }, - "source": [ - "There are 4 methods worth describing here:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tXPhEdHIPBif" - }, - "source": [ - "### `run` - The main loop\n", - "\n", - "This method:\n", - "\n", - "- Opens a `websocket` connecting to the Live API\n", - "- Calls the initial `setup` method\n", - "- Then enters the main loop where it alternates between `send` and `recv` until send returns `False`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lj5MbvafPPCM" - }, - "source": [ - "### `setup` - Initial setup\n", - "\n", - "The `setup` method sends the `setup` message, and awaits the response. You shouldn't try to `send` or `recv` anything else from the model until you've gotten the model's `setup_complete` response.\n", - "\n", - "The `setup` message (a `BidiGenerateContentSetup` object) is where you can set the `model`, `generation_config`, `system_instructions`, `tools` and `safety_settings`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oCg1qFf0PV44" - }, - "source": [ - "### `send` - Sends input text to the api\n", - "\n", - "The `send` method collects input text from the user, wraps it in a `client_content` message (an instance of `BidiGenerateContentClientContent`), and sends it to the model." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tLukmBhPPib4" - }, - "source": [ - "### `recv` - Collects audio from the API and plays it\n", - "\n", - "The `recv` method collects audio chunks in a loop and writes them to a `.wav` file. It breaks out of the loop once the model sends a `turn_complete` method, and then plays the audio.\n", - "\n", - "To keep things simple in Colab it collects **all** the audio before playing it. [TODO: link other examples]() demonstrate how to play audio as soon as you start to receive it (using `PyAudio`), and how to interrupt the model (implement input and audio playback on separate tasks)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gGYtiV2N8b2o" - }, - "source": [ - "## Run" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7fwjAvU9MMP7" - }, - "source": [ - "### Example 1: simple usage with Google Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "eme2cwH0JMwJ" - }, - "outputs": [ - { - "metadata": { - "tags": null - }, - "name": "stdout", - "output_type": "stream", - "text": [ - "Type 'q' to quit\n", - "........................................................................." - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "metadata": { - "tags": null - }, - "name": "stdout", - "output_type": "stream", - "text": [ - "................................................." - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "metadata": { - "tags": null - }, - "name": "stdout", - "output_type": "stream", - "text": [ - "..................." - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "metadata": { - "tags": null - }, - "name": "stdout", - "output_type": "stream", - "text": [ - "............................................" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "metadata": { - "tags": null - }, - "name": "stdout", - "output_type": "stream", - "text": [ - "........................................." - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "................................................." + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "tools = [\n", - " {'google_search': {}},\n", - "]\n", - "\n", - "await AudioLoop(tools=tools).run()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DCz7jK87MSWz" - }, - "source": [ - "### Example 2: function calling" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WxdwgTKIGIlY" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Type 'q' to quit\n", - "message > make it dark \n", - "....." - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ERROR:Bidi:Unhandled server message! - {'toolCall': {'functionCalls': [{'name': 'turn_off_the_lights', 'args': {}, 'id': 'function-call-1061054218196002212'}]}}\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "..................." + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "tools = [\n", - " {'function_declarations': [{'name': 'turn_on_the_lights', 'description': None}, {'name': 'turn_off_the_lights', 'description': None}]}\n", - "]\n", - "\n", - "await AudioLoop(tools=tools).run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xI5x2imVQxnr" - }, - "outputs": [], - "source": [ - "tools = [\n", - " {'google_search': {}},\n", - " {'function_declarations': [{'name': 'turn_on_the_lights', 'description': None}, {'name': 'turn_off_the_lights', 'description': None}]}\n", - "]\n", - "\n", - "await AudioLoop(tools=tools).run()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "leaqbF2dNt0V" - }, - "outputs": [], - "source": [ - "tools = [\n", - " {'code_execution': {}}\n", - "]\n", - "\n", - "await AudioLoop(tools=tools).run()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xdyyMqckmnf8" - }, - "source": [ - "## Next steps\n", - "\n", - "\n", - "\n", - "This tutorial just shows basic usage of the Live API, using the Python GenAI SDK.\n", - "\n", - "- If you aren't looking for code, and just want to try multimedia streaming use [Live API in Google AI Studio](https://aistudio.google.com/app/live).\n", - "- If you want to see how to setup streaming interruptible audio and video using the Live API and the SDK see the [Audio and Video input Tutorial](../../gemini-2/live_api_starter.py).\n", - "- Try the [Tool use in the live API tutorial](../../gemini-2/websockets/live_api_tool_use.ipynb) for an walkthrough of Gemini 2.0's new tool use capabilities.\n", - "- There is a [Streaming audio in Colab example](../../gemini-2/websockets/live_api_streaming_in_colab.ipynb), but this is more of a **demo**, it's **not optimized for readability**.\n", - "- Other nice Gemini 2.0 examples can also be found in the [Cookbook](https://github.com/google-gemini/cookbook/blob/main/gemini-2/).\n" + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "............................................" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "metadata": { + "tags": null + }, + "name": "stdout", + "output_type": "stream", + "text": [ + "........................................." + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" } - ], - "metadata": { - "colab": { - "name": "live_api_starter.ipynb", - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" + ], + "source": [ + "tools = [\n", + " {\"google_search\": {}},\n", + "]\n", + "\n", + "await AudioLoop(tools=tools).run()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DCz7jK87MSWz" + }, + "source": [ + "### Example 2: function calling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WxdwgTKIGIlY" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type 'q' to quit\n", + "message > make it dark \n", + "....." + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:Bidi:Unhandled server message! - {'toolCall': {'functionCalls': [{'name': 'turn_off_the_lights', 'args': {}, 'id': 'function-call-1061054218196002212'}]}}\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } + ], + "source": [ + "tools = [\n", + " {\n", + " \"function_declarations\": [\n", + " {\"name\": \"turn_on_the_lights\", \"description\": None},\n", + " {\"name\": \"turn_off_the_lights\", \"description\": None},\n", + " ]\n", + " }\n", + "]\n", + "\n", + "await AudioLoop(tools=tools).run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xI5x2imVQxnr" + }, + "outputs": [], + "source": [ + "tools = [\n", + " {\"google_search\": {}},\n", + " {\n", + " \"function_declarations\": [\n", + " {\"name\": \"turn_on_the_lights\", \"description\": None},\n", + " {\"name\": \"turn_off_the_lights\", \"description\": None},\n", + " ]\n", + " },\n", + "]\n", + "\n", + "await AudioLoop(tools=tools).run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "leaqbF2dNt0V" + }, + "outputs": [], + "source": [ + "tools = [{\"code_execution\": {}}]\n", + "\n", + "await AudioLoop(tools=tools).run()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xdyyMqckmnf8" + }, + "source": [ + "## Next steps\n", + "\n", + "\n", + "\n", + "This tutorial just shows basic usage of the Live API, using the Python GenAI SDK.\n", + "\n", + "- If you aren't looking for code, and just want to try multimedia streaming use [Live API in Google AI Studio](https://aistudio.google.com/app/live).\n", + "- If you want to see how to setup streaming interruptible audio and video using the Live API and the SDK see the [Audio and Video input Tutorial](../../gemini-2/live_api_starter.py).\n", + "- Try the [Tool use in the live API tutorial](../../gemini-2/websockets/live_api_tool_use.ipynb) for an walkthrough of Gemini 2.0's new tool use capabilities.\n", + "- There is a [Streaming audio in Colab example](../../gemini-2/websockets/live_api_streaming_in_colab.ipynb), but this is more of a **demo**, it's **not optimized for readability**.\n", + "- Other nice Gemini 2.0 examples can also be found in the [Cookbook](https://github.com/google-gemini/cookbook/blob/main/gemini-2/).\n" + ] + } + ], + "metadata": { + "colab": { + "name": "live_api_starter.ipynb", + "toc_visible": true }, - "nbformat": 4, - "nbformat_minor": 0 + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/gemini-2/websockets/live_api_starter.py b/gemini-2/websockets/live_api_starter.py index 42b4042aa..089776030 100755 --- a/gemini-2/websockets/live_api_starter.py +++ b/gemini-2/websockets/live_api_starter.py @@ -140,16 +140,16 @@ async def get_frames(self): def _get_screen(self): sct = mss.mss() monitor = sct.monitors[0] - + i = sct.grab(monitor) mime_type = "image/jpeg" image_bytes = mss.tools.to_png(i.rgb, i.size) img = PIL.Image.open(io.BytesIO(image_bytes)) - + image_io = io.BytesIO() img.save(image_io, format="jpeg") image_io.seek(0) - + image_bytes = image_io.read() return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} @@ -158,10 +158,11 @@ async def get_screen(self): frame = await asyncio.to_thread(self._get_screen) if frame is None: break - + await asyncio.sleep(1.0) - - await self.out_queue.put(frame) + + msg = {"realtime_input": {"media_chunks": [frame]}} + await self.out_queue.put(msg) async def send_realtime(self): while True: diff --git a/gemini-2/websockets/live_api_streaming_in_colab.ipynb b/gemini-2/websockets/live_api_streaming_in_colab.ipynb index 39773823a..f4a74d7bd 100644 --- a/gemini-2/websockets/live_api_streaming_in_colab.ipynb +++ b/gemini-2/websockets/live_api_streaming_in_colab.ipynb @@ -1,832 +1,851 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "jWESX0tpdrE-" - }, - "source": [ - "##### Copyright 2024 Google LLC." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "YQvTrJpxzRlJ" - }, - "outputs": [], - "source": [ - "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3hp_P0cDzTWp" - }, - "source": [ - "# Gemini 2.0 - Multimodal live API: Streaming in Colab" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OLW8VU78zZOc" - }, - "source": [ - "\n", - " \n", - "
\n", - " Run in Google Colab\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "y7f4kFby0E6j" - }, - "source": [ - "This notebook uses the Multimodel Live API to stream bidirectional audio in Colab. This notebook is much more a **demo** than a tutorial. This code demonstrates that it is possible to stream audio with interruptions in Colab. It takes a few hacks to make it work.\n", - "\n", - "* For an overview of the Live API, see the [Live API docs](https://ai.google.dev/api/multimodal-live).\n", - "* If you want a good live API experience, try the [Live API in Google AI Studio](https://aistudio.google.com/app/live).\n", - "* If you want to learn how the Live API works, please refer to the [Live API starter tutorial](../live_api_starter.ipynb).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NSUz31fds3Z9" - }, - "source": [ - "### Set up\n", - "\n", - "To run the following cell, your API key must be stored in a Colab Secret named `GOOGLE_API_KEY`. If you don't already have an API key, or you're not sure how to create a Colab Secret, see [Authentication](../../quickstarts/Authentication.ipynb) for an example." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bCwTqSAKsYPI" - }, - "outputs": [], - "source": [ - "from google.colab import userdata\n", - "GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h6hvdOmqs1lT" - }, - "source": [ - "Now to run it just run all the cells.\n", - "\n", - "**Important**: On first try it will typically throw an error and ask for permission to record audio, if that happens allow audio, and **run it again**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "i7Hc33s8lPjg" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/168.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.9/168.2 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.2/168.2 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h" - ] - } - ], - "source": [ - "# @title Install stuff, monkey patch old Python {display-mode: 'form'}\n", - "!pip install -q websockets taskgroup\n", - "\n", - "# Colab runs Python 3.11, but this needs a backport of taskgroup\n", - "# monkey patch:\n", - "import asyncio, taskgroup, exceptiongroup\n", - "asyncio.TaskGroup = taskgroup.TaskGroup\n", - "asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LDTWCuZyl-zd" - }, - "outputs": [], - "source": [ - "# @title Inline copy of colab_stream {display-mode: 'form'}\n", - "import asyncio, contextlib, json\n", - "from google.colab import output\n", - "from IPython import display\n", - "\n", - "# alt:\n", - "# message.WaitForRawInput()\n", - "# colab.frontend.sendMessage({'action': 'keyboard_input', 'payload': state.send});\n", - "\n", - "_start_session_js = \"\"\"\n", - "let start_session = (userFn) => {\n", - " let debug = console.log;\n", - " debug = ()=>{};\n", - "\n", - " let ctrl = new AbortController();\n", - " let state = {\n", - " recv: [],\n", - " onRecv: () => {},\n", - " send: [],\n", - " onDone: new Promise((acc) => ctrl.signal.addEventListener('abort', () => acc())),\n", - " write: (data) => {\n", - " state.send.push(data);\n", - " }\n", - " };\n", - " window._js_session_on_poll = (data) => {\n", - " debug(\"on_poll\", data);\n", - " for (let msg of data) {\n", - " if ('data' in msg) {\n", - " state.recv.push(msg.data);\n", - " }\n", - " if ('error' in msg) {\n", - " ctrl.abort(new Error('Remote: ' + msg.error));\n", - " }\n", - " if ('finish' in msg) {\n", - " // TODO\n", - " ctrl.abort(new Error('Remote: finished'));\n", - " }\n", - " }\n", - " state.onRecv();\n", - " let result = state.send;\n", - " state.send = [];\n", - " debug(\"on_poll: result\", result);\n", - " return result;\n", - " };\n", - " let connection = {\n", - " signal: ctrl.signal,\n", - " read: async () => {\n", - " while(!ctrl.signal.aborted) {\n", - " if (state.recv.length != 0) {\n", - " return state.recv.shift();\n", - " }\n", - " await Promise.race([\n", - " new Promise((acc) => state.onRecv = acc),\n", - " state.onDone,\n", - " ]);\n", - " }\n", - " },\n", - " write: (data) => {\n", - " state.write({'data': data});\n", - " }\n", - " };\n", - " debug(\"starting userFn\");\n", - " userFn(connection).then(() => {\n", - " debug(\"userFn finished\");\n", - " ctrl.abort(new Error(\"end of input\"));\n", - " state.write({'finished': true});\n", - " },\n", - " (e) => {\n", - " debug(\"userFn error\", e);\n", - " console.error(\"Stream function failed\", e);\n", - " ctrl.abort(e);\n", - " state.write({'error': '' + e});\n", - " });\n", - "};\n", - "\"\"\"\n", - "\n", - "\n", - "class Connection:\n", - "\n", - " def __init__(self):\n", - " self._recv = []\n", - " self._on_recv_ready = asyncio.Event()\n", - " self._send = []\n", - " self._on_done = asyncio.Future()\n", - "\n", - " async def write(self, data):\n", - " self._send.append({'data': data})\n", - "\n", - " async def read(self):\n", - " while not self._on_done.done() and not self._recv:\n", - " self._on_recv_ready.clear()\n", - " await self._on_recv_ready.wait()\n", - " # print(\"read, done waiting: \", self._recv, self._on_done)\n", - " if self._on_done.done() and self._on_done.exception() is not None:\n", - " raise self._on_done.exception()\n", - " elif self._recv:\n", - " return self._recv.pop(0)\n", - " else:\n", - " return EOFError('End of stream')\n", - "\n", - " def _poll(self):\n", - " # Polling is needed as ipykernel has blocking mainloop\n", - " # (Comms do not work)\n", - " # print(\"calling poll\")\n", - " res = output.eval_js(f'window._js_session_on_poll({json.dumps(self._send)})')\n", - " # print(\"poll: \", res)\n", - " self._send = []\n", - " for r in res:\n", - " if 'data' in r:\n", - " self._recv.append(r['data'])\n", - " self._on_recv_ready.set()\n", - " elif 'error' in r:\n", - " self._on_done.set_exception(Exception('Remote error: ' + r['error']))\n", - " self._on_recv_ready.set()\n", - " elif 'finished' in r:\n", - " self._on_done.set_result(None)\n", - " self._on_recv_ready.set()\n", - "\n", - " async def _pump(self, pump_interval):\n", - " while not self._on_done.done():\n", - " self._poll()\n", - " await asyncio.sleep(pump_interval)\n", - "\n", - "\n", - "@contextlib.asynccontextmanager\n", - "async def RunningLiveJs(userCode, pump_interval=0.1):\n", - " \"\"\"Runs given javascript async code connecting it to colab.\n", - "\n", - " Use .write(msg) and .read() methods on this context manager\n", - " to exchange messages with JavaScript code.\n", - "\n", - " From JavaScript use 'connection.write(data)'\n", - " and 'await connection.read()' to exchange messages with colab.\n", - " \"\"\"\n", - " c = Connection()\n", - " output.eval_js(\n", - " f\"\"\"\n", - " let userFn = async (connection) => {{\n", - " {userCode}\n", - " }};\n", - " {_start_session_js};\n", - " start_session(userFn);\n", - " 1;\n", - " \"\"\",\n", - " ignore_result=True\n", - " )\n", - " t = asyncio.create_task(c._pump(pump_interval))\n", - "\n", - " def log_error(f):\n", - " if f.exception() is not None:\n", - " print('error: ', f.exception())\n", - "\n", - " t.add_done_callback(log_error)\n", - " try:\n", - " yield c\n", - " finally:\n", - " t.cancel()\n", - " output.eval_js(\n", - " \"\"\"window._js_session_on_poll([{finish: true}]);\"\"\", ignore_result=True\n", - " )" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "jWESX0tpdrE-" + }, + "source": [ + "##### Copyright 2024 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "YQvTrJpxzRlJ" + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3hp_P0cDzTWp" + }, + "source": [ + "# Gemini 2.0 - Multimodal live API: Streaming in Colab" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OLW8VU78zZOc" + }, + "source": [ + "\n", + " \n", + "
\n", + " Run in Google Colab\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y7f4kFby0E6j" + }, + "source": [ + "This notebook uses the Multimodel Live API to stream bidirectional audio in Colab. This notebook is much more a **demo** than a tutorial. This code demonstrates that it is possible to stream audio with interruptions in Colab. It takes a few hacks to make it work.\n", + "\n", + "* For an overview of the Live API, see the [Live API docs](https://ai.google.dev/api/multimodal-live).\n", + "* If you want a good live API experience, try the [Live API in Google AI Studio](https://aistudio.google.com/app/live).\n", + "* If you want to learn how the Live API works, please refer to the [Live API starter tutorial](../live_api_starter.ipynb).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NSUz31fds3Z9" + }, + "source": [ + "### Set up\n", + "\n", + "To run the following cell, your API key must be stored in a Colab Secret named `GOOGLE_API_KEY`. If you don't already have an API key, or you're not sure how to create a Colab Secret, see [Authentication](../../quickstarts/Authentication.ipynb) for an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bCwTqSAKsYPI" + }, + "outputs": [], + "source": [ + "from google.colab import userdata\n", + "\n", + "GOOGLE_API_KEY = userdata.get(\"GOOGLE_API_KEY\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h6hvdOmqs1lT" + }, + "source": [ + "Now to run it just run all the cells.\n", + "\n", + "**Important**: On first try it will typically throw an error and ask for permission to record audio, if that happens allow audio, and **run it again**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "i7Hc33s8lPjg" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "lR8Pbzu6pcda" - }, - "outputs": [], - "source": [ - "# @title Inline copy of colab_audio {display-mode: 'form'}\n", - "\n", - "\"\"\"Realtime Audio I/O support.\n", - "\n", - "Example use:\n", - "\n", - " async with colab_audio.RunningLiveAudio() as audio:\n", - " bytes_per_second = audio.config.sample_rate * audio.config.frame_size\n", - " print ('recording (3sec)')\n", - " buf = b''\n", - " while len(buf) < 3*bytes_per_second:\n", - " buf += await audio.read()\n", - " print ('playing')\n", - " await audio.enqueue(buf)\n", - " await asyncio.sleep(3)\n", - " print ('done')\n", - " display.display(colab_audio.Audio(audio.config, buf))\n", - "\"\"\"\n", - "\n", - "import asyncio\n", - "import base64\n", - "from collections.abc import AsyncIterator\n", - "import contextlib\n", - "import dataclasses\n", - "import io\n", - "import json\n", - "import time\n", - "import wave\n", - "import numpy as np\n", - "\n", - "\n", - "@dataclasses.dataclass(frozen=True)\n", - "class AudioConfig:\n", - " \"\"\"Configuration of audio stream.\"\"\"\n", - "\n", - " sample_rate: int\n", - " format: str = 'S16_LE' # only supported value\n", - " channels: int = 1 # only supported value\n", - "\n", - " @property\n", - " def sample_size(self) -> int:\n", - " assert self.format == 'S16_LE'\n", - " return 2\n", - "\n", - " @property\n", - " def frame_size(self) -> int:\n", - " return self.channels * self.sample_size\n", - "\n", - " @property\n", - " def numpy_dtype(self) -> np.dtype:\n", - " assert self.format == 'S16_LE'\n", - " return np.dtype(np.int16).newbyteorder('<')\n", - "\n", - "\n", - "@dataclasses.dataclass(frozen=True)\n", - "class Audio:\n", - " \"\"\"Unit of audio data with configuration.\"\"\"\n", - "\n", - " config: AudioConfig\n", - " data: bytes\n", - "\n", - " @staticmethod\n", - " def silence(config: AudioConfig, length_seconds: float | int) -> 'Audio':\n", - " frame = b'\\0' * config.frame_size\n", - " num_frames = int(length_seconds * config.sample_rate)\n", - " if num_frames < 0:\n", - " num_frames = 0\n", - " return Audio(config=config, data=frame * num_frames)\n", - "\n", - " def as_numpy(self):\n", - " return np.frombuffer(self.data, dtype=self.config.numpy_dtype)\n", - "\n", - " def as_wav_bytes(self) -> bytes:\n", - " buf = io.BytesIO()\n", - " with wave.open(buf, 'w') as wav:\n", - " wav.setnchannels(self.config.channels)\n", - " wav.setframerate(self.config.sample_rate)\n", - " assert self.config.format == 'S16_LE'\n", - " wav.setsampwidth(2) # 16bit\n", - " wav.writeframes(self.data)\n", - " return buf.getvalue()\n", - "\n", - " def _ipython_display_(self):\n", - " \"\"\"Hook displaying audio as HTML tag.\"\"\"\n", - " from IPython.display import display, HTML\n", - "\n", - " b64_wav = base64.b64encode(self.as_wav_bytes()).decode('utf-8')\n", - " display(HTML(f\"\"\"\n", - " \n", - " \"\"\".strip()))\n", - "\n", - " async def astream_realtime(\n", - " self, expected_delta_sec: float = 0.1\n", - " ) -> AsyncIterator[bytes]:\n", - " \"\"\"Yields audio data in chunks as if it was played realtime.\"\"\"\n", - " current_pos = 0\n", - " mono_start_ns = time.monotonic_ns()\n", - " while current_pos < len(self.data):\n", - " # print('sleep')\n", - " await asyncio.sleep(expected_delta_sec)\n", - " delta_ns = time.monotonic_ns() - mono_start_ns\n", - " expected_pos_frames = int(delta_ns * self.config.sample_rate / 1e9)\n", - " next_pos = expected_pos_frames * self.config.frame_size\n", - " # print (f'{next_pos = }, {current_pos =}, {len(self.data) = }')\n", - " if next_pos > current_pos:\n", - " yield self.data[current_pos:next_pos]\n", - " current_pos = next_pos\n", - "\n", - " def __add__(self, other: 'Audio') -> 'Audio':\n", - " assert self.config == other.config\n", - " return Audio(config=self.config, data=self.data + other.data)\n", - "\n", - "\n", - "class FailedToStartError(Exception):\n", - " \"\"\"Raised when audio session fails to start.\"\"\"\n", - "\n", - "\n", - "class AudioSession:\n", - " \"\"\"Connection to audio recording/playback on client side.\"\"\"\n", - "\n", - " def __init__(self, config: AudioConfig, connection: Connection):\n", - " self._config = config\n", - " self._connection = connection\n", - " self._done = False\n", - " self._read_queue: asyncio.Queue[bytes] = asyncio.Queue()\n", - " self._started = asyncio.Future()\n", - "\n", - " @property\n", - " def config(self) -> AudioConfig:\n", - " return self._config\n", - "\n", - " async def await_start(self):\n", - " await self._started\n", - "\n", - " async def _read_loop(self):\n", - " # print ('read_loop')\n", - " while True:\n", - " # print ('await read')\n", - " data = await self._connection.read()\n", - " # print(\"data\", data)\n", - " if 'audio_in' in data:\n", - " # print(\"audio_in\", data['audio_in'])\n", - " raw_data = base64.b64decode(data['audio_in'].encode('utf-8'))\n", - " # print(\"audio_in\", raw_data)\n", - " self._read_queue.put_nowait(raw_data)\n", - " if 'started' in data:\n", - " self._started.set_result(None)\n", - " if 'failed_to_start' in data:\n", - " self._started.set_exception(\n", - " FailedToStartError(\n", - " f'Failed to start audio: {data[\"failed_to_start\"]}'\n", - " )\n", - " )\n", - "\n", - " async def enqueue(self, audio_data: bytes):\n", - " b64_data = base64.b64encode(audio_data).decode('utf-8')\n", - " await self._connection.write({'audio_out': b64_data})\n", - "\n", - " async def clear_queue(self):\n", - " await self._connection.write({'flush': True})\n", - "\n", - " async def read(self) -> bytes:\n", - " return await self._read_queue.get()\n", - "\n", - "\n", - "STANDARD_AUDIO_CONFIG = AudioConfig(sample_rate=16000, channels=1)\n", - "\n", - "\n", - "# JavaScript code running in AudioWorklet, executing realtime audio processing.\n", - "_audio_processor_worklet_js = \"\"\"\n", - "class PortProcessor extends AudioWorkletProcessor {\n", - " constructor() {\n", - " super();\n", - " this._queue = [];\n", - " this.port.onmessage = (event) => {\n", - " //console.log(event.data);\n", - " if ('enqueue' in event.data) {\n", - " this.enqueueAudio(event.data.enqueue);\n", - " }\n", - " if ('clear' in event.data) {\n", - " this.clearAudio();\n", - " }\n", - " };\n", - " this._out = [];\n", - " this._out_len = 0;\n", - " console.log(\"PortProcessor ctor\", this);\n", - "\n", - " this.port.postMessage({\n", - " debug: \"Hello from the processor!\",\n", - " });\n", - " }\n", - "\n", - " encodeAudio(input) {\n", - " const channel = input[0];\n", - " const data = new ArrayBuffer(2 * channel.length);\n", - " const view = new DataView(data);\n", - " for (let i=0; i (2*sampleRate / 20)) {\n", - " let concat = new Uint8Array(this._out_len);\n", - " let idx = 0;\n", - " for (let a of this._out) {\n", - " concat.set(new Uint8Array(a), idx);\n", - " idx += a.byteLength;\n", - " }\n", - " this._out = [];\n", - " this._out_len = 0;\n", - " this.port.postMessage({\n", - " 'audio_in': concat.buffer,\n", - " });\n", - " }\n", - "\n", - " // forward output\n", - " this.dequeueIntoBuffer(outputs[0][0]);\n", - " // copy to other channels\n", - " for (let i=1; i {\n", - " if ('audio_in' in event.data) {\n", - " // base64 encode ugly way\n", - " let encoded = btoa(String.fromCharCode(\n", - " ...Array.from(new Uint8Array(event.data.audio_in))));\n", - " //console.log(\"base64 input\", encoded);\n", - " connection.write({audio_in: encoded});\n", - " } else {\n", - " console.log(\"from processor (unhandled)\", event);\n", - " }\n", - " };\n", - " source.connect(processor);\n", - " processor.connect(audioCtx.destination);\n", - " //await new Promise((acc) => setTimeout(acc, 1000));\n", - " while(!connection.signal.aborted) {\n", - " let request = await connection.read();\n", - " //console.log(request);\n", - " if ('audio_out' in request) {\n", - " let decoded = Uint8Array.from(\n", - " atob(request.audio_out), c => c.charCodeAt(0)).buffer;\n", - " //console.log('Enqueue', decoded);\n", - " processor.port.postMessage({'enqueue': decoded});\n", - " } else if('flush' in request) {\n", - " processor.port.postMessage({'clear': ''});\n", - " }\n", - " }\n", - "} finally {\n", - " userMedia.getTracks().forEach(t => t.stop());\n", - " audioCtx.close();\n", - "}\n", - "\"\"\"\n", - "\n", - "\n", - "@contextlib.asynccontextmanager\n", - "async def RunningLiveAudio(\n", - " config: AudioConfig = STANDARD_AUDIO_CONFIG, pump_interval=0.1\n", - "):\n", - " \"\"\"Runs audio connection to Colab UI and returns `AudioConnection` connected to it.\"\"\"\n", - " assert config.channels == 1\n", - " assert config.format == 'S16_LE'\n", - " required_js = f\"\"\"\n", - " const audio_worklet_js = {json.dumps(_audio_processor_worklet_js)};\n", - " const sample_rate = {json.dumps(config.sample_rate)};\n", - " {_audio_session_js}\n", - " \"\"\"\n", - " try:\n", - " async with contextlib.AsyncExitStack() as stack:\n", - " tg = await stack.enter_async_context(asyncio.TaskGroup())\n", - " connection = await stack.enter_async_context(\n", - " RunningLiveJs(required_js, pump_interval)\n", - " )\n", - " session = AudioSession(config, connection)\n", - " read_task = tg.create_task(session._read_loop()) # copy data to queue\n", - " tg.create_task(session.await_start()) # fail session if it fails to start\n", - " yield session\n", - " read_task.cancel()\n", - " except asyncio.ExceptionGroup as e:\n", - " if len(e.exceptions) == 1:\n", - " raise e.exceptions[0]\n", - " else:\n", - " raise" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/168.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.9/168.2 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.2/168.2 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h" + ] + } + ], + "source": [ + "# @title Install stuff, monkey patch old Python {display-mode: 'form'}\n", + "!pip install -q websockets taskgroup\n", + "\n", + "# Colab runs Python 3.11, but this needs a backport of taskgroup\n", + "# monkey patch:\n", + "import asyncio, taskgroup, exceptiongroup\n", + "\n", + "asyncio.TaskGroup = taskgroup.TaskGroup\n", + "asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LDTWCuZyl-zd" + }, + "outputs": [], + "source": [ + "# @title Inline copy of colab_stream {display-mode: 'form'}\n", + "import asyncio, contextlib, json\n", + "from google.colab import output\n", + "from IPython import display\n", + "\n", + "# alt:\n", + "# message.WaitForRawInput()\n", + "# colab.frontend.sendMessage({'action': 'keyboard_input', 'payload': state.send});\n", + "\n", + "_start_session_js = \"\"\"\n", + "let start_session = (userFn) => {\n", + " let debug = console.log;\n", + " debug = ()=>{};\n", + "\n", + " let ctrl = new AbortController();\n", + " let state = {\n", + " recv: [],\n", + " onRecv: () => {},\n", + " send: [],\n", + " onDone: new Promise((acc) => ctrl.signal.addEventListener('abort', () => acc())),\n", + " write: (data) => {\n", + " state.send.push(data);\n", + " }\n", + " };\n", + " window._js_session_on_poll = (data) => {\n", + " debug(\"on_poll\", data);\n", + " for (let msg of data) {\n", + " if ('data' in msg) {\n", + " state.recv.push(msg.data);\n", + " }\n", + " if ('error' in msg) {\n", + " ctrl.abort(new Error('Remote: ' + msg.error));\n", + " }\n", + " if ('finish' in msg) {\n", + " // TODO\n", + " ctrl.abort(new Error('Remote: finished'));\n", + " }\n", + " }\n", + " state.onRecv();\n", + " let result = state.send;\n", + " state.send = [];\n", + " debug(\"on_poll: result\", result);\n", + " return result;\n", + " };\n", + " let connection = {\n", + " signal: ctrl.signal,\n", + " read: async () => {\n", + " while(!ctrl.signal.aborted) {\n", + " if (state.recv.length != 0) {\n", + " return state.recv.shift();\n", + " }\n", + " await Promise.race([\n", + " new Promise((acc) => state.onRecv = acc),\n", + " state.onDone,\n", + " ]);\n", + " }\n", + " },\n", + " write: (data) => {\n", + " state.write({'data': data});\n", + " }\n", + " };\n", + " debug(\"starting userFn\");\n", + " userFn(connection).then(() => {\n", + " debug(\"userFn finished\");\n", + " ctrl.abort(new Error(\"end of input\"));\n", + " state.write({'finished': true});\n", + " },\n", + " (e) => {\n", + " debug(\"userFn error\", e);\n", + " console.error(\"Stream function failed\", e);\n", + " ctrl.abort(e);\n", + " state.write({'error': '' + e});\n", + " });\n", + "};\n", + "\"\"\"\n", + "\n", + "\n", + "class Connection:\n", + "\n", + " def __init__(self):\n", + " self._recv = []\n", + " self._on_recv_ready = asyncio.Event()\n", + " self._send = []\n", + " self._on_done = asyncio.Future()\n", + "\n", + " async def write(self, data):\n", + " self._send.append({\"data\": data})\n", + "\n", + " async def read(self):\n", + " while not self._on_done.done() and not self._recv:\n", + " self._on_recv_ready.clear()\n", + " await self._on_recv_ready.wait()\n", + " # print(\"read, done waiting: \", self._recv, self._on_done)\n", + " if self._on_done.done() and self._on_done.exception() is not None:\n", + " raise self._on_done.exception()\n", + " elif self._recv:\n", + " return self._recv.pop(0)\n", + " else:\n", + " return EOFError(\"End of stream\")\n", + "\n", + " def _poll(self):\n", + " # Polling is needed as ipykernel has blocking mainloop\n", + " # (Comms do not work)\n", + " # print(\"calling poll\")\n", + " res = output.eval_js(f\"window._js_session_on_poll({json.dumps(self._send)})\")\n", + " # print(\"poll: \", res)\n", + " self._send = []\n", + " for r in res:\n", + " if \"data\" in r:\n", + " self._recv.append(r[\"data\"])\n", + " self._on_recv_ready.set()\n", + " elif \"error\" in r:\n", + " self._on_done.set_exception(Exception(\"Remote error: \" + r[\"error\"]))\n", + " self._on_recv_ready.set()\n", + " elif \"finished\" in r:\n", + " self._on_done.set_result(None)\n", + " self._on_recv_ready.set()\n", + "\n", + " async def _pump(self, pump_interval):\n", + " while not self._on_done.done():\n", + " self._poll()\n", + " await asyncio.sleep(pump_interval)\n", + "\n", + "\n", + "@contextlib.asynccontextmanager\n", + "async def RunningLiveJs(userCode, pump_interval=0.1):\n", + " \"\"\"Runs given javascript async code connecting it to colab.\n", + "\n", + " Use .write(msg) and .read() methods on this context manager\n", + " to exchange messages with JavaScript code.\n", + "\n", + " From JavaScript use 'connection.write(data)'\n", + " and 'await connection.read()' to exchange messages with colab.\n", + " \"\"\"\n", + " c = Connection()\n", + " output.eval_js(\n", + " f\"\"\"\n", + " let userFn = async (connection) => {{\n", + " {userCode}\n", + " }};\n", + " {_start_session_js};\n", + " start_session(userFn);\n", + " 1;\n", + " \"\"\",\n", + " ignore_result=True,\n", + " )\n", + " t = asyncio.create_task(c._pump(pump_interval))\n", + "\n", + " def log_error(f):\n", + " if f.exception() is not None:\n", + " print(\"error: \", f.exception())\n", + "\n", + " t.add_done_callback(log_error)\n", + " try:\n", + " yield c\n", + " finally:\n", + " t.cancel()\n", + " output.eval_js(\n", + " \"\"\"window._js_session_on_poll([{finish: true}]);\"\"\", ignore_result=True\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lR8Pbzu6pcda" + }, + "outputs": [], + "source": [ + "# @title Inline copy of colab_audio {display-mode: 'form'}\n", + "\n", + "\"\"\"Realtime Audio I/O support.\n", + "\n", + "Example use:\n", + "\n", + " async with colab_audio.RunningLiveAudio() as audio:\n", + " bytes_per_second = audio.config.sample_rate * audio.config.frame_size\n", + " print ('recording (3sec)')\n", + " buf = b''\n", + " while len(buf) < 3*bytes_per_second:\n", + " buf += await audio.read()\n", + " print ('playing')\n", + " await audio.enqueue(buf)\n", + " await asyncio.sleep(3)\n", + " print ('done')\n", + " display.display(colab_audio.Audio(audio.config, buf))\n", + "\"\"\"\n", + "\n", + "import asyncio\n", + "import base64\n", + "from collections.abc import AsyncIterator\n", + "import contextlib\n", + "import dataclasses\n", + "import io\n", + "import json\n", + "import time\n", + "import wave\n", + "import numpy as np\n", + "\n", + "\n", + "@dataclasses.dataclass(frozen=True)\n", + "class AudioConfig:\n", + " \"\"\"Configuration of audio stream.\"\"\"\n", + "\n", + " sample_rate: int\n", + " format: str = \"S16_LE\" # only supported value\n", + " channels: int = 1 # only supported value\n", + "\n", + " @property\n", + " def sample_size(self) -> int:\n", + " assert self.format == \"S16_LE\"\n", + " return 2\n", + "\n", + " @property\n", + " def frame_size(self) -> int:\n", + " return self.channels * self.sample_size\n", + "\n", + " @property\n", + " def numpy_dtype(self) -> np.dtype:\n", + " assert self.format == \"S16_LE\"\n", + " return np.dtype(np.int16).newbyteorder(\"<\")\n", + "\n", + "\n", + "@dataclasses.dataclass(frozen=True)\n", + "class Audio:\n", + " \"\"\"Unit of audio data with configuration.\"\"\"\n", + "\n", + " config: AudioConfig\n", + " data: bytes\n", + "\n", + " @staticmethod\n", + " def silence(config: AudioConfig, length_seconds: float | int) -> \"Audio\":\n", + " frame = b\"\\0\" * config.frame_size\n", + " num_frames = int(length_seconds * config.sample_rate)\n", + " if num_frames < 0:\n", + " num_frames = 0\n", + " return Audio(config=config, data=frame * num_frames)\n", + "\n", + " def as_numpy(self):\n", + " return np.frombuffer(self.data, dtype=self.config.numpy_dtype)\n", + "\n", + " def as_wav_bytes(self) -> bytes:\n", + " buf = io.BytesIO()\n", + " with wave.open(buf, \"w\") as wav:\n", + " wav.setnchannels(self.config.channels)\n", + " wav.setframerate(self.config.sample_rate)\n", + " assert self.config.format == \"S16_LE\"\n", + " wav.setsampwidth(2) # 16bit\n", + " wav.writeframes(self.data)\n", + " return buf.getvalue()\n", + "\n", + " def _ipython_display_(self):\n", + " \"\"\"Hook displaying audio as HTML tag.\"\"\"\n", + " from IPython.display import display, HTML\n", + "\n", + " b64_wav = base64.b64encode(self.as_wav_bytes()).decode(\"utf-8\")\n", + " display(\n", + " HTML(\n", + " f\"\"\"\n", + " \n", + " \"\"\".strip()\n", + " )\n", + " )\n", + "\n", + " async def astream_realtime(\n", + " self, expected_delta_sec: float = 0.1\n", + " ) -> AsyncIterator[bytes]:\n", + " \"\"\"Yields audio data in chunks as if it was played realtime.\"\"\"\n", + " current_pos = 0\n", + " mono_start_ns = time.monotonic_ns()\n", + " while current_pos < len(self.data):\n", + " # print('sleep')\n", + " await asyncio.sleep(expected_delta_sec)\n", + " delta_ns = time.monotonic_ns() - mono_start_ns\n", + " expected_pos_frames = int(delta_ns * self.config.sample_rate / 1e9)\n", + " next_pos = expected_pos_frames * self.config.frame_size\n", + " # print (f'{next_pos = }, {current_pos =}, {len(self.data) = }')\n", + " if next_pos > current_pos:\n", + " yield self.data[current_pos:next_pos]\n", + " current_pos = next_pos\n", + "\n", + " def __add__(self, other: \"Audio\") -> \"Audio\":\n", + " assert self.config == other.config\n", + " return Audio(config=self.config, data=self.data + other.data)\n", + "\n", + "\n", + "class FailedToStartError(Exception):\n", + " \"\"\"Raised when audio session fails to start.\"\"\"\n", + "\n", + "\n", + "class AudioSession:\n", + " \"\"\"Connection to audio recording/playback on client side.\"\"\"\n", + "\n", + " def __init__(self, config: AudioConfig, connection: Connection):\n", + " self._config = config\n", + " self._connection = connection\n", + " self._done = False\n", + " self._read_queue: asyncio.Queue[bytes] = asyncio.Queue()\n", + " self._started = asyncio.Future()\n", + "\n", + " @property\n", + " def config(self) -> AudioConfig:\n", + " return self._config\n", + "\n", + " async def await_start(self):\n", + " await self._started\n", + "\n", + " async def _read_loop(self):\n", + " # print ('read_loop')\n", + " while True:\n", + " # print ('await read')\n", + " data = await self._connection.read()\n", + " # print(\"data\", data)\n", + " if \"audio_in\" in data:\n", + " # print(\"audio_in\", data['audio_in'])\n", + " raw_data = base64.b64decode(data[\"audio_in\"].encode(\"utf-8\"))\n", + " # print(\"audio_in\", raw_data)\n", + " self._read_queue.put_nowait(raw_data)\n", + " if \"started\" in data:\n", + " self._started.set_result(None)\n", + " if \"failed_to_start\" in data:\n", + " self._started.set_exception(\n", + " FailedToStartError(\n", + " f'Failed to start audio: {data[\"failed_to_start\"]}'\n", + " )\n", + " )\n", + "\n", + " async def enqueue(self, audio_data: bytes):\n", + " b64_data = base64.b64encode(audio_data).decode(\"utf-8\")\n", + " await self._connection.write({\"audio_out\": b64_data})\n", + "\n", + " async def clear_queue(self):\n", + " await self._connection.write({\"flush\": True})\n", + "\n", + " async def read(self) -> bytes:\n", + " return await self._read_queue.get()\n", + "\n", + "\n", + "STANDARD_AUDIO_CONFIG = AudioConfig(sample_rate=16000, channels=1)\n", + "\n", + "\n", + "# JavaScript code running in AudioWorklet, executing realtime audio processing.\n", + "_audio_processor_worklet_js = \"\"\"\n", + "class PortProcessor extends AudioWorkletProcessor {\n", + " constructor() {\n", + " super();\n", + " this._queue = [];\n", + " this.port.onmessage = (event) => {\n", + " //console.log(event.data);\n", + " if ('enqueue' in event.data) {\n", + " this.enqueueAudio(event.data.enqueue);\n", + " }\n", + " if ('clear' in event.data) {\n", + " this.clearAudio();\n", + " }\n", + " };\n", + " this._out = [];\n", + " this._out_len = 0;\n", + " console.log(\"PortProcessor ctor\", this);\n", + "\n", + " this.port.postMessage({\n", + " debug: \"Hello from the processor!\",\n", + " });\n", + " }\n", + "\n", + " encodeAudio(input) {\n", + " const channel = input[0];\n", + " const data = new ArrayBuffer(2 * channel.length);\n", + " const view = new DataView(data);\n", + " for (let i=0; i (2*sampleRate / 20)) {\n", + " let concat = new Uint8Array(this._out_len);\n", + " let idx = 0;\n", + " for (let a of this._out) {\n", + " concat.set(new Uint8Array(a), idx);\n", + " idx += a.byteLength;\n", + " }\n", + " this._out = [];\n", + " this._out_len = 0;\n", + " this.port.postMessage({\n", + " 'audio_in': concat.buffer,\n", + " });\n", + " }\n", + "\n", + " // forward output\n", + " this.dequeueIntoBuffer(outputs[0][0]);\n", + " // copy to other channels\n", + " for (let i=1; i {\n", + " if ('audio_in' in event.data) {\n", + " // base64 encode ugly way\n", + " let encoded = btoa(String.fromCharCode(\n", + " ...Array.from(new Uint8Array(event.data.audio_in))));\n", + " //console.log(\"base64 input\", encoded);\n", + " connection.write({audio_in: encoded});\n", + " } else {\n", + " console.log(\"from processor (unhandled)\", event);\n", + " }\n", + " };\n", + " source.connect(processor);\n", + " processor.connect(audioCtx.destination);\n", + " //await new Promise((acc) => setTimeout(acc, 1000));\n", + " while(!connection.signal.aborted) {\n", + " let request = await connection.read();\n", + " //console.log(request);\n", + " if ('audio_out' in request) {\n", + " let decoded = Uint8Array.from(\n", + " atob(request.audio_out), c => c.charCodeAt(0)).buffer;\n", + " //console.log('Enqueue', decoded);\n", + " processor.port.postMessage({'enqueue': decoded});\n", + " } else if('flush' in request) {\n", + " processor.port.postMessage({'clear': ''});\n", + " }\n", + " }\n", + "} finally {\n", + " userMedia.getTracks().forEach(t => t.stop());\n", + " audioCtx.close();\n", + "}\n", + "\"\"\"\n", + "\n", + "\n", + "@contextlib.asynccontextmanager\n", + "async def RunningLiveAudio(\n", + " config: AudioConfig = STANDARD_AUDIO_CONFIG, pump_interval=0.1\n", + "):\n", + " \"\"\"Runs audio connection to Colab UI and returns `AudioConnection` connected to it.\"\"\"\n", + " assert config.channels == 1\n", + " assert config.format == \"S16_LE\"\n", + " required_js = f\"\"\"\n", + " const audio_worklet_js = {json.dumps(_audio_processor_worklet_js)};\n", + " const sample_rate = {json.dumps(config.sample_rate)};\n", + " {_audio_session_js}\n", + " \"\"\"\n", + " try:\n", + " async with contextlib.AsyncExitStack() as stack:\n", + " tg = await stack.enter_async_context(asyncio.TaskGroup())\n", + " connection = await stack.enter_async_context(\n", + " RunningLiveJs(required_js, pump_interval)\n", + " )\n", + " session = AudioSession(config, connection)\n", + " read_task = tg.create_task(session._read_loop()) # copy data to queue\n", + " tg.create_task(session.await_start()) # fail session if it fails to start\n", + " yield session\n", + " read_task.cancel()\n", + " except asyncio.ExceptionGroup as e:\n", + " if len(e.exceptions) == 1:\n", + " raise e.exceptions[0]\n", + " else:\n", + " raise" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bLLTAfUjpBJ4" + }, + "source": [ + "## Run the client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2leQzHwTqIOM" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "bLLTAfUjpBJ4" - }, - "source": [ - "## Run the client" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "unhandled message: {'setupComplete': {}}\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2leQzHwTqIOM" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "unhandled message: {'setupComplete': {}}\n" - ] - }, - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } + "data": { + "text/html": [ + "" ], - "source": [ - "# @title Client implementation {display-mode: 'form'}\n", - "# @markdown This cell runs client connection to BidiGenerate with realtime audio I/O\n", - "from websockets.asyncio.client import connect\n", - "import asyncio\n", - "import contextlib\n", - "import base64\n", - "import json\n", - "\n", - "\n", - "HOST = 'generativelanguage.googleapis.com' # @param {type:'string'}\n", - "API_KEY = GOOGLE_API_KEY\n", - "MODEL = 'models/gemini-2.0-flash-exp' # @param {type:'string'}\n", - "INITIAL_REQUEST_TEXT = \"what's up?\" # @param {type:'string'}\n", - "\n", - "\n", - "def encode_audio_input(data: bytes, config: AudioConfig) -> dict:\n", - " \"\"\"Build JSPB message with user input audio bytes.\"\"\"\n", - " return {\n", - " 'realtimeInput': {\n", - " 'mediaChunks': [{\n", - " 'mimeType': f'audio/pcm;rate={config.sample_rate}',\n", - " 'data': base64.b64encode(data).decode('UTF-8'),\n", - " }],\n", - " },\n", - " }\n", - "\n", - "\n", - "def encode_text_input(text: str) -> dict:\n", - " \"\"\"Builds JSPB message with user input text.\"\"\"\n", - " return {\n", - " 'clientContent': {\n", - " 'turns': [{\n", - " 'role': 'USER',\n", - " 'parts': [{'text': text}],\n", - " }],\n", - " 'turnComplete': True,\n", - " },\n", - " }\n", - "\n", - "\n", - "def decode_audio_output(input: dict) -> bytes:\n", - " \"\"\"Returns byte string with model output audio.\"\"\"\n", - " result = []\n", - " content_input = input.get('serverContent', {})\n", - " content = content_input.get('modelTurn', {})\n", - " for part in content.get('parts', []):\n", - " data = part.get('inlineData', {}).get('data', '')\n", - " if data:\n", - " result.append(base64.b64decode(data))\n", - " return b''.join(result)\n", - "\n", - "\n", - "async def main():\n", - " async with contextlib.AsyncExitStack() as es:\n", - " tg = await es.enter_async_context(asyncio.TaskGroup())\n", - " audio = await es.enter_async_context(RunningLiveAudio(AudioConfig(sample_rate=24000)))\n", - " conn = await es.enter_async_context(connect(f'wss://{HOST}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={API_KEY}'))\n", - " print('')\n", - "\n", - " initial_request = {\n", - " 'setup': {\n", - " 'model': MODEL,\n", - " },\n", - " }\n", - " await conn.send(json.dumps(initial_request))\n", - "\n", - " if text := INITIAL_REQUEST_TEXT:\n", - " await conn.send(json.dumps(encode_text_input(text)))\n", - "\n", - " async def send_audio():\n", - " while True:\n", - " data = await audio.read()\n", - " await conn.send(json.dumps(encode_audio_input(data, audio.config)))\n", - "\n", - " tg.create_task(send_audio())\n", - " enqueued_audio = []\n", - " async for msg in conn:\n", - " msg = json.loads(msg)\n", - " if to_play := decode_audio_output(msg):\n", - " enqueued_audio.append(to_play)\n", - " await audio.enqueue(to_play) # enqueue TTS\n", - " elif 'interrupted' in msg.get('serverContent', {}):\n", - " print('')\n", - " await audio.clear_queue() # stop TTS\n", - " elif 'turnComplete' in msg.get('serverContent', {}):\n", - " if enqueued_audio: # display it for later playback\n", - " display.display(Audio(config=audio.config, data=b''.join(enqueued_audio)))\n", - " enqueued_audio = []\n", - " print('')\n", - " else:\n", - " if msg != {'serverContent': {}}:\n", - " print(f'unhandled message: {msg}')\n", - "\n", - "try:\n", - " await main()\n", - "except asyncio.ExceptionGroup as e:\n", - " raise e.exceptions[0]" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "QzmXBx1cBowf" - }, - "source": [ - "**Important**: On first try it will typically throw an error and ask for permission to record audio, if that happens allow audio, and **run it again**." - ] - } - ], - "metadata": { - "colab": { - "name": "live_api_streaming_in_colab.ipynb", - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] } + ], + "source": [ + "# @title Client implementation {display-mode: 'form'}\n", + "# @markdown This cell runs client connection to BidiGenerate with realtime audio I/O\n", + "from websockets.asyncio.client import connect\n", + "import asyncio\n", + "import contextlib\n", + "import base64\n", + "import json\n", + "\n", + "\n", + "HOST = \"generativelanguage.googleapis.com\" # @param {type:'string'}\n", + "API_KEY = GOOGLE_API_KEY\n", + "MODEL = \"models/gemini-2.0-flash-exp\" # @param {type:'string'}\n", + "INITIAL_REQUEST_TEXT = \"what's up?\" # @param {type:'string'}\n", + "\n", + "\n", + "def encode_audio_input(data: bytes, config: AudioConfig) -> dict:\n", + " \"\"\"Build JSPB message with user input audio bytes.\"\"\"\n", + " return {\n", + " \"realtimeInput\": {\n", + " \"mediaChunks\": [\n", + " {\n", + " \"mimeType\": f\"audio/pcm;rate={config.sample_rate}\",\n", + " \"data\": base64.b64encode(data).decode(\"UTF-8\"),\n", + " }\n", + " ],\n", + " },\n", + " }\n", + "\n", + "\n", + "def encode_text_input(text: str) -> dict:\n", + " \"\"\"Builds JSPB message with user input text.\"\"\"\n", + " return {\n", + " \"clientContent\": {\n", + " \"turns\": [\n", + " {\n", + " \"role\": \"USER\",\n", + " \"parts\": [{\"text\": text}],\n", + " }\n", + " ],\n", + " \"turnComplete\": True,\n", + " },\n", + " }\n", + "\n", + "\n", + "def decode_audio_output(input: dict) -> bytes:\n", + " \"\"\"Returns byte string with model output audio.\"\"\"\n", + " result = []\n", + " content_input = input.get(\"serverContent\", {})\n", + " content = content_input.get(\"modelTurn\", {})\n", + " for part in content.get(\"parts\", []):\n", + " data = part.get(\"inlineData\", {}).get(\"data\", \"\")\n", + " if data:\n", + " result.append(base64.b64decode(data))\n", + " return b\"\".join(result)\n", + "\n", + "\n", + "async def main():\n", + " async with contextlib.AsyncExitStack() as es:\n", + " tg = await es.enter_async_context(asyncio.TaskGroup())\n", + " audio = await es.enter_async_context(\n", + " RunningLiveAudio(AudioConfig(sample_rate=24000))\n", + " )\n", + " conn = await es.enter_async_context(\n", + " connect(\n", + " f\"wss://{HOST}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={API_KEY}\"\n", + " )\n", + " )\n", + " print(\"\")\n", + "\n", + " initial_request = {\n", + " \"setup\": {\n", + " \"model\": MODEL,\n", + " },\n", + " }\n", + " await conn.send(json.dumps(initial_request))\n", + "\n", + " if text := INITIAL_REQUEST_TEXT:\n", + " await conn.send(json.dumps(encode_text_input(text)))\n", + "\n", + " async def send_audio():\n", + " while True:\n", + " data = await audio.read()\n", + " await conn.send(json.dumps(encode_audio_input(data, audio.config)))\n", + "\n", + " tg.create_task(send_audio())\n", + " enqueued_audio = []\n", + " async for msg in conn:\n", + " msg = json.loads(msg)\n", + " if to_play := decode_audio_output(msg):\n", + " enqueued_audio.append(to_play)\n", + " await audio.enqueue(to_play) # enqueue TTS\n", + " elif \"interrupted\" in msg.get(\"serverContent\", {}):\n", + " print(\"\")\n", + " await audio.clear_queue() # stop TTS\n", + " elif \"turnComplete\" in msg.get(\"serverContent\", {}):\n", + " if enqueued_audio: # display it for later playback\n", + " display.display(\n", + " Audio(config=audio.config, data=b\"\".join(enqueued_audio))\n", + " )\n", + " enqueued_audio = []\n", + " print(\"\")\n", + " else:\n", + " if msg != {\"serverContent\": {}}:\n", + " print(f\"unhandled message: {msg}\")\n", + "\n", + "\n", + "try:\n", + " await main()\n", + "except asyncio.ExceptionGroup as e:\n", + " raise e.exceptions[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QzmXBx1cBowf" + }, + "source": [ + "**Important**: On first try it will typically throw an error and ask for permission to record audio, if that happens allow audio, and **run it again**." + ] + } + ], + "metadata": { + "colab": { + "name": "live_api_streaming_in_colab.ipynb", + "toc_visible": true }, - "nbformat": 4, - "nbformat_minor": 0 + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/gemini-2/websockets/live_api_tool_use.ipynb b/gemini-2/websockets/live_api_tool_use.ipynb index 064659651..0c47e4914 100644 --- a/gemini-2/websockets/live_api_tool_use.ipynb +++ b/gemini-2/websockets/live_api_tool_use.ipynb @@ -1,833 +1,831 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "jWESX0tpdrE-" - }, - "source": [ - "##### Copyright 2024 Google LLC." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "cellView": "form", - "id": "YQvTrJpxzRlJ" - }, - "outputs": [], - "source": [ - "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3hp_P0cDzTWp" - }, - "source": [ - "# Gemini 2.0 - Multimodal live API tool use with websockets" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OLW8VU78zZOc" - }, - "source": [ - "\n", - " \n", - "
\n", - " Run in Google Colab\n", - "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "y7f4kFby0E6j" - }, - "source": [ - "This notebook provides examples of how to use tools with the Multimodal Live API with the Gemini 2.0 models. The API provides Google Search, Code Execution and Function Calling.\n", - "\n", - "This tutorial assumes you are familiar with the Live API, as described in the [Live API starter tutorial](live_api_starter.ipynb).\n", - "\n", - "Note: This version of the tutorial uses websockets directly. The [SDK version of this tutorial](../live_api_tool_use.ipynb) is a bit simpler because the SDK handles some of the details for you." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "belfJKH-p7HT" - }, - "source": [ - "## Set up" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NSUz31fds3Z9" - }, - "source": [ - "To run the following cell, your API key must be stored in a Colab Secret named `GOOGLE_API_KEY`. If you don't already have an API key, or you're not sure how to create a Colab Secret, see [Authentication](../../quickstarts/Authentication.ipynb) for an example." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "bCwTqSAKsYPI" - }, - "outputs": [], - "source": [ - "from google.colab import userdata\n", - "GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GI4cFArLMoOR" - }, - "source": [ - "Install the `websockets` package:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "HBGg4WPcpvMv" - }, - "outputs": [], - "source": [ - "!pip install -q websockets" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2KFW-aPMMspT" - }, - "source": [ - "import the necessary modules" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "6G-3y03tIJTo" - }, - "outputs": [], - "source": [ - "import asyncio\n", - "import base64\n", - "import contextlib\n", - "import os\n", - "import json\n", - "import wave\n", - "\n", - "from IPython import display\n", - "\n", - "from websockets.asyncio.client import connect" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "YFqqwegvtVZJ" - }, - "outputs": [], - "source": [ - "uri = f\"wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={GOOGLE_API_KEY}\"\n", - "\n", - "model = \"models/gemini-2.0-flash-exp\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x5fC6UNT0Umv" - }, - "source": [ - "Define a context manager to convert streamed PCM data into a wave file that can be played directly using an IPython audio widget." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "eZcinbBdkZ1h" - }, - "outputs": [], - "source": [ - "@contextlib.contextmanager\n", - "def wave_file(filename, channels=1, rate=24000, sample_width=2):\n", - " with wave.open(filename, \"wb\") as wf:\n", - " wf.setnchannels(channels)\n", - " wf.setsampwidth(sample_width)\n", - " wf.setframerate(rate)\n", - " yield wf" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Tb82LM7j0c3_" - }, - "source": [ - "And define a custom logger so you can toggle extra information, like the in-flight requests and responses." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "0JTDa_0N2VPj" - }, - "outputs": [], - "source": [ - "import logging\n", - "\n", - "logger = logging.getLogger(\"Live\")\n", - "logger.setLevel(\"INFO\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FdN7sXRx0lc_" - }, - "source": [ - "These helpers handle the websocket connection and prompt transmission (`run` and `send`), server handshake (`setup`) and process server responses (`handle_server_content`, `handle_tool_call`)." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "lwLZrmW5zR_P" - }, - "outputs": [], - "source": [ - "async def setup(ws, modality, tools):\n", - " setup = {\n", - " \"setup\": {\n", - " \"model\": model,\n", - " \"tools\": tools,\n", - " \"generation_config\": {\n", - " \"response_modalities\": [modality]\n", - " }}}\n", - " await ws.send(json.dumps(setup))\n", - " setup_response = json.loads(await ws.recv())\n", - " logger.debug(setup_response)\n", - "\n", - "async def send(ws, prompt):\n", - " msg = {\n", - " \"client_content\": {\n", - " \"turns\": [{\"role\": \"user\", \"parts\": [{\"text\": prompt}]}],\n", - " \"turn_complete\": True,\n", - " }\n", - " }\n", - " print(\">>> \", msg)\n", - " await ws.send(json.dumps(msg))\n", - "\n", - "\n", - "def handle_server_content(wf, server_content):\n", - " audio = False\n", - " model_turn = server_content.pop(\"modelTurn\", None)\n", - " if model_turn:\n", - " text = model_turn[\"parts\"][0].pop(\"text\", None)\n", - " if text is not None:\n", - " print(text)\n", - "\n", - " inline_data = model_turn['parts'][0].pop('inlineData', None)\n", - " if inline_data is not None:\n", - " print('.', end='')\n", - " b64data = inline_data['data']\n", - " pcm_data = base64.b64decode(b64data)\n", - " wf.writeframes(pcm_data)\n", - " audio = True\n", - "\n", - " turn_complete = server_content.pop('turnComplete', None)\n", - " return turn_complete, audio\n", - "\n", - "\n", - "async def handle_tool_call(ws, tool_call):\n", - " print(\" \", tool_call)\n", - " for fc in tool_call['functionCalls']:\n", - "\n", - " msg = {\n", - " 'tool_response': {\n", - " 'function_responses': [{\n", - " 'id': fc['id'],\n", - " 'name': fc['name'],\n", - " 'response':{'result': {'string_value': 'ok'}}\n", - " }]\n", - " }\n", - " }\n", - " print('>>> ', msg)\n", - " await ws.send(json.dumps(msg))\n", - "\n", - "\n", - "\n", - "async def run(prompt, modality='TEXT', tools=None):\n", - " if tools is None:\n", - " tools=[]\n", - "\n", - " async with (\n", - " connect(uri, additional_headers={\"Content-Type\": \"application/json\"}) as ws,\n", - " ):\n", - " await setup(ws, modality, tools)\n", - " await send(ws, prompt)\n", - "\n", - " audio = False\n", - " filename = 'audio.wav'\n", - " with wave_file(filename) as wf:\n", - " async for raw_response in ws:\n", - " response = json.loads(raw_response.decode())\n", - " logger.debug(str(response)[:150])\n", - "\n", - " server_content = response.pop(\"serverContent\", None)\n", - " if server_content is not None:\n", - " turn_complete, a = handle_server_content(wf, server_content)\n", - " audio = audio or a\n", - " if turn_complete:\n", - " print()\n", - " print('Turn complete')\n", - " break\n", - "\n", - " tool_call = response.pop('toolCall', None)\n", - " if tool_call is not None:\n", - " await handle_tool_call(ws, tool_call)\n", - "\n", - " if audio:\n", - " display.display(display.Audio(filename, autoplay=True))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1DVnwW4P04LO" - }, - "source": [ - "Run a test prompt to ensure everything is set up." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "QWQBQmqQCymg" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Hello?'}]}], 'turn_complete': True}}\n", - "Hello\n", - " there! How can I help you today?\n", - "\n", - "\n", - "Turn complete\n" - ] - } - ], - "source": [ - "await run(prompt='Hello?')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z_BFBLLGp-Ye" - }, - "source": [ - "## Simple function call" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nisn9xRy09CJ" - }, - "source": [ - "Define some stub functions to use in a function calling example." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "vu0YDPRE0Roj" - }, - "outputs": [], - "source": [ - "turn_on_the_lights_schema = {'name': 'turn_on_the_lights'}\n", - "turn_off_the_lights_schema = {'name': 'turn_off_the_lights'}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nVP6aZ_K1Buw" - }, - "source": [ - "Send the function declarations as part of the `tools` (in the generation config)." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "LNE59h4QfUX8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Turn on the lights'}]}], 'turn_complete': True}}\n", - " {'functionCalls': [{'name': 'turn_on_the_lights', 'args': {}, 'id': 'function-call-12001301744900464704'}]}\n", - ">>> {'tool_response': {'function_responses': [{'id': 'function-call-12001301744900464704', 'name': 'turn_on_the_lights', 'response': {'result': {'string_value': 'ok'}}}]}}\n", - "OK\n", - ". The lights are now on.\n", - "\n", - "\n", - "Turn complete\n" - ] - } - ], - "source": [ - "prompt = \"Turn on the lights\"\n", - "\n", - "tools = [\n", - " {'function_declarations': [turn_on_the_lights_schema, turn_off_the_lights_schema]}\n", - "]\n", - "\n", - "await run(prompt, tools=tools)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "e5lktqsx1KMc" - }, - "source": [ - "Try the same thing again, but using audio-out this time." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "8dCjPmz8nEbv" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Turn on the lights'}]}], 'turn_complete': True}}\n", - " {'functionCalls': [{'name': 'turn_on_the_lights', 'args': {}, 'id': 'function-call-15210722669530560737'}]}\n", - ">>> {'tool_response': {'function_responses': [{'id': 'function-call-15210722669530560737', 'name': 'turn_on_the_lights', 'response': {'result': {'string_value': 'ok'}}}]}}\n", - "..........\n", - "Turn complete\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "prompt = \"Turn on the lights\"\n", - "\n", - "tools = [\n", - " {'function_declarations': [turn_on_the_lights_schema, turn_off_the_lights_schema]}\n", - "]\n", - "\n", - "await run(prompt, tools=tools, modality = \"AUDIO\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eCnCiTbhqE8q" - }, - "source": [ - "## Code execution" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MTatLRXi1Qni" - }, - "source": [ - "The API can generate and execute code during the conversation too." - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "k4dURhC-QoSw" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'What is the largest prime palindrome under 100000'}]}], 'turn_complete': True}}\n", - "......................................................\n", - "Turn complete\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "jWESX0tpdrE-" + }, + "source": [ + "##### Copyright 2024 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "cellView": "form", + "id": "YQvTrJpxzRlJ" + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3hp_P0cDzTWp" + }, + "source": [ + "# Gemini 2.0 - Multimodal live API tool use with websockets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OLW8VU78zZOc" + }, + "source": [ + "\n", + " \n", + "
\n", + " Run in Google Colab\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y7f4kFby0E6j" + }, + "source": [ + "This notebook provides examples of how to use tools with the Multimodal Live API with the Gemini 2.0 models. The API provides Google Search, Code Execution and Function Calling.\n", + "\n", + "This tutorial assumes you are familiar with the Live API, as described in the [Live API starter tutorial](live_api_starter.ipynb).\n", + "\n", + "Note: This version of the tutorial uses websockets directly. The [SDK version of this tutorial](../live_api_tool_use.ipynb) is a bit simpler because the SDK handles some of the details for you." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "belfJKH-p7HT" + }, + "source": [ + "## Set up" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NSUz31fds3Z9" + }, + "source": [ + "To run the following cell, your API key must be stored in a Colab Secret named `GOOGLE_API_KEY`. If you don't already have an API key, or you're not sure how to create a Colab Secret, see [Authentication](../../quickstarts/Authentication.ipynb) for an example." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "bCwTqSAKsYPI" + }, + "outputs": [], + "source": [ + "from google.colab import userdata\n", + "\n", + "GOOGLE_API_KEY = userdata.get(\"GOOGLE_API_KEY\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GI4cFArLMoOR" + }, + "source": [ + "Install the `websockets` package:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "HBGg4WPcpvMv" + }, + "outputs": [], + "source": [ + "!pip install -q websockets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2KFW-aPMMspT" + }, + "source": [ + "import the necessary modules" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "6G-3y03tIJTo" + }, + "outputs": [], + "source": [ + "import asyncio\n", + "import base64\n", + "import contextlib\n", + "import os\n", + "import json\n", + "import wave\n", + "\n", + "from IPython import display\n", + "\n", + "from websockets.asyncio.client import connect" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "YFqqwegvtVZJ" + }, + "outputs": [], + "source": [ + "uri = f\"wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={GOOGLE_API_KEY}\"\n", + "\n", + "model = \"models/gemini-2.0-flash-exp\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x5fC6UNT0Umv" + }, + "source": [ + "Define a context manager to convert streamed PCM data into a wave file that can be played directly using an IPython audio widget." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "eZcinbBdkZ1h" + }, + "outputs": [], + "source": [ + "@contextlib.contextmanager\n", + "def wave_file(filename, channels=1, rate=24000, sample_width=2):\n", + " with wave.open(filename, \"wb\") as wf:\n", + " wf.setnchannels(channels)\n", + " wf.setsampwidth(sample_width)\n", + " wf.setframerate(rate)\n", + " yield wf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Tb82LM7j0c3_" + }, + "source": [ + "And define a custom logger so you can toggle extra information, like the in-flight requests and responses." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "0JTDa_0N2VPj" + }, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logger = logging.getLogger(\"Live\")\n", + "logger.setLevel(\"INFO\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FdN7sXRx0lc_" + }, + "source": [ + "These helpers handle the websocket connection and prompt transmission (`run` and `send`), server handshake (`setup`) and process server responses (`handle_server_content`, `handle_tool_call`)." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "lwLZrmW5zR_P" + }, + "outputs": [], + "source": [ + "async def setup(ws, modality, tools):\n", + " setup = {\n", + " \"setup\": {\n", + " \"model\": model,\n", + " \"tools\": tools,\n", + " \"generation_config\": {\"response_modalities\": [modality]},\n", + " }\n", + " }\n", + " await ws.send(json.dumps(setup))\n", + " setup_response = json.loads(await ws.recv())\n", + " logger.debug(setup_response)\n", + "\n", + "\n", + "async def send(ws, prompt):\n", + " msg = {\n", + " \"client_content\": {\n", + " \"turns\": [{\"role\": \"user\", \"parts\": [{\"text\": prompt}]}],\n", + " \"turn_complete\": True,\n", + " }\n", + " }\n", + " print(\">>> \", msg)\n", + " await ws.send(json.dumps(msg))\n", + "\n", + "\n", + "def handle_server_content(wf, server_content):\n", + " audio = False\n", + " model_turn = server_content.pop(\"modelTurn\", None)\n", + " if model_turn:\n", + " text = model_turn[\"parts\"][0].pop(\"text\", None)\n", + " if text is not None:\n", + " print(text)\n", + "\n", + " inline_data = model_turn[\"parts\"][0].pop(\"inlineData\", None)\n", + " if inline_data is not None:\n", + " print(\".\", end=\"\")\n", + " b64data = inline_data[\"data\"]\n", + " pcm_data = base64.b64decode(b64data)\n", + " wf.writeframes(pcm_data)\n", + " audio = True\n", + "\n", + " turn_complete = server_content.pop(\"turnComplete\", None)\n", + " return turn_complete, audio\n", + "\n", + "\n", + "async def handle_tool_call(ws, tool_call):\n", + " print(\" \", tool_call)\n", + " for fc in tool_call[\"functionCalls\"]:\n", + "\n", + " msg = {\n", + " \"tool_response\": {\n", + " \"function_responses\": [\n", + " {\n", + " \"id\": fc[\"id\"],\n", + " \"name\": fc[\"name\"],\n", + " \"response\": {\"result\": {\"string_value\": \"ok\"}},\n", + " }\n", + " ]\n", + " }\n", + " }\n", + " print(\">>> \", msg)\n", + " await ws.send(json.dumps(msg))\n", + "\n", + "\n", + "async def run(prompt, modality=\"TEXT\", tools=None):\n", + " if tools is None:\n", + " tools = []\n", + "\n", + " async with (\n", + " connect(uri, additional_headers={\"Content-Type\": \"application/json\"}) as ws,\n", + " ):\n", + " await setup(ws, modality, tools)\n", + " await send(ws, prompt)\n", + "\n", + " audio = False\n", + " filename = \"audio.wav\"\n", + " with wave_file(filename) as wf:\n", + " async for raw_response in ws:\n", + " response = json.loads(raw_response.decode())\n", + " logger.debug(str(response)[:150])\n", + "\n", + " server_content = response.pop(\"serverContent\", None)\n", + " if server_content is not None:\n", + " turn_complete, a = handle_server_content(wf, server_content)\n", + " audio = audio or a\n", + " if turn_complete:\n", + " print()\n", + " print(\"Turn complete\")\n", + " break\n", + "\n", + " tool_call = response.pop(\"toolCall\", None)\n", + " if tool_call is not None:\n", + " await handle_tool_call(ws, tool_call)\n", + "\n", + " if audio:\n", + " display.display(display.Audio(filename, autoplay=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1DVnwW4P04LO" + }, + "source": [ + "Run a test prompt to ensure everything is set up." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "QWQBQmqQCymg" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Hello?'}]}], 'turn_complete': True}}\n", + "Hello\n", + " there! How can I help you today?\n", + "\n", + "\n", + "Turn complete\n" + ] + } + ], + "source": [ + "await run(prompt=\"Hello?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z_BFBLLGp-Ye" + }, + "source": [ + "## Simple function call" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nisn9xRy09CJ" + }, + "source": [ + "Define some stub functions to use in a function calling example." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "vu0YDPRE0Roj" + }, + "outputs": [], + "source": [ + "turn_on_the_lights_schema = {\"name\": \"turn_on_the_lights\"}\n", + "turn_off_the_lights_schema = {\"name\": \"turn_off_the_lights\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nVP6aZ_K1Buw" + }, + "source": [ + "Send the function declarations as part of the `tools` (in the generation config)." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "LNE59h4QfUX8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Turn on the lights'}]}], 'turn_complete': True}}\n", + " {'functionCalls': [{'name': 'turn_on_the_lights', 'args': {}, 'id': 'function-call-12001301744900464704'}]}\n", + ">>> {'tool_response': {'function_responses': [{'id': 'function-call-12001301744900464704', 'name': 'turn_on_the_lights', 'response': {'result': {'string_value': 'ok'}}}]}}\n", + "OK\n", + ". The lights are now on.\n", + "\n", + "\n", + "Turn complete\n" + ] + } + ], + "source": [ + "prompt = \"Turn on the lights\"\n", + "\n", + "tools = [\n", + " {\"function_declarations\": [turn_on_the_lights_schema, turn_off_the_lights_schema]}\n", + "]\n", + "\n", + "await run(prompt, tools=tools)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e5lktqsx1KMc" + }, + "source": [ + "Try the same thing again, but using audio-out this time." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "8dCjPmz8nEbv" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Turn on the lights'}]}], 'turn_complete': True}}\n", + " {'functionCalls': [{'name': 'turn_on_the_lights', 'args': {}, 'id': 'function-call-15210722669530560737'}]}\n", + ">>> {'tool_response': {'function_responses': [{'id': 'function-call-15210722669530560737', 'name': 'turn_on_the_lights', 'response': {'result': {'string_value': 'ok'}}}]}}\n", + "..........\n", + "Turn complete\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "prompt=\"What is the largest prime palindrome under 100000\"\n", - "\n", - "tools = [\n", - " {'code_execution': {}}\n", - "]\n", - "\n", - "await run(prompt, tools=tools, modality='AUDIO')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "G78sxDEcqHyO" - }, - "source": [ - "## Google search" + "text/plain": [ + "" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eHH7jpk11VC_" - }, - "source": [ - "A `google_search` tool is also available for use during live conversations." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "dYcAVEqpDW6D" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Can you use google search tell me about the largest earthquake in california the week of Dec 5 2024?'}]}], 'turn_complete': True}}\n", - "\n", - "\n", - "The\n", - " largest earthquake in California during the week of December 5, 202\n", - "4, was a magnitude 7.0 that occurred offshore of Cape Mendoc\n", - "ino on December 5, 2024, at 10:44 AM PT. The earthquake was located approximately 60 miles (\n", - "100km) southwest of Ferndale, in Northern California. The earthquake prompted a brief tsunami warning for coastal areas from Santa Cruz into Oregon, which was\n", - " canceled shortly after.\n", - "\n", - "Here's a summary of the key details:\n", - "\n", - "* **Magnitude:** 7.0\n", - "* **Date:** December 5, 2024\n", - "* **Time:** 1\n", - "0:44 AM PT\n", - "* **Location:** Approximately 60 miles (100km) southwest of Ferndale, California, offshore of Cape Mendocino\n", - "* **Tsunami Warning:** A tsunami warning was\n", - " issued but later canceled.\n", - "* **Aftershocks:** There were many aftershocks following the main earthquake, with the strongest being a magnitude 4.7, which occurred about two minutes after the main shock. By the next day, there had been about 200 aftershocks with 9\n", - " of them measuring 4.0 or greater.\n", - "* **Impact:** The earthquake caused shaking that was felt as far away as San Francisco, and caused some minor damage like broken windows and burst pipes.\n", - "\n", - "The area where the earthquake occurred is known to be seismically active, as it is near the Mendoc\n", - "ino Triple Junction, where three tectonic plates meet. Five of the eleven earthquakes with a magnitude of 7 or greater since 1900 have occurred in this area.\n", - "\n", - "\n", - "Turn complete\n" - ] - } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "prompt = \"Turn on the lights\"\n", + "\n", + "tools = [\n", + " {\"function_declarations\": [turn_on_the_lights_schema, turn_off_the_lights_schema]}\n", + "]\n", + "\n", + "await run(prompt, tools=tools, modality=\"AUDIO\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eCnCiTbhqE8q" + }, + "source": [ + "## Code execution" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MTatLRXi1Qni" + }, + "source": [ + "The API can generate and execute code during the conversation too." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "k4dURhC-QoSw" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'What is the largest prime palindrome under 100000'}]}], 'turn_complete': True}}\n", + "......................................................\n", + "Turn complete\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "prompt=\"Can you use google search tell me about the largest earthquake in california the week of Dec 5 2024?\"\n", - "\n", - "tools = [\n", - " {'google_search': {}}\n", - "]\n", - "\n", - "await run(prompt, tools=tools, modality='TEXT')" + "text/plain": [ + "" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "y-S9dBPu1cNk" - }, - "source": [ - "Try the same again, with audio." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "QKvWzROJic60" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Can you use google search tell me about the largest earthquake in california the week of Dec 5 2024?'}]}], 'turn_complete': True}}\n", - "...............................................................................................................................................................................................\n", - "Turn complete\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "prompt = \"What is the largest prime palindrome under 100000\"\n", + "\n", + "tools = [{\"code_execution\": {}}]\n", + "\n", + "await run(prompt, tools=tools, modality=\"AUDIO\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G78sxDEcqHyO" + }, + "source": [ + "## Google search" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eHH7jpk11VC_" + }, + "source": [ + "A `google_search` tool is also available for use during live conversations." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "dYcAVEqpDW6D" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Can you use google search tell me about the largest earthquake in california the week of Dec 5 2024?'}]}], 'turn_complete': True}}\n", + "\n", + "\n", + "The\n", + " largest earthquake in California during the week of December 5, 202\n", + "4, was a magnitude 7.0 that occurred offshore of Cape Mendoc\n", + "ino on December 5, 2024, at 10:44 AM PT. The earthquake was located approximately 60 miles (\n", + "100km) southwest of Ferndale, in Northern California. The earthquake prompted a brief tsunami warning for coastal areas from Santa Cruz into Oregon, which was\n", + " canceled shortly after.\n", + "\n", + "Here's a summary of the key details:\n", + "\n", + "* **Magnitude:** 7.0\n", + "* **Date:** December 5, 2024\n", + "* **Time:** 1\n", + "0:44 AM PT\n", + "* **Location:** Approximately 60 miles (100km) southwest of Ferndale, California, offshore of Cape Mendocino\n", + "* **Tsunami Warning:** A tsunami warning was\n", + " issued but later canceled.\n", + "* **Aftershocks:** There were many aftershocks following the main earthquake, with the strongest being a magnitude 4.7, which occurred about two minutes after the main shock. By the next day, there had been about 200 aftershocks with 9\n", + " of them measuring 4.0 or greater.\n", + "* **Impact:** The earthquake caused shaking that was felt as far away as San Francisco, and caused some minor damage like broken windows and burst pipes.\n", + "\n", + "The area where the earthquake occurred is known to be seismically active, as it is near the Mendoc\n", + "ino Triple Junction, where three tectonic plates meet. Five of the eleven earthquakes with a magnitude of 7 or greater since 1900 have occurred in this area.\n", + "\n", + "\n", + "Turn complete\n" + ] + } + ], + "source": [ + "prompt = \"Can you use google search tell me about the largest earthquake in california the week of Dec 5 2024?\"\n", + "\n", + "tools = [{\"google_search\": {}}]\n", + "\n", + "await run(prompt, tools=tools, modality=\"TEXT\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y-S9dBPu1cNk" + }, + "source": [ + "Try the same again, with audio." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "QKvWzROJic60" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': 'Can you use google search tell me about the largest earthquake in california the week of Dec 5 2024?'}]}], 'turn_complete': True}}\n", + "...............................................................................................................................................................................................\n", + "Turn complete\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "prompt=\"Can you use google search tell me about the largest earthquake in california the week of Dec 5 2024?\"\n", - "\n", - "tools = [\n", - " {'google_search': {}}\n", - "]\n", - "\n", - "await run(prompt, tools=tools, modality='AUDIO')" + "text/plain": [ + "" ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qu6ZZwJs2AYC" - }, - "source": [ - "## Compositional Function Calling\n", - "\n", - "Compositional function calling allows you to ask the model to use your provided functions in generated code. In this example, you can test this by asking for a `sleep` before calling the provided tool." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "id": "hpuGzKXZ2DaN" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': '\\n Hey, can you write run some python code to turn on the lights, wait 10s and then turn off the lights?\\n '}]}], 'turn_complete': True}}\n", - ".......................................\n", - "Turn complete\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Elapsed: 10.031386852264404s\n" - ] - } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "prompt = \"Can you use google search tell me about the largest earthquake in california the week of Dec 5 2024?\"\n", + "\n", + "tools = [{\"google_search\": {}}]\n", + "\n", + "await run(prompt, tools=tools, modality=\"AUDIO\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qu6ZZwJs2AYC" + }, + "source": [ + "## Compositional Function Calling\n", + "\n", + "Compositional function calling allows you to ask the model to use your provided functions in generated code. In this example, you can test this by asking for a `sleep` before calling the provided tool." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "id": "hpuGzKXZ2DaN" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': '\\n Hey, can you write run some python code to turn on the lights, wait 10s and then turn off the lights?\\n '}]}], 'turn_complete': True}}\n", + ".......................................\n", + "Turn complete\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "prompt = \"\"\"\n", - " Hey, can you write run some python code to turn on the lights, wait 10s and then turn off the lights?\n", - " \"\"\"\n", - "\n", - "tools = [\n", - " {'code_execution': {}},\n", - " {'function_declarations': [turn_on_the_lights_schema, turn_off_the_lights_schema]}\n", - "]\n", - "\n", - "import time\n", - "start = time.time()\n", - "await run(prompt, tools=tools, modality=\"AUDIO\")\n", - "end = time.time()\n", - "print(f'Elapsed: {end-start}s')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HM9y5rwfqKfY" - }, - "source": [ - "## Multi-tool\n", - "\n", - "The model can be asked to use multiple tools in a single conversational turn. In this example, a single prompt is used to perform 3 tasks using all 3 provided tools." + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "id": "QmB_4XPOslyA" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': '\\n Hey, I need you to do three things for me.\\n\\n 1. Turn on the lights\\n 2. Then compute the largest prime plaindrome under 100000.\\n 3. Then use google search to lookup unformation about the largest earthquake in california the week of Dec 5 2024?\\n\\n Thanks!\\n '}]}], 'turn_complete': True}}\n", - ". {'functionCalls': [{'name': 'turn_on_the_lights', 'args': {}, 'id': 'function-call-7055392265610039783'}]}\n", - ">>> {'tool_response': {'function_responses': [{'id': 'function-call-7055392265610039783', 'name': 'turn_on_the_lights', 'response': {'result': {'string_value': 'ok'}}}]}}\n", - "......................................................................................................................................................................................................................................................................................................................\n", - "Turn complete\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "name": "stdout", + "output_type": "stream", + "text": [ + "Elapsed: 10.031386852264404s\n" + ] + } + ], + "source": [ + "prompt = \"\"\"\n", + " Hey, can you write run some python code to turn on the lights, wait 10s and then turn off the lights?\n", + " \"\"\"\n", + "\n", + "tools = [\n", + " {\"code_execution\": {}},\n", + " {\"function_declarations\": [turn_on_the_lights_schema, turn_off_the_lights_schema]},\n", + "]\n", + "\n", + "import time\n", + "\n", + "start = time.time()\n", + "await run(prompt, tools=tools, modality=\"AUDIO\")\n", + "end = time.time()\n", + "print(f\"Elapsed: {end-start}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HM9y5rwfqKfY" + }, + "source": [ + "## Multi-tool\n", + "\n", + "The model can be asked to use multiple tools in a single conversational turn. In this example, a single prompt is used to perform 3 tasks using all 3 provided tools." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "QmB_4XPOslyA" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>> {'client_content': {'turns': [{'role': 'user', 'parts': [{'text': '\\n Hey, I need you to do three things for me.\\n\\n 1. Turn on the lights\\n 2. Then compute the largest prime plaindrome under 100000.\\n 3. Then use google search to lookup unformation about the largest earthquake in california the week of Dec 5 2024?\\n\\n Thanks!\\n '}]}], 'turn_complete': True}}\n", + ". {'functionCalls': [{'name': 'turn_on_the_lights', 'args': {}, 'id': 'function-call-7055392265610039783'}]}\n", + ">>> {'tool_response': {'function_responses': [{'id': 'function-call-7055392265610039783', 'name': 'turn_on_the_lights', 'response': {'result': {'string_value': 'ok'}}}]}}\n", + "......................................................................................................................................................................................................................................................................................................................\n", + "Turn complete\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " ], - "source": [ - "prompt = \"\"\"\n", - " Hey, I need you to do three things for me.\n", - "\n", - " 1. Turn on the lights\n", - " 2. Then compute the largest prime plaindrome under 100000.\n", - " 3. Then use google search to lookup unformation about the largest earthquake in california the week of Dec 5 2024?\n", - "\n", - " Thanks!\n", - " \"\"\"\n", - "\n", - "tools = [\n", - " {'google_search': {}},\n", - " {'code_execution': {}},\n", - " {'function_declarations': [turn_on_the_lights_schema, turn_off_the_lights_schema]}\n", - "]\n", - "\n", - "await run(prompt, tools=tools, modality=\"AUDIO\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D0CWk0iTpVx-" - }, - "source": [ - "## Next steps\n", - "\n", - "\n", - "\n", - "This tutorial just shows basic usage of the Live API, using the Python GenAI SDK.\n", - "\n", - "- If you aren't looking for code, and just want to try multimedia streaming use [Live API in Google AI Studio](https://aistudio.google.com/app/live).\n", - "- If you want to see how to setup streaming interruptible audio and video using the Live API and the SDK see the [Audio and Video input Tutorial](../../gemini-2/live_api_starter.py).\n", - "- There is a [Streaming audio in Colab example](../../gemini-2/websockets/live_api_streaming_in_colab.ipynb), but this is more of a **demo**, it's **not optimized for readability**.\n", - "- Other nice Gemini 2.0 examples can also be found in the [Cookbook](https://github.com/google-gemini/cookbook/blob/main/gemini-2/).\n" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" } - ], - "metadata": { - "colab": { - "name": "live_api_tool_use.ipynb", - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } + ], + "source": [ + "prompt = \"\"\"\n", + " Hey, I need you to do three things for me.\n", + "\n", + " 1. Turn on the lights\n", + " 2. Then compute the largest prime plaindrome under 100000.\n", + " 3. Then use google search to lookup unformation about the largest earthquake in california the week of Dec 5 2024?\n", + "\n", + " Thanks!\n", + " \"\"\"\n", + "\n", + "tools = [\n", + " {\"google_search\": {}},\n", + " {\"code_execution\": {}},\n", + " {\"function_declarations\": [turn_on_the_lights_schema, turn_off_the_lights_schema]},\n", + "]\n", + "\n", + "await run(prompt, tools=tools, modality=\"AUDIO\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D0CWk0iTpVx-" + }, + "source": [ + "## Next steps\n", + "\n", + "\n", + "\n", + "This tutorial just shows basic usage of the Live API, using the Python GenAI SDK.\n", + "\n", + "- If you aren't looking for code, and just want to try multimedia streaming use [Live API in Google AI Studio](https://aistudio.google.com/app/live).\n", + "- If you want to see how to setup streaming interruptible audio and video using the Live API and the SDK see the [Audio and Video input Tutorial](../../gemini-2/live_api_starter.py).\n", + "- There is a [Streaming audio in Colab example](../../gemini-2/websockets/live_api_streaming_in_colab.ipynb), but this is more of a **demo**, it's **not optimized for readability**.\n", + "- Other nice Gemini 2.0 examples can also be found in the [Cookbook](https://github.com/google-gemini/cookbook/blob/main/gemini-2/).\n" + ] + } + ], + "metadata": { + "colab": { + "name": "live_api_tool_use.ipynb", + "toc_visible": true }, - "nbformat": 4, - "nbformat_minor": 0 + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 }