From ab82d7513b5437dd60a8251d31e2a3b1f3d429e3 Mon Sep 17 00:00:00 2001 From: Jetha Chan Date: Fri, 6 Dec 2024 02:36:31 +0900 Subject: [PATCH] update with new paligemma2 notebooks --- .../Finetune_PaliGemma_2_with_JAX.ipynb | 1398 +++++++++++++++++ .../Finetune_PaliGemma_2_with_Keras.ipynb | 955 +++++++++++ README.md | 6 + 3 files changed, 2359 insertions(+) create mode 100644 PaliGemma 2/Finetune_PaliGemma_2_with_JAX.ipynb create mode 100644 PaliGemma 2/Finetune_PaliGemma_2_with_Keras.ipynb diff --git a/PaliGemma 2/Finetune_PaliGemma_2_with_JAX.ipynb b/PaliGemma 2/Finetune_PaliGemma_2_with_JAX.ipynb new file mode 100644 index 0000000..6ff5f82 --- /dev/null +++ b/PaliGemma 2/Finetune_PaliGemma_2_with_JAX.ipynb @@ -0,0 +1,1398 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "be34d25b", + "metadata": { + "id": "8377c056591f" + }, + "source": [ + "Copyright 2024 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6130c8e6", + "metadata": { + "cellView": "form", + "id": "ca23c3f523a7" + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "880b5dcf", + "metadata": { + "id": "u71STQRgnQ3a" + }, + "source": [ + "# Fine-tune PaliGemma with JAX\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "View on ai.google.dev\n", + "\n", + "Run in Google Colab\n", + "\n", + "View source on GitHub\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "74dcda33", + "metadata": { + "id": "wR53lePHuiP-" + }, + "source": [ + "This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n", + "\n", + "### What's in this notebook\n", + "\n", + "This notebook uses the model reference implementation from [`big_vision`](https://github.com/google-research/big_vision)\n", + "and shows how to:\n", + "\n", + " * Install dependencies, and download the PaliGemma model checkpoint and training data\n", + " * Load the model onto GPU devices\n", + " * Prepare the model's inputs for training and inference\n", + " * Fine-tune the model\n", + " * Inspect the output\n", + "\n", + "The training data for this notebook consists of 90 pairs of images and long captions describing them. To make it runnable on a Kaggle GPU runtime, you'll only fine-tune the attention layers of the language model and freeze the other parameters.\n", + "\n", + "This example is for learning purposes only. In a real use case, the amount of data, trainable parameters, training steps and hyper-parameters, and obtained results could be significantly different.\n", + "\n", + "### Before you begin\n", + "\n", + "Before going through this notebook, you should be familiar with Python code, as well as how large language models (LLMs) are trained. You don't need to be familiar with JAX, but basic knowledge about JAX (or similar technologies such as Keras) is helpful when reading through the example code." + ] + }, + { + "cell_type": "markdown", + "id": "a42e7554", + "metadata": { + "id": "6U0QUFveqSP2" + }, + "source": [ + "## Setup\n", + "\n", + "The following sections explain the preliminary steps for getting a notebook to use a PaliGemma model, including model access and configuring the notebook runtime." + ] + }, + { + "cell_type": "markdown", + "id": "16b96310", + "metadata": { + "id": "qRi1rF4MWlQi" + }, + "source": [ + "### Get access to PaliGemma\n", + "\n", + "Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n", + "\n", + "1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n", + "1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma/) and click **Request Access**.\n", + "1. Complete the consent form and accept the terms and conditions." + ] + }, + { + "cell_type": "markdown", + "id": "ee60a2fe", + "metadata": { + "id": "Kp6XQ2hQB8lv" + }, + "source": [ + "### Select the runtime\n", + "\n", + "To complete this tutorial, you'll need to have a Kaggle runtime with sufficient resources to run the PaliGemma model. In this case, you can use a GPU:\n", + "\n", + "1. In the upper-right of the Kaggle notebook window, click on the three dots.\n", + "1. Select **Accelerator**.\n", + "1. Choose **GPU P100 or GPU T4 x2** from the available options." + ] + }, + { + "cell_type": "markdown", + "id": "016fecda", + "metadata": { + "id": "rCd__uzW_eK-" + }, + "source": [ + "### Fetch the `big_vision` repository and install related dependencies\n", + "\n", + "Download the `big_vision` repository to your Kaggle notebook from GitHub and install dependencies related to `big_vision` by running the following code." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c92f001e", + "metadata": { + "id": "c2eba4d7d2d3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: jax[cuda12] in /opt/conda/lib/python3.10/site-packages (0.4.26)\r\n", + "Collecting jax[cuda12]\r\n", + " Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB)\r\n", + "Collecting jaxlib<=0.4.35,>=0.4.34 (from jax[cuda12])\r\n", + " Downloading jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)\r\n", + "Collecting ml-dtypes>=0.4.0 (from jax[cuda12])\r\n", + " Downloading ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\r\n", + "Requirement already satisfied: numpy>=1.24 in /opt/conda/lib/python3.10/site-packages (from jax[cuda12]) (1.26.4)\r\n", + "Requirement already satisfied: opt-einsum in /opt/conda/lib/python3.10/site-packages (from jax[cuda12]) (3.3.0)\r\n", + "Requirement already satisfied: scipy>=1.10 in /opt/conda/lib/python3.10/site-packages (from jax[cuda12]) (1.14.1)\r\n", + "Collecting jaxlib<=0.4.35,>=0.4.34 (from jax[cuda12])\r\n", + " Downloading jaxlib-0.4.34-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)\r\n", + "Collecting jax-cuda12-plugin<=0.4.35,>=0.4.34 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading jax_cuda12_plugin-0.4.35-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.2 kB)\r\n", + "Collecting jax-cuda12-pjrt==0.4.35 (from jax-cuda12-plugin<=0.4.35,>=0.4.34->jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading jax_cuda12_pjrt-0.4.35-py3-none-manylinux2014_x86_64.whl.metadata (349 bytes)\r\n", + "\u001b[33mWARNING: jax-cuda12-plugin 0.4.35 does not provide the extra 'with-cuda'\u001b[0m\u001b[33m\r\n", + "\u001b[0mCollecting nvidia-cublas-cu12>=12.1.3.1 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)\r\n", + "Collecting nvidia-cuda-cupti-cu12>=12.1.105 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)\r\n", + "Collecting nvidia-cuda-nvcc-cu12>=12.1.105 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl.metadata (1.5 kB)\r\n", + "Collecting nvidia-cuda-runtime-cu12>=12.1.105 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)\r\n", + "Collecting nvidia-cudnn-cu12<10.0,>=9.1 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_cudnn_cu12-9.6.0.74-py3-none-manylinux_2_27_x86_64.whl.metadata (1.6 kB)\r\n", + "Collecting nvidia-cufft-cu12>=11.0.2.54 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)\r\n", + "Collecting nvidia-cusolver-cu12>=11.4.5.107 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)\r\n", + "Collecting nvidia-cusparse-cu12>=12.1.0.106 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)\r\n", + "Collecting nvidia-nccl-cu12>=2.18.1 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\r\n", + "Collecting nvidia-nvjitlink-cu12>=12.1.105 (from jax-cuda12-plugin[with_cuda]<=0.4.35,>=0.4.34; extra == \"cuda12\"->jax[cuda12])\r\n", + " Downloading nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.5 kB)\r\n", + "Downloading jaxlib-0.4.34-cp310-cp310-manylinux2014_x86_64.whl (86.1 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.1/86.1 MB\u001b[0m \u001b[31m19.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading jax_cuda12_plugin-0.4.35-cp310-cp310-manylinux2014_x86_64.whl (15.5 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.5/15.5 MB\u001b[0m \u001b[31m85.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading jax_cuda12_pjrt-0.4.35-py3-none-manylinux2014_x86_64.whl (100.8 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 MB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m93.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading jax-0.4.35-py3-none-any.whl (2.2 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m67.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (393.1 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m393.1/393.1 MB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.9 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.9/8.9 MB\u001b[0m \u001b[31m59.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (21.2 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.2/21.2 MB\u001b[0m \u001b[31m78.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (897 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m897.7/897.7 kB\u001b[0m \u001b[31m42.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.6.0.74-py3-none-manylinux_2_27_x86_64.whl (508.1 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m508.1/508.1 MB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (200.2 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.2/200.2 MB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (158.2 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.2/158.2 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (216.6 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m216.6/216.6 MB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_x86_64.whl (199.0 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.0/199.0 MB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (19.7 MB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m19.7/19.7 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hInstalling collected packages: jax-cuda12-pjrt, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvcc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, ml-dtypes, jax-cuda12-plugin, nvidia-cusparse-cu12, nvidia-cufft-cu12, nvidia-cudnn-cu12, jaxlib, nvidia-cusolver-cu12, jax\r\n", + " Attempting uninstall: ml-dtypes\r\n", + " Found existing installation: ml-dtypes 0.3.2\r\n", + " Uninstalling ml-dtypes-0.3.2:\r\n", + " Successfully uninstalled ml-dtypes-0.3.2\r\n", + " Attempting uninstall: jaxlib\r\n", + " Found existing installation: jaxlib 0.4.26.dev20240620\r\n", + " Uninstalling jaxlib-0.4.26.dev20240620:\r\n", + " Successfully uninstalled jaxlib-0.4.26.dev20240620\r\n", + " Attempting uninstall: jax\r\n", + " Found existing installation: jax 0.4.26\r\n", + " Uninstalling jax-0.4.26:\r\n", + " Successfully uninstalled jax-0.4.26\r\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\r\n", + "tensorflow 2.16.1 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.5.0 which is incompatible.\u001b[0m\u001b[31m\r\n", + "\u001b[0mSuccessfully installed jax-0.4.35 jax-cuda12-pjrt-0.4.35 jax-cuda12-plugin-0.4.35 jaxlib-0.4.34 ml-dtypes-0.5.0 nvidia-cublas-cu12-12.6.4.1 nvidia-cuda-cupti-cu12-12.6.80 nvidia-cuda-nvcc-cu12-12.6.85 nvidia-cuda-runtime-cu12-12.6.77 nvidia-cudnn-cu12-9.6.0.74 nvidia-cufft-cu12-11.3.0.4 nvidia-cusolver-cu12-11.7.1.2 nvidia-cusparse-cu12-12.5.4.2 nvidia-nccl-cu12-2.23.4 nvidia-nvjitlink-cu12-12.6.85\r\n" + ] + } + ], + "source": [ + "!pip install -U \"jax[cuda12]\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3927a091", + "metadata": { + "id": "DfxKb3F839Ks" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "!git clone --quiet --branch=main --depth=1 \\\n", + " https://github.com/google-research/big_vision big_vision_repo\n", + "\n", + "# Append big_vision code to python import path\n", + "if \"big_vision_repo\" not in sys.path:\n", + " sys.path.append(\"big_vision_repo\")\n", + "\n", + "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n", + "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\"" + ] + }, + { + "cell_type": "markdown", + "id": "a61a030a", + "metadata": { + "id": "zDoq0O77GF30" + }, + "source": [ + "### Import JAX and other dependencies\n", + "\n", + "Import JAX and other dependencies required for PaliGemma, like TensorFlow and NumPy." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e15a2524", + "metadata": { + "id": "dTfe2k8J4Bw0" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_24/840491807.py:16: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n", + " from IPython.core.display import display, HTML\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "JAX version: 0.4.35\n", + "JAX platform: gpu\n", + "JAX devices: 2\n" + ] + } + ], + "source": [ + "import base64\n", + "import functools\n", + "import html\n", + "import io\n", + "import os\n", + "import warnings\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import ml_collections\n", + "\n", + "import tensorflow as tf\n", + "import sentencepiece\n", + "\n", + "from IPython.core.display import display, HTML\n", + "from PIL import Image\n", + "\n", + "# Import model definition from big_vision\n", + "from big_vision.models.proj.paligemma import paligemma\n", + "from big_vision.trainers.proj.paligemma import predict_fns\n", + "\n", + "# Import big vision utilities\n", + "import big_vision.datasets.jsonl\n", + "import big_vision.utils\n", + "import big_vision.sharding\n", + "\n", + "# Don't let TF use the GPU or TPUs\n", + "tf.config.set_visible_devices([], \"GPU\")\n", + "tf.config.set_visible_devices([], \"TPU\")\n", + "\n", + "backend = jax.extend.backend.get_backend()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX platform: {backend.platform}\")\n", + "print(f\"JAX devices: {jax.device_count()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "92dbccf9", + "metadata": { + "id": "b9kSadtIhjlX" + }, + "source": [ + "## Download and configure the model\n", + "\n", + "In this step, you'll download the model checkpoint and configure it so that you can fine-tune it later on. This step shows you how to move model parameters into TPU memory, which is useful for fine-tuning models on devices with limited resources." + ] + }, + { + "cell_type": "markdown", + "id": "923baf02", + "metadata": { + "id": "7tvcc0oQHl4v" + }, + "source": [ + "### Download the model checkpoint\n", + "\n", + "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n", + "\n", + "Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fde048e2", + "metadata": { + "id": "gQNOTfF24AV4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading the checkpoint from Kaggle, this could take a few minutes....\n", + "Model path: /kaggle/input/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz\n", + "Downloading the model tokenizer...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Copying gs://big_vision/paligemma_tokenizer.model...\r\n", + "\r\n", + "Operation completed over 1 objects/4.1 MiB. \r\n", + "Tokenizer path: ./paligemma_tokenizer.model\n", + "Downloading the dataset...\n", + "Data path: ./longcap100\n" + ] + } + ], + "source": [ + "import os\n", + "import kagglehub\n", + "\n", + "# Use these for PaliGemma-2 3B 224px²\n", + "LLM_VARIANT = \"gemma2_2b\"\n", + "MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n", + "KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\" # Path to fetch from Kaggle.\n", + "\n", + "# Use these for PaliGemma 1:\n", + "# LLM_VARIANT = \"gemma_2b\"\n", + "# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n", + "# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n", + "\n", + "if not os.path.exists(MODEL_PATH):\n", + " print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n", + " MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n", + " print(f\"Model path: {MODEL_PATH}\")\n", + "\n", + "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n", + "if not os.path.exists(TOKENIZER_PATH):\n", + " print(\"Downloading the model tokenizer...\")\n", + " !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n", + " print(f\"Tokenizer path: {TOKENIZER_PATH}\")\n", + "\n", + "DATA_DIR=\"./longcap100\"\n", + "if not os.path.exists(DATA_DIR):\n", + " print(\"Downloading the dataset...\")\n", + " !gsutil -m -q cp -n -r gs://longcap100/ .\n", + " print(f\"Data path: {DATA_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "id": "dd46593b", + "metadata": { + "id": "rv7w-cGuLj5o" + }, + "source": [ + "### Configure the model\n", + "\n", + "It's time to actually start configuring the model that you're going to use.\n", + "\n", + "For this notebook, you need to be able to fit your model onto a GPU. Having a limited resource like space constraints means that you have to be mindful of how your model is configured.\n", + "\n", + "If you fine-tune every parameter, your model won't be able to run in the notebook environment. As a result, in this part of the notebook, you'll configure your model so that it has the ability to freeze some of the parameters, and only fine-tune the parameters that really need to be fine-tuned for the model to give you accurate results. In LLMs, parameters are said to be *frozen* when they are no longer actively being used to train the model.\n", + "\n", + "In order to configure your model, you need to:\n", + "\n", + "* Initialize the `model_config` as a [`FrozenConfigDict`](https://github.com/google/ml_collections/tree/master#frozenconfigdict) so that you can freeze some of the parameters and keep memory usage low\n", + "* Initialize an instance of the PaliGemma `Model` class using the `model_config` as its configurations\n", + "* Load the model parameters into RAM\n", + "* Define a `decode` function to sample outputs from the model\n", + "\n", + "This code in this cell takes about a minute to run to completion." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "30747284", + "metadata": { + "id": "1aghcULcEdtv" + }, + "outputs": [], + "source": [ + "# Define model\n", + "\n", + "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n", + "# for better transfer results.\n", + "model_config = ml_collections.FrozenConfigDict({\n", + " \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n", + " \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n", + "})\n", + "model = paligemma.Model(**model_config)\n", + "tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n", + "\n", + "# Load params - this can take up to 1 minute in T4 colabs.\n", + "params = paligemma.load(None, MODEL_PATH, model_config)\n", + "\n", + "# Define `decode` function to sample outputs from the model.\n", + "decode_fn = predict_fns.get_all(model)['decode']\n", + "decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())" + ] + }, + { + "cell_type": "markdown", + "id": "fdfd4faf", + "metadata": { + "id": "uidBwmb8LwZ5" + }, + "source": [ + "### Move model parameters into GPU/TPU memory\n", + "\n", + "Now you need to move the model parameters into GPU/TPU memory. First, shard the parameters across the available GPUs, then load the parameters. Here, you'll load the parameters sequentially. This process takes longer than loading them simultaneously, but it requires more RAM than you have available in this notebook.\n", + "\n", + "Finally, print out all of the parameters to see what type each individual parameter is cast to. Frozen parameters are kept as `float16`, while the trainable parameters are cast to `float32`. When you inspect the list, you'll see that most of the parameters have been frozen and are `float16`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "19f25fb1", + "metadata": { + "id": "RWOdf_fw2SAO" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " == Model params == \n", + "img/Transformer/encoder_norm/bias (1152,) float16\n", + "img/Transformer/encoder_norm/scale (1152,) float16\n", + "img/Transformer/encoderblock/LayerNorm_0/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_0/scale (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_1/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/LayerNorm_1/scale (27, 1152) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias (27, 4304) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel (27, 1152, 4304) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel (27, 4304, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel (27, 1152, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias (27, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel (27, 16, 72, 1152) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel (27, 1152, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias (27, 16, 72) float16\n", + "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel (27, 1152, 16, 72) float16\n", + "img/embedding/bias (1152,) float16\n", + "img/embedding/kernel (14, 14, 3, 1152) float16\n", + "img/head/bias (2304,) float16\n", + "img/head/kernel (1152, 2304) float16\n", + "img/pos_embedding (1, 256, 1152) float16\n", + "llm/embedder/input_embedding (257152, 2304) float16\n", + "llm/final_norm/scale (2304,) float16\n", + "llm/layers/attn/attn_vec_einsum/w (26, 8, 256, 2304) float32\n", + "llm/layers/attn/kv_einsum/w (26, 2, 4, 2304, 256) float32\n", + "llm/layers/attn/q_einsum/w (26, 8, 2304, 256) float32\n", + "llm/layers/mlp/gating_einsum (26, 2, 2304, 9216) float16\n", + "llm/layers/mlp/linear (26, 9216, 2304) float16\n", + "llm/layers/post_attention_norm/scale (26, 2304) float16\n", + "llm/layers/post_ffw_norm/scale (26, 2304) float16\n", + "llm/layers/pre_attention_norm/scale (26, 2304) float16\n", + "llm/layers/pre_ffw_norm/scale (26, 2304) float16\n" + ] + } + ], + "source": [ + "# Create a pytree mask of the trainable params.\n", + "def is_trainable_param(name, param): # pylint: disable=unused-argument\n", + " if name.startswith(\"llm/layers/attn/\"): return True\n", + " if name.startswith(\"llm/\"): return False\n", + " if name.startswith(\"img/\"): return False\n", + " raise ValueError(f\"Unexpected param name {name}\")\n", + "trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)\n", + "\n", + "# If more than one device is available (e.g. multiple GPUs) the parameters can\n", + "# be sharded across them to reduce HBM usage per device.\n", + "mesh = jax.sharding.Mesh(jax.devices(), (\"data\"))\n", + "\n", + "data_sharding = jax.sharding.NamedSharding(\n", + " mesh, jax.sharding.PartitionSpec(\"data\"))\n", + "\n", + "params_sharding = big_vision.sharding.infer_sharding(\n", + " params, strategy=[('.*', 'fsdp(axis=\"data\")')], mesh=mesh)\n", + "\n", + "# Yes: Some donated buffers are not usable.\n", + "warnings.filterwarnings(\n", + " \"ignore\", message=\"Some donated buffers were not usable\")\n", + "\n", + "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n", + "def maybe_cast_to_f32(params, trainable):\n", + " # Cast others to float16, since some GPUs don't support bf16.\n", + " return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n", + " if m else p.astype(jnp.float16),\n", + " params, trainable)\n", + "\n", + "# Loading all params in simultaneous - albeit much faster and more succinct -\n", + "# requires more RAM than the T4 colab runtimes have by default.\n", + "# Instead we do it param by param.\n", + "params, treedef = jax.tree.flatten(params)\n", + "sharding_leaves = jax.tree.leaves(params_sharding)\n", + "trainable_leaves = jax.tree.leaves(trainable_mask)\n", + "for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):\n", + " params[idx] = big_vision.utils.reshard(params[idx], sharding)\n", + " params[idx] = maybe_cast_to_f32(params[idx], trainable)\n", + " params[idx].block_until_ready()\n", + "params = jax.tree.unflatten(treedef, params)\n", + "\n", + "# Print params to show what the model is made of.\n", + "def parameter_overview(params):\n", + " for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:\n", + " print(f\"{path:80s} {str(arr.shape):22s} {arr.dtype}\")\n", + "\n", + "print(\" == Model params == \")\n", + "parameter_overview(params)" + ] + }, + { + "cell_type": "markdown", + "id": "4bde55a0", + "metadata": { + "id": "iD_9XXQkn1Mv" + }, + "source": [ + "## Prepare to tune the model\n", + "\n", + "Now that your model is configured, you can tune it. In this step, you'll create your model's inputs as well as the training and validation iterators, view the training examples, and define the training and validation loops." + ] + }, + { + "cell_type": "markdown", + "id": "3ef1ef32", + "metadata": { + "id": "83ZcnbddJKdx" + }, + "source": [ + "### Create model inputs\n", + "\n", + "The model checkpoint you're using has already been trained on images of various aspect ratios that have been resized to 224x224 pixels, and to handle tokenized texts.\n", + "\n", + "The code below defines three functions that you'll use in the next step create the model's inputs:\n", + "\n", + "* **`preprocess_image`:** Normalizes the image data. In this case, pre-processing converts the passed-in image to greyscale, removes the alpha layer, and resizes the passed-in image to the size required by the model for image inputs (224x224 pixels).\n", + "* **`preprocess_tokens`:** Splits the tokens up and adds flags to mark whether a token is a prefix or suffix token. These flags will be used later on in the code, during the training step and the evaluation loop.\n", + "* **`postprocess_tokens`:** Removes any tokens left at and/or after the end-of-sequence (EOS) token and returns the remaining decoded tokens.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "aea6b72a", + "metadata": { + "id": "8SRW0NuU4UcW" + }, + "outputs": [], + "source": [ + "def preprocess_image(image, size=224):\n", + " # Model has been trained to handle images of different aspects ratios\n", + " # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize\n", + " # options are helpful to improve quality in some tasks.\n", + " image = np.asarray(image)\n", + " if image.ndim == 2: # Convert image without last channel into greyscale.\n", + " image = np.stack((image,)*3, axis=-1)\n", + " image = image[..., :3] # Remove alpha layer.\n", + " assert image.shape[-1] == 3\n", + "\n", + " image = tf.constant(image)\n", + " image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n", + " return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]\n", + "\n", + "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n", + " # Model has been trained to handle tokenized text composed of a prefix with\n", + " # full attention and a suffix with causal attention.\n", + " separator = \"\\n\"\n", + " tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)\n", + " mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.\n", + " mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.\n", + "\n", + " if suffix:\n", + " suffix = tokenizer.encode(suffix, add_eos=True)\n", + " tokens += suffix\n", + " mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.\n", + " mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.\n", + "\n", + " mask_input = [1] * len(tokens) # 1 if it's a token, 0 if padding.\n", + " if seqlen:\n", + " padding = [0] * max(0, seqlen - len(tokens))\n", + " tokens = tokens[:seqlen] + padding\n", + " mask_ar = mask_ar[:seqlen] + padding\n", + " mask_loss = mask_loss[:seqlen] + padding\n", + " mask_input = mask_input[:seqlen] + padding\n", + "\n", + " return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))\n", + "\n", + "def postprocess_tokens(tokens):\n", + " tokens = tokens.tolist() # np.array to list[int]\n", + " try: # Remove tokens at and after EOS if any.\n", + " eos_pos = tokens.index(tokenizer.eos_id())\n", + " tokens = tokens[:eos_pos]\n", + " except ValueError:\n", + " pass\n", + " return tokenizer.decode(tokens)\n" + ] + }, + { + "cell_type": "markdown", + "id": "672eed66", + "metadata": { + "id": "ovgWBgdHJZq3" + }, + "source": [ + "### Create the training and validation iterators\n", + "\n", + "Create two iterators:\n", + "\n", + "* A **training iterator** to allow the training process to go through the data in chunks rather than processing it all at once\n", + " * This allows you to do some data pre-processing before use\n", + "* A **validation iterator** that allows the training process to iterate over the validation dataset to see how well the tuned model aligned with the provided results" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fc220ff0", + "metadata": { + "id": "whzWOojGOtzi" + }, + "outputs": [], + "source": [ + "SEQLEN = 128\n", + "\n", + "train_dataset = big_vision.datasets.jsonl.DataSource(\n", + " os.path.join(DATA_DIR, \"data_train90.jsonl\"),\n", + " fopen_keys={\"image\": DATA_DIR})\n", + "\n", + "val_dataset = big_vision.datasets.jsonl.DataSource(\n", + " os.path.join(DATA_DIR, \"data_val10.jsonl\"),\n", + " fopen_keys={\"image\": DATA_DIR})\n", + "\n", + "\n", + "def train_data_iterator():\n", + " \"\"\"Never ending iterator over training examples.\"\"\"\n", + " # Shuffle examples and repeat so one can train for many epochs.\n", + " dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()\n", + " for example in dataset.as_numpy_iterator():\n", + " image = Image.open(io.BytesIO(example[\"image\"]))\n", + " image = preprocess_image(image)\n", + "\n", + " prefix = \"caption en\" # Could also be a different prefix per example.\n", + " suffix = example[\"suffix\"].decode().lower()\n", + " tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n", + "\n", + " yield {\n", + " \"image\": np.asarray(image),\n", + " \"text\": np.asarray(tokens),\n", + " \"mask_ar\": np.asarray(mask_ar),\n", + " \"mask_loss\": np.asarray(mask_loss),\n", + " }\n", + "\n", + "\n", + "def validation_data_iterator():\n", + " \"\"\"Single iterator over validation examples.\"\"\"\n", + " for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n", + " image = Image.open(io.BytesIO(example[\"image\"]))\n", + " image = preprocess_image(image)\n", + "\n", + " prefix = \"caption en\" # Could also be a different prefix per example.\n", + " tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n", + "\n", + " yield {\n", + " \"image\": np.asarray(image),\n", + " \"text\": np.asarray(tokens),\n", + " \"mask_ar\": np.asarray(mask_ar),\n", + " \"mask_input\": np.asarray(mask_input),\n", + " }\n" + ] + }, + { + "cell_type": "markdown", + "id": "0849f8a1", + "metadata": { + "id": "84olaM5dCiAl" + }, + "source": [ + "### View training examples\n", + "\n", + "In this notebook, the training data contains 90 images that are paired with long descriptions of what's depicted in the image.\n", + "\n", + "**Note:** Normal training data sets that are meant to be used for practical use cases should contain more images, but this notebook limits the number of data points so that you can train the model in a reasonable amount of time for an example.\n", + "\n", + "The code below prints a random selection of images with their descriptions from the training data set so that you can see what the images and descriptions your model is trained on looks like. Each image is displayed in as a 128x128 pixel JPEG, with the description printed next to the image to the right." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "55a7464e", + "metadata": { + "id": "BzJfb5t0nsLq" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training examples\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

a table topped with a variety of items, including a wooden box, a brush, a bowl, a jar, and a towel. the table is black, and the items are arranged neatly. the brush is made of wood, and the bowl is made of wood. the knife is made of wood, and the towel is striped. the jar is made of metal, and the lid is on the jar.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a martini glass sits on a bar, its contents neatly arranged. the glass sits on a black coaster, reflecting the lights of the city lights in the background. the bar is illuminated by a warm glow, casting long shadows on the wall. the glass on the coaster holds a lemon slice, a testament to the refreshing nature of the drink. the overall atmosphere is relaxed and inviting.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a plate of spaghetti with vegetables, including green leaves, green beans and bacon crumbs on the plate. the plate is white and sits on a black table. the spaghetti is yellow.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a blue bicycle is parked next to awooden fence. the bike has a black seat, a black kickstand, and a black tire. the fence is brown and the grass is green. there is a small green bush and a small green tree in the background. the bike has a light on the front and a light on the back. the front tire of the bike is on the ground and the back tire is on the fence. the bike has a black pedal and a black pedal on the bike. the bike has a black seat and a black seat on the bike. the bike has a black kickstand and a

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a large sign in the shape of a crown stands proudly in the center of a city square. the sign is illuminated by the reflection of the sun on the water, creating a vibrant display. a tall building casts long shadows on the ground, while a flag on top of a building waves proudly. people stroll along the sidewalk. the sky is clear and blue, with fluffy white clouds drifting above. the reflection of the city in the water is a mirror image of the city itself, showcasing the beauty and diversity of this urban landscape.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a leopard sits majestically on a tree branch, its eyes open and its mouth closed. the leopard's coat is adorned with intricate black spots, and its eyes are a vibrant blue. the tree behind the animal is tall and slender, its branches reaching out like a welcoming embrace. the leopard's whiskers are white.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a man stands on a red track, his leg raised high, his shoe firmly planted on the ground. the track is lined with white lines, and the grass is green. the man wears red shorts and grey and white shoes, and his socks are black. the man is running towards the finish line.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a group of people walk down a street. a black and white sign hangs from a building, while a brown sign with gold lettering advertises a business. a woman with a pink hat and a woman with a black backpack walk side by side, their backs facing the camera. a black and white sign on a pole and a black and white sign on a building are also visible. a woman with a blue jacket and a black backpack walk on the street.

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def render_inline(image, resize=(128, 128)):\n", + " \"\"\"Convert image into inline html.\"\"\"\n", + " image = Image.fromarray(image)\n", + " image.resize(resize)\n", + " with io.BytesIO() as buffer:\n", + " image.save(buffer, format='jpeg')\n", + " image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n", + " return f\"data:image/jpeg;base64,{image_b64}\"\n", + "\n", + "def render_example(image, caption):\n", + " image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n", + " return f\"\"\"\n", + "
\n", + " \n", + "

{html.escape(caption)}

\n", + "
\n", + " \"\"\"\n", + "\n", + "html_out = \"\"\n", + "for idx, example in zip(range(8), train_data_iterator()):\n", + " caption = postprocess_tokens(example[\"text\"]) # detokenize model input.\n", + " caption = caption[len(\"caption en\\n\"):] # strip prefix\n", + " html_out += render_example(example[\"image\"], caption)\n", + "\n", + "print(\"Training examples\")\n", + "display(HTML(html_out))" + ] + }, + { + "cell_type": "markdown", + "id": "5a55c1c1", + "metadata": { + "id": "N2BwpXkfI8OT" + }, + "source": [ + "### Define the training and evaluation loops\n", + "\n", + "Define the training loop to train the model on the provided dataset, and the evaluation loop to look at all of the examples in the validation dataset and make its predictions.\n", + "\n", + "#### Defining the training loop\n", + "\n", + "The `update_fn` function defines the training step. During the training step, the loss per example is calculated and stochastic gradient descent (SGD) is applied to the trainable parameters.\n", + "\n", + "Recall that earlier in the notebook, you included flags in the `preprocess_tokens` function that included `mask_loss`. You'll use the `mask_loss` flag here to exclude prefix and padded tokens from the loss. Without it, the loss calculation will be skewed. You also need to normalize each example, since each of them has a different number of tokens. After the prefix and padded tokens have been excluded and the examples have been normalized, you can calculate the loss per example.\n", + "\n", + "The training step also includes a function to apply an SGD to optimize the training.\n", + "\n", + "#### Defining the evaluation loop\n", + "\n", + "The `make_predictions` function is your evaluation loop. The evaluation loop is fairly straight forward with one notable change. If you recall from the beginning of the notebook, you only have 90 examples in your training data set. This is a very small amount of training examples, and your model ends up not having enough examples for the batch size when you run the training. This means that in the evaluation loop, you need to pad the batch by repeating examples.\n", + "\n", + "To make sure that your evaluation loop only counts actual examples and not the padded examples, you have to apply a mask to the padded examples that excludes them from the output." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ff9e0a81", + "metadata": { + "id": "dwUV_imW3WQJ" + }, + "outputs": [], + "source": [ + "# The main update_fn using a simple stochastic gradient descent (SGD).\n", + "@functools.partial(jax.jit, donate_argnums=(0,))\n", + "def update_fn(params, batch, learning_rate):\n", + " imgs, txts, mask_ar = batch[\"image\"], batch[\"text\"], batch[\"mask_ar\"]\n", + "\n", + " def loss_fn(params):\n", + " text_logits, _ = model.apply({\"params\": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)\n", + " logp = jax.nn.log_softmax(text_logits, axis=-1)\n", + "\n", + " # The model takes as input txts[:, :-1] but the loss is defined as predicting\n", + " # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens\n", + " # are part of the loss (e.g. prefix and padded tokens are not included).\n", + " mask_loss = batch[\"mask_loss\"][:, 1:]\n", + " targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n", + "\n", + " # Compute the loss per example. i.e. the mean of per token pplx.\n", + " # Since each example has a different number of tokens we normalize it.\n", + " token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n", + " example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n", + " example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n", + "\n", + " # batch_loss: mean of per example loss.\n", + " return jnp.mean(example_loss)\n", + "\n", + " loss, grads = jax.value_and_grad(loss_fn)(params)\n", + "\n", + " # Apply gradients to trainable params using SGD.\n", + " def apply_grad(param, gradient, trainable):\n", + " if not trainable: return param\n", + " return param - learning_rate * gradient\n", + "\n", + " params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)\n", + "\n", + " return params, loss\n", + "\n", + "# Evaluation/inference loop.\n", + "def make_predictions(data_iterator, *, num_examples=None,\n", + " batch_size=4, seqlen=SEQLEN, sampler=\"greedy\"):\n", + " outputs = []\n", + " while True:\n", + " # Construct a list of examples in the batch.\n", + " examples = []\n", + " try:\n", + " for _ in range(batch_size):\n", + " examples.append(next(data_iterator))\n", + " examples[-1][\"_mask\"] = np.array(True) # Indicates true example.\n", + " except StopIteration:\n", + " if len(examples) == 0:\n", + " return outputs\n", + "\n", + " # Not enough examples to complete a batch. Pad by repeating last example.\n", + " while len(examples) % batch_size:\n", + " examples.append(dict(examples[-1]))\n", + " examples[-1][\"_mask\"] = np.array(False) # Indicates padding example.\n", + "\n", + " # Convert list of examples into a dict of np.arrays and load onto devices.\n", + " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", + " batch = big_vision.utils.reshard(batch, data_sharding)\n", + "\n", + " # Make model predictions\n", + " tokens = decode({\"params\": params}, batch=batch,\n", + " max_decode_len=seqlen, sampler=sampler)\n", + "\n", + " # Fetch model predictions to device and detokenize.\n", + " tokens, mask = jax.device_get((tokens, batch[\"_mask\"]))\n", + " tokens = tokens[mask] # remove padding examples.\n", + " responses = [postprocess_tokens(t) for t in tokens]\n", + "\n", + " # Append to html output.\n", + " for example, response in zip(examples, responses):\n", + " outputs.append((example[\"image\"], response))\n", + " if num_examples and len(outputs) >= num_examples:\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "id": "bf6ad946", + "metadata": { + "id": "n9r9V1jwJvu9" + }, + "source": [ + "## Tune the model\n", + "\n", + "Now that you've set everything up and taken a look at the training data, it's time to finally tune the model. The code below runs the training loop for the model for 64 steps and prints the learning rate (`lr` in the printed output) and loss rate for each step.\n", + "\n", + "Every 16 steps, the model prints what its predictions are at that step in the training. This code prints out predictions for the same set of images so that you can see the model's ability to predict descriptions improve over time.\n", + "\n", + "At earlier steps in the training, there's likely issues with the descriptions, such as repeated sentences as the model gets stuck in its predictive loop or unfinished sentences. The model's predictions become steadily more accurate as training progresses. By step 64, the model's predictions should closely resemble the descriptions provided by the training data.\n", + "\n", + "This process takes around 15 minutes to complete on T4 TPUs." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "f78b3fea", + "metadata": { + "id": "067wj_6bZAG3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 1/64 lr: 0.00500 loss: 3.2539\n", + "step: 2/64 lr: 0.01000 loss: 1.9291\n", + "step: 3/64 lr: 0.01500 loss: 1.5984\n", + "step: 4/64 lr: 0.02000 loss: 1.6361\n", + "step: 5/64 lr: 0.02500 loss: 2.0249\n", + "step: 6/64 lr: 0.03000 loss: 2.6033\n", + "step: 7/64 lr: 0.02998 loss: 1.9704\n", + "step: 8/64 lr: 0.02992 loss: 1.6470\n", + "step: 9/64 lr: 0.02981 loss: 1.5255\n", + "step: 10/64 lr: 0.02966 loss: 1.5204\n", + "step: 11/64 lr: 0.02947 loss: 1.3989\n", + "step: 12/64 lr: 0.02924 loss: 1.2505\n", + "step: 13/64 lr: 0.02897 loss: 1.1247\n", + "step: 14/64 lr: 0.02866 loss: 1.0750\n", + "step: 15/64 lr: 0.02831 loss: 1.2703\n", + "step: 16/64 lr: 0.02792 loss: 1.0917\n", + "Model predictions at step 16\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

a woman's hand rests on a white wall, casting a shadow on the wall. the dress is pink, and the sleeves are long. the hand is on the wall, and the shadow is on the wall. the dress is flowing, and the sleeves are gathered. the wall is white, and the shadow is long. the hand is on the wall, and the shadow is long. the dress is pink, and the sleeves are gathered. the shadow is long.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman in a white dress with a pink flower on it sits on a stone wall overlooking the ocean. the dress has a floral pattern and a white bag on her hand. the sky is blue and the water is calm. the boat is on the water and the sails are white. the dress is flowing in the wind. the woman is wearing a hat and holding a bag. the dress is long and the flowers are pink. the sky is clear and the water is calm. the boat is on the water and the sails are white. the dress is flowing in the wind. the woman is sitting on a stone wall. the bag is

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wearing a red blazer and black pants, with a black bag on their hip. the bag has a silver zipper and a white writing on it. the person is wearing a white top underneath the blazer. the bag is black and has a silver zipper. the jacket is red and has a silver button. the pants are black and have a silver zipper. the person is wearing a white top underneath the blazer. the bag is black and has a silver zipper. the jacket is red and has a silver button. the pants are black and have a silver zipper. the bag is black and has a silver zipper. the person is wearing a

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman in a pink shirt and blue jeans stands on a stone staircase. the jeans have a hole in the knee. the woman is wearing a white cardigan and a pink bag. the bag is on her arm. the steps are gray. the wall is gray. the sky is blue. the ground is gray. the woman is wearing a pink bag. the sky is blue. the ground is gray. the wall is gray. the steps are gray. the sky is blue. the ground is gray. the woman is wearing a pink bag. the bag is on her arm. the steps are gray. the wall is gray. the sky

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 17/64 lr: 0.02750 loss: 1.1208\n", + "step: 18/64 lr: 0.02704 loss: 1.2137\n", + "step: 19/64 lr: 0.02655 loss: 1.0639\n", + "step: 20/64 lr: 0.02602 loss: 1.0356\n", + "step: 21/64 lr: 0.02546 loss: 0.9214\n", + "step: 22/64 lr: 0.02488 loss: 1.0569\n", + "step: 23/64 lr: 0.02426 loss: 0.9526\n", + "step: 24/64 lr: 0.02362 loss: 0.6038\n", + "step: 25/64 lr: 0.02296 loss: 0.8039\n", + "step: 26/64 lr: 0.02227 loss: 0.7570\n", + "step: 27/64 lr: 0.02156 loss: 0.7252\n", + "step: 28/64 lr: 0.02083 loss: 0.7221\n", + "step: 29/64 lr: 0.02009 loss: 0.7316\n", + "step: 30/64 lr: 0.01933 loss: 0.7288\n", + "step: 31/64 lr: 0.01856 loss: 0.6435\n", + "step: 32/64 lr: 0.01778 loss: 0.7477\n", + "Model predictions at step 32\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

a woman in a pink dress stands on a white wall, her hand on the wall. the dress is pink, and the woman's hand is on the wall. the dress is long and flowing, and the woman's hand is gripping the wall. the woman is wearing a bracelet and a watch. the dress is pink, and the woman's hand is on the wall. the woman is standing on a white wall, and the wall is white. the woman's hand is gripping the wall, and her fingers are curled. the woman is wearing a bracelet and a watch. the dress is pink, and the woman'

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman in a white dress with a floral pattern stands on a stone wall overlooking the ocean. the dress is long and flowing, and the woman is wearing a straw bag. the sky is clear and blue, and the water is calm. the woman is standing on a stone wall, and the dress is flowing in the wind. the woman is holding a white bag and wearing a pair of sandals. the dress is white and has a floral pattern. the woman is standing on a stone wall, and the dress is flowing in the wind. the woman is wearing a straw bag and a pair of sandals. the dress is long and has a

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wears a red blazer with a black fanny pack on their hip. the blazer is open and the person is wearing black pants. the person is standing in front of a green plant and is holding their hand in their pocket. the bag is black and has a zipper. the person is wearing a black top underneath the jacket. the jacket is red and has a button on the front. the person is wearing a black belt and a black fanny pack. the jacket is open and the person is wearing a black pants. the bag is on the person's hip and the zipper is on the bag. the person is standing in

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman stands on a stone staircase, her hand on her bag. her jeans are blue, and her shirt is pink. the woman is wearing a white cardigan and a pink bag. the stairs are made of stone, and the wall is made of concrete. the woman is standing on the stairs, and her hand is on her bag. the bag is pink, and the strap is long. the woman is wearing a bracelet and a necklace. the jeans are blue, and the buttons are white. the woman is wearing a pink shirt and a white cardigan. the bag is on her arm, and her hand is on the bag. the

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 33/64 lr: 0.01699 loss: 0.6949\n", + "step: 34/64 lr: 0.01620 loss: 0.6263\n", + "step: 35/64 lr: 0.01540 loss: 0.3855\n", + "step: 36/64 lr: 0.01460 loss: 0.2839\n", + "step: 37/64 lr: 0.01380 loss: 0.3310\n", + "step: 38/64 lr: 0.01301 loss: 0.4091\n", + "step: 39/64 lr: 0.01222 loss: 0.4324\n", + "step: 40/64 lr: 0.01144 loss: 0.3957\n", + "step: 41/64 lr: 0.01067 loss: 0.3261\n", + "step: 42/64 lr: 0.00991 loss: 0.4206\n", + "step: 43/64 lr: 0.00917 loss: 0.4413\n", + "step: 44/64 lr: 0.00844 loss: 0.3780\n", + "step: 45/64 lr: 0.00773 loss: 0.3321\n", + "step: 46/64 lr: 0.00704 loss: 0.2110\n", + "step: 47/64 lr: 0.00638 loss: 0.1994\n", + "step: 48/64 lr: 0.00574 loss: 0.1646\n", + "Model predictions at step 48\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

a woman in a pink dress stands on a white wall, her hand on the wall. the dress is pink, and the sleeves are long and gathered. the woman's hand is gripping the wall. the wall is white, and the shadow on the wall is long and dark.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman stands on a pier, her dress flowing in the wind. the sky is clear and blue, with a few clouds. the water is calm and blue, with a few waves. the woman holds her bag and stands with her legs crossed. the dress is white with a red and black flower print. the woman wears short sleeves and a tie on the dress. the dress is long and flowing.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wears a red blazer with a black belt and bag. the blazer is open and the bag is strapped to their waist. the bag is black and has a zipper. the person's hand is on the bag. the bag has a zipper and a silver chain. the blazer is loose and the buttons are unbuttoned. the person is standing next to a green plant.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman stands on a stone staircase, her hand on her bag. the jeans are blue, and the fabric is ripped. the shirt is pink, and the buttons are white. the woman is wearing a white cardigan and a silver bracelet on her wrist. the bag is pink, and the strap is pink.

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step: 49/64 lr: 0.00512 loss: 0.1421\n", + "step: 50/64 lr: 0.00454 loss: 0.2420\n", + "step: 51/64 lr: 0.00398 loss: 0.1420\n", + "step: 52/64 lr: 0.00345 loss: 0.1434\n", + "step: 53/64 lr: 0.00296 loss: 0.1580\n", + "step: 54/64 lr: 0.00250 loss: 0.2400\n", + "step: 55/64 lr: 0.00208 loss: 0.1307\n", + "step: 56/64 lr: 0.00169 loss: 0.1296\n", + "step: 57/64 lr: 0.00134 loss: 0.1500\n", + "step: 58/64 lr: 0.00103 loss: 0.1329\n", + "step: 59/64 lr: 0.00076 loss: 0.0738\n", + "step: 60/64 lr: 0.00053 loss: 0.1207\n", + "step: 61/64 lr: 0.00034 loss: 0.1089\n", + "step: 62/64 lr: 0.00019 loss: 0.1033\n", + "step: 63/64 lr: 0.00008 loss: 0.1217\n", + "step: 64/64 lr: 0.00002 loss: 0.1000\n", + "Model predictions at step 64\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

a woman in a pink dress stands on a white staircase, her hand on the wall. the dress is pink, and the fabric is sheer. the woman's hand is gripping the wall. the stairs are white, and the wall is painted white. the woman is wearing long sleeves, and the sleeves are gathered at the wrist. the dress has a collar, and the collar is white. the woman is standing on a step, and her hand is on the wall.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman stands on a pier, her dress flowing in the wind. the sky is clear and blue, with a few fluffy clouds. the water is calm and blue, and the boats on the water are visible. the woman's hand is on her hip, and her other hand is on her dress. the woman is wearing a long white dress with a floral pattern, and her hair is blonde. the dress is flowing in the wind, and the flowers on the dress are red and pink. the woman is standing next to the ocean, and the boats are floating on the water.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wears a red blazer with a black belt bag. the bag has a zipper and a silver zipper pull. the person wears black pants and has their fingers in the bag. the blazer has a button and a single vent. the person stands in front of a green plant.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman stands on a stone staircase, her hand on her purse. the jeans are blue, and the fabric is torn. the shirt is pink, and the buttons are white. the woman is wearing a white cardigan and a silver bracelet on her wrist. the bag is pink, and the strap is pink. the woman is walking on the street.

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Run a short training loop with cosine learning rate schedule.\n", + "#\n", + "# Note: the first step can be quite slow on some machines (up to several minutes)\n", + "# due to XLA compilation of the jax.jit'd function.\n", + "#\n", + "\n", + "BATCH_SIZE = 8\n", + "TRAIN_EXAMPLES = 512\n", + "LEARNING_RATE = 0.03\n", + "\n", + "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n", + "EVAL_STEPS = TRAIN_STEPS // 4\n", + "\n", + "train_data_it = train_data_iterator()\n", + "\n", + "sched_fn = big_vision.utils.create_learning_rate_schedule(\n", + " total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,\n", + " decay_type=\"cosine\", warmup_percent=0.10)\n", + "\n", + "for step in range(1, TRAIN_STEPS+1):\n", + " # Make list of N training examples.\n", + " examples = [next(train_data_it) for _ in range(BATCH_SIZE)]\n", + "\n", + " # Convert list of examples into a dict of np.arrays and load onto devices.\n", + " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", + " batch = big_vision.utils.reshard(batch, data_sharding)\n", + "\n", + " # Training step and report training loss\n", + " learning_rate = sched_fn(step)\n", + " params, loss = update_fn(params, batch, learning_rate)\n", + "\n", + " loss = jax.device_get(loss)\n", + " print(f\"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}\")\n", + "\n", + " if (step % EVAL_STEPS) == 0:\n", + " print(f\"Model predictions at step {step}\")\n", + " html_out = \"\"\n", + " for image, caption in make_predictions(\n", + " validation_data_iterator(), num_examples=4, batch_size=4):\n", + " html_out += render_example(image, caption)\n", + " display(HTML(html_out))\n" + ] + }, + { + "cell_type": "markdown", + "id": "f6019d20", + "metadata": { + "id": "glScsFLVJ52c" + }, + "source": [ + "## Output\n", + "\n", + "The validation data for this notebook consists of just 10 images. In normal code, you would likely have many more data points for validation, but for this notebook, run the following code to generate descriptions for all 10 images. After tuning the model, these descriptions should be very similar in form and content coverage to the descriptions included with the training data that you looked at earlier in this notebook.\n", + "\n", + "Run the below code to generate descriptions for the validation data set." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "6c3b2164", + "metadata": { + "id": "hgUhEKjzPdMQ" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model predictions\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

a woman in a pink dress stands on a white staircase, her hand on the wall. the dress is pink, and the fabric is sheer. the woman's hand is gripping the wall. the stairs are white, and the wall is painted white. the woman is wearing long sleeves, and the sleeves are gathered at the wrist. the dress has a collar, and the collar is white. the woman is standing on a step, and her hand is on the wall.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman stands on a pier, her dress flowing in the wind. the sky is clear and blue, with a few fluffy clouds. the water is calm and blue, and the boats on the water are visible. the woman's hand is on her hip, and her other hand is on her dress. the woman is wearing a long white dress with a floral pattern, and her hair is blonde. the dress is flowing in the wind, and the flowers on the dress are red and pink. the woman is standing next to the ocean, and the boats are floating on the water.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a person wears a red blazer with a black belt bag. the bag has a zipper and a silver zipper pull. the person wears black pants and has their fingers in the bag. the blazer has a button and a single vent. the person stands in front of a green plant.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman stands on a stone staircase, her hand on her purse. the jeans are blue, and the fabric is torn. the shirt is pink, and the buttons are white. the woman is wearing a white cardigan and a silver bracelet on her wrist. the bag is pink, and the strap is pink. the woman is walking on the street.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman is lying on a bed, wearing a pink sweater with the words "love will save us" written on it. the sweater is long-sleeved and has a crew neckline. the woman is wearing white sneakers and has her hand on the bed. the jeans are blue and have a belt loop. the blanket is gray and fuzzy.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a man stands with his hand on his head, his long blonde hair flowing in the wind. the man wears a black sweater and a white and black plaid shirt. the sweater is navy blue and the shirt is white and black. the man's hair is messy and his eyes are closed. the man is standing against a pink wall.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a row of white hangers on a white clothes rack, with a white wall in the background. the hangers are white, and the metal bar on the rack is white. the rack has a white metal pole on the bottom, and a white metal bar on the top. the wall is white, and the light is shining on the rack.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a white sweater hangs on a wooden hanger, with a white drawstring on the bottom of the sweater. the sweater has a hood and a pocket on the front. the pants have a white drawstring on the bottom of the pants. the clothes are hanging on a black pole, with a black circle on the wall. the clothes are on a white rack, with a white tag on the hanger.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a woman stands on a sidewalk, showcasing her black knee-high boots and black bag. the boots are made of suede and have a low heel. the bag has a gold chain strap and a silver lock. the woman's hand is on the bag. the jeans are blue and have a slight stretch. the woman is wearing a black knee-high boot and a black long-sleeve shirt. the boots are black and have a low heel. the bag is black and has a gold chain strap. the woman is standing on a gray sidewalk.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

a man stands on a road, his hands in his pockets. his pants are brown, and his shirt is white. he wears a denim jacket and white shoes. the road is gray and the trees are green. the man's hands are in his pockets. the man is standing on the road, his back to the camera. the man is wearing a white t-shirt and a blue denim jacket. the man's shoes are white. the man's pants are brown. the man is smiling.

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# The validation data consists of 10 images in a different domain than training\n", + "# data.\n", + "\n", + "print(\"Model predictions\")\n", + "html_out = \"\"\n", + "for image, caption in make_predictions(validation_data_iterator(), batch_size=4):\n", + " html_out += render_example(image, caption)\n", + "display(HTML(html_out))" + ] + } + ], + "metadata": { + "colab": { + "name": "Finetune_PaliGemma_2_with_JAX.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/PaliGemma 2/Finetune_PaliGemma_2_with_Keras.ipynb b/PaliGemma 2/Finetune_PaliGemma_2_with_Keras.ipynb new file mode 100644 index 0000000..43714b5 --- /dev/null +++ b/PaliGemma 2/Finetune_PaliGemma_2_with_Keras.ipynb @@ -0,0 +1,955 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "be34d25b", + "metadata": { + "id": "8377c056591f" + }, + "source": [ + "Copyright 2024 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6130c8e6", + "metadata": { + "cellView": "form", + "id": "ca23c3f523a7" + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "edb02808", + "metadata": { + "id": "afe2be3c68c3" + }, + "source": [ + "# Fine-tune PaliGemma 2 with Keras\n", + "\n", + "This notebook shows how to fine-tune [PaliGemma 2](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [Keras](https://keras.io/). Fine-tuning is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n", + "\n", + "## Before you begin\r\n", + "\r\n", + "Before going through this notebook, you should be familiar with Python code, as well as how large language models (LLMs) are trained. You don't need to be familiar with Keras, but basic knowledge about Keras (or similar technologies) is helpful when reading through the example code.e." + ] + }, + { + "cell_type": "markdown", + "id": "c941724c", + "metadata": { + "id": "dfbbe167352c" + }, + "source": [ + "## Install KerasHub" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "468e967f", + "metadata": { + "id": "5341a3c6ebe7" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting einops\r\n", + " Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)\r\n", + "Downloading einops-0.8.0-py3-none-any.whl (43 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25hInstalling collected packages: einops\r\n", + "Successfully installed einops-0.8.0\r\n" + ] + } + ], + "source": [ + "!pip install -q -U keras-nlp\n", + "!pip install -q -U keras>=3\n", + "!pip install einops" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b4eadf69", + "metadata": { + "id": "973190af9fb4" + }, + "outputs": [], + "source": [ + "!pip install -q -U keras keras-hub\n", + "\n", + "import os\n", + "# Set the backbend before importing Keras\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "# Avoid memory fragmentation on JAX backend.\n", + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\"\n", + "\n", + "# Set the fine-tuning variables\n", + "BATCH_SIZE = 1\n", + "TRAIN_EXAMPLES = 4\n", + "LEARNING_RATE = 0.003\n", + "\n", + "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n", + "EVAL_STEPS = TRAIN_STEPS // 4" + ] + }, + { + "cell_type": "markdown", + "id": "b7e49a04", + "metadata": { + "id": "0f0cc14c99cb" + }, + "source": [ + "## Download the training and validation data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6892fd5b", + "metadata": { + "id": "4fde23bb6f99" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_23/4203520384.py:15: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n", + " from IPython.core.display import display, HTML\n" + ] + } + ], + "source": [ + "import io\n", + "import json\n", + "import os\n", + "import urllib\n", + "\n", + "import base64\n", + "import html\n", + "\n", + "import numpy as np\n", + "import keras\n", + "import keras_hub\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from IPython.core.display import display, HTML\n", + "from PIL import Image\n", + "\n", + "train_file = urllib.request.urlopen(\n", + " \"https://storage.googleapis.com/longcap100/data_train90.jsonl\"\n", + ")\n", + "val_file = urllib.request.urlopen(\n", + " \"https://storage.googleapis.com/longcap100/data_val10.jsonl\"\n", + ")\n", + "\n", + "# Crop the image to the desired dimensions.\n", + "target_size = (224, 224)\n", + "\n", + "def load_image(image_url):\n", + " image = tf.io.decode_jpeg(urllib.request.urlopen(image_url).read())\n", + " return tf.image.resize(image, size=target_size)\n", + "\n", + "def load_dataset(file):\n", + " captions = []\n", + " images = []\n", + " for line in file:\n", + " sample = json.loads(line)\n", + " captions.append(sample[\"suffix\"])\n", + " image_name = sample[\"image\"]\n", + " image_url = f\"https://storage.googleapis.com/longcap100/{image_name}\"\n", + " images.append(load_image(image_url))\n", + " return tf.data.Dataset.from_tensor_slices({\n", + " \"images\": images,\n", + " \"prompts\": [\"caption en\\n\"] * len(images),\n", + " \"responses\": captions,\n", + " })\n", + "\n", + "train_data = load_dataset(train_file).shuffle(1000).batch(BATCH_SIZE)\n", + "val_data = load_dataset(val_file).shuffle(1000).batch(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "id": "ccba239c", + "metadata": { + "id": "f2480f920144" + }, + "source": [ + "## View training examples\r\n", + "\r\n", + "In this notebook, the training data contains 90 images that are paired with long descriptions of what's depicted in the image.\r\n", + "\r\n", + "**Note:** Normal training data sets that are meant to be used for practical use cases should contain more images, but this notebook limits the number of data points so that you can train the model in a reasonable amount of time for an example.\r\n", + "\r\n", + "The code below prints a random selection of images with their descriptions from the training data set so that you can see what the images and descriptions your model is trained on looks like. Each image is displayed in as a 128x128 pixel JPEG, with the description printed next to the image to the right." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "cd8d59b7", + "metadata": { + "id": "ee549257aa4f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training examples\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

A pile of tools sits on a brown carpet, including a green and black cordless drill, a yellow and black electronic device.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

A white towel with a Happy Easter message written on it. The towel is white, and the eggs are scattered on the towel. There are four colors of eggs on the towel, yellow, pink, white and blue. The message is written in black, and the letters are black. The towel is clean.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

A yellow train is stopped at a station. The train is on the tracks, and the lights are on. The platform has a white line and a white stripe on the platform. There is a person standing on the platform, and a person walking on the platform. The train has a number on its side. The train is moving, and the doors are closed.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

A group of people stand next to a train on a platform. The train is parked on the tracks, and the platform is made of bricks. There are several people standing on the platform, including a woman wearing a red shirt, a woman wearing a black shirt, and a woman wearing a blue shirt. The train has a window on the side. The sky is clear, and the sun is shining.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

A flamingo stands gracefully in the calm water, its long legs extending gracefully. The flamingo's pink feathers and black tail feathers contrast beautifully against the blue water. Its long neck and slender legs are prominent features, while its beak, black as night, and eye, white as the moon, add a touch of whimsy. The flamingo's reflection dances on the water's surface, mirroring its graceful movement. The flamingo's legs are pink.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

A surfer rides a wave, his white shirt billowing in the wind. The wave is white and blue, and the water is clear and blue. The surfer's board is white and blue, and his hair is dark. He is crouched on the board, his arm extended out to the side. The wave is crashing behind him, and the water is splashing in the air. The surfer is wearing a white shirt and black shorts, and his hair is wet.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

A storefront with a sign that reads "Segovia Coffee Bar" sits on a city street. The store has a black awning with white lettering, a white sign on the building, and a window on the building. There is a tree in front of the building,. The sidewalk in front of the store is empty, and there is a shadow on the ground. The store has a window on the building, and the window is open.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

A large sign in the shape of a crown stands proudly in the center of a city square. The sign is illuminated by the reflection of the sun on the water, creating a vibrant display. A tall building casts long shadows on the ground, while a flag on top of a building waves proudly. People stroll along the sidewalk. The sky is clear and blue, with fluffy white clouds drifting above. The reflection of the city in the water is a mirror image of the city itself, showcasing the beauty and diversity of this urban landscape.

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def render_inline(image, resize=(224, 224)):\n", + " \"\"\"Convert image into inline html.\"\"\"\n", + " image = tf.keras.preprocessing.image.array_to_img(image)\n", + " image.resize(resize)\n", + " with io.BytesIO() as buffer:\n", + " image.save(buffer, format='jpeg')\n", + " image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n", + " return f\"data:image/jpeg;base64,{image_b64}\"\n", + "\n", + "def render_example(image, caption):\n", + " image = np.asarray(image)\n", + " return f\"\"\"\n", + "
\n", + " \n", + "

{html.escape(caption)}

\n", + "
\n", + " \"\"\"\n", + "\n", + "html_out = \"\"\n", + "\n", + "for element in train_data.take(8):\n", + " caption = tf.compat.as_str_any(element[\"responses\"].numpy()[0])\n", + " html_out += render_example(element[\"images\"].numpy()[0], caption)\n", + "\n", + "print(\"Training examples\")\n", + "display(HTML(html_out))" + ] + }, + { + "cell_type": "markdown", + "id": "53c0034b", + "metadata": { + "id": "e90446ab2251" + }, + "source": [ + "## Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "10967820", + "metadata": { + "id": "39c82b20d38a" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n" + ] + }, + { + "data": { + "text/html": [ + "
Preprocessor: \"pali_gemma_causal_lm_preprocessor\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"pali_gemma_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                                                                                     Config ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ pali_gemma_tokenizer (PaliGemmaTokenizer)                     │                      Vocab size: 257,152 │\n",
+              "├───────────────────────────────────────────────────────────────┼──────────────────────────────────────────┤\n",
+              "│ pali_gemma_image_converter (PaliGemmaImageConverter)          │                   Image size: (224, 224) │\n",
+              "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Config\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ pali_gemma_tokenizer (\u001b[38;5;33mPaliGemmaTokenizer\u001b[0m) │ Vocab size: \u001b[38;5;34m257,152\u001b[0m │\n", + "├───────────────────────────────────────────────────────────────┼──────────────────────────────────────────┤\n", + "│ pali_gemma_image_converter (\u001b[38;5;33mPaliGemmaImageConverter\u001b[0m) │ Image size: (\u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m) │\n", + "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"pali_gemma_causal_lm\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"pali_gemma_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ images (InputLayer)           │ (None, 224, 224, 3)       │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ response_mask (InputLayer)    │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ pali_gemma_backbone           │ (None, None, 2304)        │   3,032,094,960 │ images[0][0],              │\n",
+              "│ (PaliGemmaBackbone)           │                           │                 │ padding_mask[0][0],        │\n",
+              "│                               │                           │                 │ response_mask[0][0],       │\n",
+              "│                               │                           │                 │ token_ids[0][0]            │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ token_embedding               │ (None, None, 257152)      │     592,478,208 │ pali_gemma_backbone[0][0]  │\n",
+              "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ get_item (GetItem)            │ (None, None, 257152)      │               0 │ token_embedding[1][0]      │\n",
+              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ images (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ response_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ pali_gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m3,032,094,960\u001b[0m │ images[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mPaliGemmaBackbone\u001b[0m) │ │ │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ response_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m257152\u001b[0m) │ \u001b[38;5;34m592,478,208\u001b[0m │ pali_gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ get_item (\u001b[38;5;33mGetItem\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m257152\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ token_embedding[\u001b[38;5;34m1\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 3,032,094,960 (11.30 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m3,032,094,960\u001b[0m (11.30 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 3,032,094,960 (11.30 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m3,032,094,960\u001b[0m (11.30 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(\n", + " \"/kaggle/input/paligemma2/keras/pali_gemma2_pt_3b_224/1\"\n", + "# \"kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_224\"\n", + ")\n", + "pali_gemma_lm.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "dd4fe0c2", + "metadata": { + "id": "edeab21cb990" + }, + "source": [ + "## Inference before fine tuning" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6355a8e9", + "metadata": { + "id": "ab780f60d802" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

caption en\n", + "a cow on the beach

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Inference Result\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

caption en\n", + "a series of white hangers

\n", + "
\n", + " \n", + "
\n", + " \n", + "

caption en\n", + "person in a red blazer and black pants

\n", + "
\n", + " \n", + "
\n", + " \n", + "

caption en\n", + "the top is a mix of fabrics .

\n", + "
\n", + " \n", + "
\n", + " \n", + "

caption en\n", + "love will save us sweatshirt in pink

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "test_image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'\n", + "test_image = load_image(test_image_url)\n", + "\n", + "def inference_test(image):\n", + " prompt = 'caption en\\n'\n", + " output = pali_gemma_lm.generate(\n", + " inputs={\n", + " \"images\": image,\n", + " \"prompts\": prompt,\n", + " }\n", + " )\n", + " return render_example(image, output)\n", + "\n", + "display(HTML(inference_test(test_image)))\n", + "\n", + "\n", + "def make_predictions():\n", + " html_out = \"\"\n", + " for element in val_data.take(4):\n", + " html_out += inference_test(element[\"images\"].numpy()[0])\n", + "\n", + " print(\"\\nInference Result\")\n", + " display(HTML(html_out))\n", + "\n", + "make_predictions()" + ] + }, + { + "cell_type": "markdown", + "id": "eaea17d7", + "metadata": { + "id": "7ba37a64794f" + }, + "source": [ + "## LoRA Fine-tuning" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "40d24252", + "metadata": { + "id": "0a61883f4a81" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Preprocessor: \"pali_gemma_causal_lm_preprocessor\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"pali_gemma_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                                                                                     Config ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ pali_gemma_tokenizer (PaliGemmaTokenizer)                     │                      Vocab size: 257,152 │\n",
+              "├───────────────────────────────────────────────────────────────┼──────────────────────────────────────────┤\n",
+              "│ pali_gemma_image_converter (PaliGemmaImageConverter)          │                   Image size: (224, 224) │\n",
+              "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Config\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ pali_gemma_tokenizer (\u001b[38;5;33mPaliGemmaTokenizer\u001b[0m) │ Vocab size: \u001b[38;5;34m257,152\u001b[0m │\n", + "├───────────────────────────────────────────────────────────────┼──────────────────────────────────────────┤\n", + "│ pali_gemma_image_converter (\u001b[38;5;33mPaliGemmaImageConverter\u001b[0m) │ Image size: (\u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m) │\n", + "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"pali_gemma_causal_lm\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"pali_gemma_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ images (InputLayer)           │ (None, 224, 224, 3)       │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ response_mask (InputLayer)    │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ pali_gemma_backbone           │ (None, None, 2304)        │   3,035,023,600 │ images[0][0],              │\n",
+              "│ (PaliGemmaBackbone)           │                           │                 │ padding_mask[0][0],        │\n",
+              "│                               │                           │                 │ response_mask[0][0],       │\n",
+              "│                               │                           │                 │ token_ids[0][0]            │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ token_embedding               │ (None, None, 257152)      │     592,478,208 │ pali_gemma_backbone[0][0]  │\n",
+              "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ get_item (GetItem)            │ (None, None, 257152)      │               0 │ token_embedding[1][0]      │\n",
+              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ images (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m224\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ response_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ pali_gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m3,035,023,600\u001b[0m │ images[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mPaliGemmaBackbone\u001b[0m) │ │ │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ response_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m257152\u001b[0m) │ \u001b[38;5;34m592,478,208\u001b[0m │ pali_gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ get_item (\u001b[38;5;33mGetItem\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m257152\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ token_embedding[\u001b[38;5;34m1\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 3,035,023,600 (11.31 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m3,035,023,600\u001b[0m (11.31 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 2,928,640 (11.17 MB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 3,032,094,960 (11.30 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m3,032,094,960\u001b[0m (11.30 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Enable lora to freeze most of the model and save memory.\n", + "pali_gemma_lm.backbone.enable_lora(4)\n", + "pali_gemma_lm.summary()\n", + "\n", + "# Lower our sequence length to further save memory.\n", + "pali_gemma_lm.preprocessor.sequence_length = 64\n", + "\n", + "# Use Cosine Decay Scheduler with Warm up\n", + "#cosine_decay_scheduler = tf.keras.optimizers.schedules.CosineDecay(\n", + "# initial_learning_rate = 0,\n", + "# decay_steps = TRAIN_EXAMPLES,\n", + "# warmup_target = LEARNING_RATE,\n", + "# warmup_steps = TRAIN_EXAMPLES / 10\n", + "#)\n", + "\n", + "def plot_scheduler(step, scheduler):\n", + " x = range(step)\n", + " y = []\n", + " for step in x:\n", + " y.append(scheduler(step))\n", + " plt.plot(x, y, label=scheduler.name)\n", + " plt.xlabel('Epoch')\n", + " plt.ylabel('Learning Rate')\n", + " plt.legend()\n", + " plt.show()\n", + "\n", + "#plot_scheduler(TRAIN_EXAMPLES, cosine_decay_scheduler)\n", + "\n", + "# Use AdamW (a common optimizer for transformer models).\n", + "#optimizer = keras.optimizers.SGD(learning_rate=cosine_decay_scheduler)\n", + "optimizer = keras.optimizers.AdamW(learning_rate=LEARNING_RATE)\n", + "\n", + "pali_gemma_lm.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=optimizer,\n", + " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "da88646d", + "metadata": { + "id": "016ed3975263" + }, + "source": [ + "## Fine-tune the model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "56c00057", + "metadata": { + "id": "17867028751a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/4\n", + "\u001b[1m90/90\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m117s\u001b[0m 877ms/step - loss: 1.6715 - sparse_categorical_accuracy: 0.5523\n", + "Epoch 2/4\n", + "\u001b[1m90/90\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m80s\u001b[0m 567ms/step - loss: 0.9483 - sparse_categorical_accuracy: 0.6871\n", + "Epoch 3/4\n", + "\u001b[1m90/90\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m52s\u001b[0m 555ms/step - loss: 0.7245 - sparse_categorical_accuracy: 0.7587\n", + "Epoch 4/4\n", + "\u001b[1m90/90\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m52s\u001b[0m 555ms/step - loss: 0.5710 - sparse_categorical_accuracy: 0.8049\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA86ElEQVR4nO3de1xUdeL/8fcMdxTwjiB4zxsigqVptWpZpmRZbVnaan2335apZda22ppmtVFbW7aF3bZyK69dtFLzEoZmWa4C3kC8gOIF8A6Ccp3z+4OkKNEZFM4M83o+HucPh8/hvDnNY+bdnM+cj8UwDEMAAAAmsZodAAAAuDfKCAAAMBVlBAAAmIoyAgAATEUZAQAApqKMAAAAU1FGAACAqSgjAADAVJ5mB7CHzWbToUOHFBAQIIvFYnYcAABgB8MwdOrUKYWGhspqrf7zD5coI4cOHVJ4eLjZMQAAQA3s379fYWFh1f7cJcpIQECApIo/JjAw0OQ0AADAHvn5+QoPD698H6+OS5SRs5dmAgMDKSMAALiYC02xYAIrAAAwFWUEAACYijICAABMRRkBAACmoowAAABTUUYAAICpKCMAAMBUlBEAAGAqyggAADAVZQQAAJiKMgIAAExFGQEAAKZy6zKy/VCe7n7nRx0vLDE7CgAAbstty4jNZuixhZu1PuOYJi1Mkc1mmB0JAAC35LZlxGq16NURPeXjaVVi+hHNStxtdiQAANyS25YRSeoaEqhnh3eXJL2yaqd+2HPU5EQAALgfty4jknTn5eG6o1eYbIb08LwUHc4vMjsSAABuxe3LiCQ9c0t3dWkZoKMFxRo/L1ll5TazIwEA4DYoI5L8vD00a1SMGvp4akPmcb28cqfZkQAAcBuUkZ+1b95QL97eQ5L01po9SkjLNTkRAADugTLyK7E9QnRvv7aSpEkLN2v/8dPmBgIAwA1QRn7jyaFd1TO8kfLOlGrc3CQVl5WbHQkAgHqNMvIb3p5WxY+KUSN/L205kKd/LE0zOxIAAPUaZeQcWjXy06sjekqSPly/T19tPmRuIAAA6jHKSDUGdm6hcQM7SJImf7ZFe44UmJwIAID6iTJyHo8O6qQr2zdRYUm5xn68SadLysyOBABAvUMZOQ9PD6v+fXe0mgf4aGdugaYu3ibDYEE9AAAuJcrIBbQI8NXrd0fLapE+TzqoBf/bb3YkAADqFcqIHa5s31SPD+4sSZr25XZtP5RnciIAAOoPyoidHvxDB13XpYVKymx6aE6S8otKzY4EAEC9QBmxk9Vq0b/ujFKrRn7ad+y0nvhkC/NHAAC4BCgjDmjk7634UTHy8rBo+fYcvf/9XrMjAQDg8igjDuoZ3khTY7tJkuKWpWnTvhMmJwIAwLVRRmpgdN82iu0RojKbofFzk3S8sMTsSAAAuCyHy8jatWs1bNgwhYaGymKxaPHixXbv+/3338vT01M9e/Z09LBOxWKx6MXbe6h98wbKzivSxAUpstmYPwIAQE04XEYKCwsVFRWl+Ph4h/Y7efKkRo8ereuuu87RQzqlhj6eenNUL/l6WbV25xG98e1usyMBAOCSHC4jQ4YM0XPPPadbb73Vof0efPBBjRw5Un379nX0kE6rc8sAPTc8UpL06jc79f3uoyYnAgDA9dTJnJEPPvhAGRkZmj59el0crk79sVeYRlweLsOQHpmfrNz8IrMjAQDgUmq9jOzatUuTJ0/Wxx9/LE9PT7v2KS4uVn5+fpXNmc24JUJdQwJ1tKBEE+Ymq6zcZnYkAABcRq2WkfLyco0cOVIzZsxQp06d7N4vLi5OQUFBlVt4eHgtprx4vl4emjUqRg19PLVh73G9tDLd7EgAALgMi3ERtxG1WCxatGiRhg8ffs6fnzx5Uo0bN5aHh0flYzabTYZhyMPDQytXrtS11177u/2Ki4tVXFxc+e/8/HyFh4crLy9PgYGBNY1b677emq2xc5IkSe+OvlzXdws2OREAAObJz89XUFDQBd+/7btuUkOBgYHaunVrlcdmzZql1atX69NPP1W7du3OuZ+Pj498fHxqM1qtGBIZovuuaqsPvt+rxxamaOnD1yi8ib/ZsQAAcGoOl5GCggLt3v3L11gzMzOVkpKiJk2aqHXr1poyZYoOHjyoDz/8UFarVd27d6+yf4sWLeTr6/u7x+uLKUO6KmX/SSVnndRDc5L0yYN95evlceEdAQBwUw7PGdm4caOio6MVHR0tSZo0aZKio6M1bdo0SVJ2draysrIubUoX4u1pVfzIGDX299LWg3l6bmmq2ZEAAHBqFzVnpK7Ye83JmSSmH9Z9s/8nw5Beu6unbunZyuxIAADUKXvfv1mbppYM6NxC4wd2lCRN+Xyrdh8+ZXIiAACcE2WkFk0c1En9OjTV6ZJyjf04SadLysyOBACA06GM1CIPq0Wv3RWtFgE+2nW4QH9ftE0ucFUMAIA6RRmpZc0DfPT63dHysFq0KPmg5m3Yb3YkAACcCmWkDvRp31SP39BZkvT0V9u17WCeyYkAAHAelJE68sAf2mtQ1xYqKbPpoTlJyjtTanYkAACcAmWkjlitFv3rjp4Ka+ynrOOn9ddPNjN/BAAAUUbqVJC/l2aNipG3h1UrU3P13rpMsyMBAGA6ykgd6xHWSE/d1FWS9MLXO7Rp33GTEwEAYC7KiAnuubKNhkWFqsxmaNycZB0rKL7wTgAA1FOUERNYLBbF3Rap9s0bKCe/SBMXpKjcxvwRAIB7ooyYpKGPp94c1Uu+XlZ9t+uoXl+9y+xIAACYgjJios4tA/T8rZGSpNcSdum7XUdMTgQAQN2jjJjstpgw3d07XIYhTZyfopy8IrMjAQBQpygjTmD6sAh1CwnUscISjZ+bpNJym9mRAACoM5QRJ+Dr5aE374lRgI+nNu47oZdWpJsdCQCAOkMZcRJtmjbQS3f0kCS9szZDK7fnmJwIAIC6QRlxIjd2D9Gfr24nSXrsk83KOnba5EQAANQ+yoiTmTyki2JaN9KpojKNnbNJRaXlZkcCAKBWUUacjJeHVW+MjFFjfy9tP5SvZ5akmh0JAIBaRRlxQqGN/DTzrmhZLNLcn7K0OPmg2ZEAAKg1lBEn1b9Tc0249jJJ0pTPt2pX7imTEwEAUDsoI07skesu09Udm+lMabnGzklSYXGZ2ZEAALjkKCNOzMNq0cy7eio40Ee7Dxfo74u2yjBYUA8AUL9QRpxcs4Y+emNkjDysFi1OOaS5G7LMjgQAwCVFGXEBV7RtoicGd5YkzfgyVVsP5JmcCACAS4cy4iL+8of2GtQ1WCXlNj00d5PyTpeaHQkAgEuCMuIiLBaL/nVHlMIa+2n/8TN6/NPNzB8BANQLlBEXEuTvpTdH9ZK3h1WrUnP17ncZZkcCAOCiUUZcTGRYkKYN6yZJenF5uv6397jJiQAAuDiUERc0qk9r3dIzVOU2Q+PnJuloQbHZkQAAqDGHy8jatWs1bNgwhYaGymKxaPHixecdv27dOl111VVq2rSp/Pz81KVLF7366qs1zQtVzB95/tZIdWzRULn5xZo4P0XlNuaPAABck8NlpLCwUFFRUYqPj7drfIMGDTR+/HitXbtWaWlpmjp1qqZOnap33nnH4bD4RQMfT705KkZ+Xh5at/uoXkvYZXYkAABqxGJcxFcyLBaLFi1apOHDhzu032233aYGDRroo48+smt8fn6+goKClJeXp8DAwBokrb8WJR/Qows2y2KR/ntfb/2hU3OzIwEAIMn+9+86nzOSnJysH374Qf3796/rQ9dLt0aH6e7erWUY0sQFKcrOO2N2JAAAHFJnZSQsLEw+Pj66/PLLNW7cON1///3Vji0uLlZ+fn6VDdWbPqybIkIDdbywROPnJqu03GZ2JAAA7FZnZeS7777Txo0b9dZbb2nmzJmaN29etWPj4uIUFBRUuYWHh9dVTJfk6+WhN0f1UoCvpzbtO6EXv95hdiQAAOxmypyR5557Th999JHS09PP+fPi4mIVF//yddX8/HyFh4czZ+QCVmzP0QMfbZIkvXVPL93YvaXJiQAA7sxp54xIks1mq1I2fsvHx0eBgYFVNlzY4IiW+n/XtJMk/fWTzdp3rNDkRAAAXJinozsUFBRo9+7dlf/OzMxUSkqKmjRpotatW2vKlCk6ePCgPvzwQ0lSfHy8WrdurS5dukiquE/Jyy+/rIcffvgS/Qn4tSdu7KLkrJPauO+Exn6cpM8f6idfLw+zYwEAUC2Hy8jGjRs1cODAyn9PmjRJkjRmzBjNnj1b2dnZysrKqvy5zWbTlClTlJmZKU9PT3Xo0EEvvviiHnjggUsQH7/l5WHV6yOjFfvvdUrNzteMr7Yr7rYeZscCAKBaFzVnpK5wnxHHfbfriEa/v0GGIb1yZ5RuiwkzOxIAwM049ZwR1L5rLmuuh6+9TJL090XbtDP3lMmJAAA4N8pIPfbwdZfpmsua6UxpucZ+vEmFxWVmRwIA4HcoI/WYh9WimSN6qmWgr/YcKdSUz7fKBa7KAQDcDGWknmva0EdvjIyWh9WiLzcf0sc/ZV14JwAA6hBlxA1c3raJJt9Y8dXqZ79K1ZYDJ80NBADAr1BG3MT917TTDd2CVVJu00NzkpR3utTsSAAASKKMuA2LxaKX7ohS6yb+OnDijB77JEU2G/NHAADmo4y4kSA/L80aFSNvT6u+STusd77LMDsSAACUEXfTvVWQpg/rJkl6aUW6fso4ZnIiAIC7o4y4oZG9W+vW6FYqtxmaMC9ZR05Vv2ghAAC1jTLihiwWi/5xa3dd1qKhDp8q1iPzk1XO/BEAgEkoI27K39tTb94TI39vD/2w55he+2an2ZEAAG6KMuLGOrYIUNxtkZKkf6/ercT0wyYnAgC4I8qIm7ulZyuN6tNakvToghQdOnnG5EQAAHdDGYGeuqmburcK1InTpRo3N0klZTazIwEA3AhlBPL18tCskb0U4Oup5KyTeuHrHWZHAgC4EcoIJEmtm/rrlTt7SpLe/z5TX2/NNjcQAMBtUEZQ6fpuwXrgD+0lSU98ukV7jxaanAgA4A4oI6ji8cGddUXbxjpVXKaxc5JUVFpudiQAQD1HGUEVXh5WvX53jJo28FZadr6e/nK72ZEAAPUcZQS/0zLIV6/dFS2LRZr/v/36dNMBsyMBAOoxygjO6erLmmnidZ0kSVMXb9WOnHyTEwEA6ivKCKo14dqOuuayZioqtemhOUkqKC4zOxIAoB6ijKBaVqtFM0f0VMtAX2UcKdTkz7bIMFhQDwBwaVFGcF5NG/ooflS0PK0WLdmSrY9+3Gd2JABAPUMZwQX1atNEk4d0kSQ9uyRVm/efNDcQAKBeoYzALn++up1ujGip0nJDD81J0snTJWZHAgDUE5QR2MViseifd/RQm6b+OnjyjCYt3CybjfkjAICLRxmB3QJ9vTRrVIy8Pa1aveOw3lq7x+xIAIB6gDICh0SEBmnGzRGSpJdXpOvHjGMmJwIAuDrKCBx21xXhui26lWyGNGFesg6fKjI7EgDAhVFG4DCLxaLnbu2uTsENdeRUsR6Zl6Jy5o8AAGrI4TKydu1aDRs2TKGhobJYLFq8ePF5x3/++ee6/vrr1bx5cwUGBqpv375asWJFTfPCSfh7e2rWqF5q4O2h9RnH9OqqnWZHAgC4KIfLSGFhoaKiohQfH2/X+LVr1+r666/XsmXLtGnTJg0cOFDDhg1TcnKyw2HhXDq2aKi423tIkt74dre+TT9sciIAgCuyGBdxf2+LxaJFixZp+PDhDu0XERGhESNGaNq0aXaNz8/PV1BQkPLy8hQYGFiDpKhNTy3epo9+3KdG/l5a+vA1atXIz+xIAAAnYO/7d53PGbHZbDp16pSaNGlS7Zji4mLl5+dX2eC8pt7UVT3CgnTydKnGzUlSSZnN7EgAABdS52Xk5ZdfVkFBge68885qx8TFxSkoKKhyCw8Pr8OEcJSPp4fiR8Yo0NdTKftP6vllaWZHAgC4kDotI3PnztWMGTO0cOFCtWjRotpxU6ZMUV5eXuW2f//+OkyJmghv4q9X7uwpSZr9w14t3ZJtbiAAgMuoszIyf/583X///Vq4cKEGDRp03rE+Pj4KDAysssH5DeoWrAf6t5ck/e2zLco4UmByIgCAK6iTMjJv3jzdd999mjdvnmJjY+vikDDJX2/orN7tmqiguEwPzUlSUWm52ZEAAE7O4TJSUFCglJQUpaSkSJIyMzOVkpKirKwsSRWXWEaPHl05fu7cuRo9erT+9a9/qU+fPsrJyVFOTo7y8vIuzV8Ap+LpYdUbd0erWUNv7cg5pWlfbDM7EgDAyTlcRjZu3Kjo6GhFR0dLkiZNmqTo6OjKr+lmZ2dXFhNJeuedd1RWVqZx48YpJCSkcnvkkUcu0Z8AZ9Mi0Ff/vitaVou0cOMBLdzInB8AQPUu6j4jdYX7jLim1xN26V+rdsrH06rF465S1xD+2wGAO3Ha+4zAfYwb2FH9OzVXcZlND81J0qmiUrMjAQCcEGUEtcZqtejVET0VEuSrzKOFmvzZVrnAB3EAgDpGGUGtatLAW2+MjJGn1aKlW7P13x/2mh0JAOBkKCOodb3aNNaUoV0lSf9YlqbkrBMmJwIAOBPKCOrE/13VVkO6t1RpuaHxc5N1orDE7EgAACdBGUGdsFgsevGPPdS2qb8OnjyjSQtTZLMxfwQAQBlBHQr09dKsUb3k42nVt+lH9OaaPWZHAgA4AcoI6lS30EA9c0uEJOlfK9P1w56jJicCAJiNMoI6d+fl4bo9Jkw2Q3p4XooO5xeZHQkAYCLKCOqcxWLRc8O7q3NwgI4WFGvCvGSVldvMjgUAMAllBKbw8/bQrHti1MDbQz9lHtcrq3aaHQkAYBLKCEzToXlDvXB7D0nSrMQ9Wr0j1+REAAAzUEZgqmFRoRrTt40k6dEFm3XgxGmTEwEA6hplBKZ7MrarosKClHemVOPmJKm4rNzsSACAOkQZgel8PD0UPypGQX5e2nwgT88vTTM7EgCgDlFG4BTCGvvr1RFRkqT/rt+nrzYfMjkRAKCuUEbgNK7tEqyxAzpIkiZ/tkV7jhSYnAgAUBcoI3Aqj13fSX3aNVFhSbke+jhJZ0qYPwIA9R1lBE7F08Oq1++OVrOGPkrPPaWnvthmdiQAQC2jjMDptAj01b/v7imrRfp00wEt/N9+syMBAGoRZQROqV+HZnrshs6SpKe+2KbUQ/kmJwIA1BbKCJzW2P4dNLBzcxWX2fTQnE3KLyo1OxIAoBZQRuC0rFaLXrmzp1o18tPeY6f1t0+3yDAMs2MBAC4xygicWuMG3npjZLS8PCz6eluOPvh+r9mRAACXGGUETi+6dWM9ObSrJOn5ZWlKyjphciIAwKVEGYFLuLdfW8VGhqjMZmj8nCSdKCwxOxIA4BKhjMAlWCwWvXB7pNo1a6BDeUV6dGGKbDbmjwBAfUAZgcsI8PXSrFEx8vG0KjH9iGYl7jY7EgDgEqCMwKV0DQnUs8O7S5JeWbVTP+w+anIiAMDFoozA5dx5ebju6BUmmyE9PD9ZuflFZkcCAFwEyghc0jO3dFeXlgE6WlCiCfOSVVZuMzsSAKCGHC4ja9eu1bBhwxQaGiqLxaLFixefd3x2drZGjhypTp06yWq1auLEiTWMCvzCz9tDs0bFqKGPpzZkHtfLK3eaHQkAUEMOl5HCwkJFRUUpPj7ervHFxcVq3ry5pk6dqqioKIcDAtVp37yhXry9hyTprTV7lJCWa3IiAEBNeDq6w5AhQzRkyBC7x7dt21avvfaaJOn999939HDAecX2CNH/9rbV7B/2atLCzVoy4WqFN/E3OxYAwAFOOWekuLhY+fn5VTagOk8O7aqe4Y2Ud6ZU4+Ymqbis3OxIAAAHOGUZiYuLU1BQUOUWHh5udiQ4MW9Pq+JHxaiRv5e2HMjTc0vSzI4EAHCAU5aRKVOmKC8vr3Lbv3+/2ZHg5Fo18tOrI3pKkj76cZ++3HzI3EAAALs5ZRnx8fFRYGBglQ24kIGdW2jcwA6SpMmfbdHuwwUmJwIA2MMpywhQU48O6qQr2zfR6ZJyPTRnk06XlJkdCQBwAQ6XkYKCAqWkpCglJUWSlJmZqZSUFGVlZUmquMQyevToKvucHV9QUKAjR44oJSVFqampF58e+A1PD6v+fXe0mgf4aGdugaYu3ibDYEE9AHBmFsPBV+rExEQNHDjwd4+PGTNGs2fP1r333qu9e/cqMTHxl4NYLL8b36ZNG+3du9euY+bn5ysoKEh5eXlcsoFd1u85plH/+VE2Q3rhtkjd1bu12ZEAwO3Y+/7tcBkxA2UENRH/7W69tCJd3p5WLXqonyJCg8yOBABuxd73b+aMoN4a27+Dru3SQiVlNj00J0n5RaVmRwIAnANlBPWW1WrRK3dGqVUjP+07dlp//WQz80cAwAlRRlCvNfL3VvyoGHl5WLRie67eW5dpdiQAwG9QRlDv9QxvpKmx3SRJL3y9Q5v2HTc5EQDg1ygjcAuj+7ZRbI8QldkMjZ+brOOFJWZHAgD8jDICt2CxWPTi7T3UvlkDZecVaeKCFNlszB8BAGdAGYHbaOjjqVn3xMjXy6q1O4/ojW93mx0JACDKCNxMl5aBem54pCTp1W92at2uoyYnAgBQRuB2/tgrTCMuD5dhSI/MT1ZOXpHZkQDArVFG4JZm3BKhriGBOlZYognzklRabjM7EgC4LcoI3JKvl4dmjYpRQx9P/W/vCb28It3sSADgtigjcFvtmjXQS3/sIUl6e22GVqXmmpwIANwTZQRubUhkiO67qq0k6bGFKdp//LS5gQDADVFG4PamDOmq6NaNlF9UpofmJKmotNzsSADgVigjcHvenla9MTJGjfy9tPVgnp5bmmp2JABwK5QRQFKrRn6aOaKnLBbp4x+z9EXKQbMjAYDboIwAPxvQuYXGD+woSZry+VbtPnzK5EQA4B4oI8CvTBzUSf06NNXpknKN/ThJp0vKzI4EAPUeZQT4FQ+rRa/dFa0WAT7adbhAf1+0TYbBgnoAUJsoI8BvNA/w0et3R8vDatGi5IOat2G/2ZEAoF6jjADn0Kd9Uz1+Q2dJ0tNfbde2g3kmJwKA+osyAlTjgT+013VdWqikzKaH5iQp70yp2ZEAoF6ijADVsFot+tedUWrVyE9Zx0/rr59sZv4IANQCyghwHo38vfXmPTHy9rBqZWqu/vNdptmRAKDeoYwAF9AjrJGeuqmrJOmF5Tu0ce9xkxMBQP1CGQHscM+VbTQsKlTlNkPj5ybrWEGx2ZEAoN6gjAB2sFgsirstUu2bN1BOfpEmLkhRuY35IwBwKVBGADs19PHUm6N6ydfLqu92HdXrq3eZHQkA6gXKCOCAzi0D9I/hkZKk1xJ26btdR0xOBACujzICOOj2XmG664pwGYb0yPwUZeedMTsSALg0yghQA0/fHKFuIYE6Xlii8XOTVVpuMzsSALgsh8vI2rVrNWzYMIWGhspisWjx4sUX3CcxMVExMTHy8fFRx44dNXv27BpEBZyHr5eH3rwnRgE+ntq074T+uXyH2ZEAwGU5XEYKCwsVFRWl+Ph4u8ZnZmYqNjZWAwcOVEpKiiZOnKj7779fK1ascDgs4EzaNG2gl+7oIUl697tMrdieY3IiAHBNFuMi7m9tsVi0aNEiDR8+vNoxf/vb37R06VJt27at8rG77rpLJ0+e1PLly+06Tn5+voKCgpSXl6fAwMCaxgVqxbNLUvXeukwF+Hpq6YRr1Lqpv9mRAMAp2Pv+XetzRtavX69BgwZVeWzw4MFav359tfsUFxcrPz+/ygY4q8lDuiimdSOdKirT2DmbVFRabnYkAHAptV5GcnJyFBwcXOWx4OBg5efn68yZc38LIS4uTkFBQZVbeHh4bccEaszLw6o3Rsaosb+Xth/K1zNLUs2OBAAuxSm/TTNlyhTl5eVVbvv37zc7EnBeoY38NPOuaFks0tyfsrQo+YDZkQDAZdR6GWnZsqVyc3OrPJabm6vAwED5+fmdcx8fHx8FBgZW2QBn179Tc00Y2FGS9OTn27Qjh8uLAGCPWi8jffv2VUJCQpXHVq1apb59+9b2oYE698igTrqqY1OdKS3X8Pjv9XrCLuaQAMAFOFxGCgoKlJKSopSUFEkVX91NSUlRVlaWpIpLLKNHj64c/+CDDyojI0NPPPGEduzYoVmzZmnhwoV69NFHL81fADgRD6tF/74rWr3bNVFRqU3/WrVTN7y6Vt+k5uoivrgGAPWaw1/tTUxM1MCBA3/3+JgxYzR79mzde++92rt3rxITE6vs8+ijjyo1NVVhYWF66qmndO+999p9TL7aC1djGIa+3HxIzy9LU25+sSRpQOfmmnZTN7Vv3tDkdABQN+x9/76o+4zUFcoIXFVhcZleX71b763LUGm5IS8Pi/58dXtNuLajGvh4mh0PAGoVZQRwIhlHCvTMklQlples8tsy0FdThnbRzVEVyyoAQH1EGQGcjGEYSkg7rGeWpCrr+GlJUu92TTTj5gh1DeF5DaD+oYwATqqotFzvrs1QfOJuFZXaZLVI91zZRpOu76RG/t5mxwOAS4YyAji5gyfP6PmlaVq6NVuS1NjfS0/c2EV3Xh4uDyuXbgC4PsoI4CJ+2H1UT3+1XTtzCyRJka2CNOOWCMW0bmxyMgC4OJQRwIWUltv04fp9mrlqp04Vl0mSbo8J09+GdFaLAF+T0wFAzVBGABd05FSx/rl8hz7ZVLG2TYCPpx4ZdJnG9GsrLw+nXEoKAKpFGQFcWHLWCU3/cru2HMiTJHVs0VAzbo7QVR2bmZwMAOxHGQFcnM1maOHG/frninQdLyyRJA3p3lJ/j+2qsMb+JqcDgAujjAD1RN7pUr36zU59uH6vbIbk62XV2P4d9UD/9vL18jA7HgBUizIC1DNp2fma/uV2bcg8LkkKb+Knp2K76fpuwdzFFYBToowA9ZBhGPpqS7aeX5qmnPwiSdIfOjXX9GHd1IEF+AA4GcoIUI8VFpcp/tvd+s93mSopt8nLw6L/u6qdJlx3mRqyAB8AJ0EZAdxA5tFCPbskVat3HJYktQjw0ZNDu+qWnizAB8B8lBHAjSSk5eqZJanad6xiAb4r2jbW0zdHKCI0yORkANwZZQRwM0Wl5XpvXabeWL1bZ0rLZbVII/u01uM3dGYBPgCmoIwAburQyTN6flmalmypWICvkb+XHr+hs+7u3ZoF+ADUKcoI4OZ+2HNUM75MVXruKUlSRGignrklQr3aNDE5GQB3QRkBoLJymz76cZ9eWbVTp4oqFuC7LbqVJg/pohaBLMAHoHZRRgBUOlpQrJeWp2vhpv0yDKmhj6cevq6j7u3XTt6eLMAHoHZQRgD8Tsr+k5r+xTZt/nkBvg7NG+jpmyN0zWXNTU4GoD6ijAA4J5vN0KebDujF5Tt07OcF+AZHBGtqbDeFN2EBPgCXDmUEwHnlnSnVzG926sP1+1RuM+TjadWD/Tto7IAOLMAH4JKgjACwS3rOKU3/cpt+zKhYgK9VIz89dVM3DY5gAT4AF4cyAsBuhmFo6dZs/WNpmrLzKhbgu+ayZpo+LEIdW7AAH4CaoYwAcNjpkjLN+naP3lmboZJymzytFt13VVs9fN1lCvD1MjseABdDGQFQY/uOVSzA901axQJ8zQN8NPnGLro1upWs3MUVgJ0oIwAu2rc7DmvGV9u19+cF+Hq1aawZN0eoeysW4ANwYZQRAJdEcdkvC/CdLimXxSLd3bu1/npDZzVuwAJ8AKpHGQFwSWXnnVHcsh36cvMhSVKQn5ceH9xZI1mAD0A1KCMAasWPGcf09JfbtSOnYgG+biGBmnFLhK5oywJ8AKqy9/27RotSxMfHq23btvL19VWfPn20YcOGaseWlpbqmWeeUYcOHeTr66uoqCgtX768JocF4ASubN9USyZcrRk3RyjQ11Op2fm64631mjg/Wbn5RWbHA+CCHC4jCxYs0KRJkzR9+nQlJSUpKipKgwcP1uHDh885furUqXr77bf1+uuvKzU1VQ8++KBuvfVWJScnX3R4AObw9LBqTL+2+vbxAbq7d7gsFmlxyiFd+3Ki3lqzRyVlNrMjAnAhDl+m6dOnj6644gq98cYbkiSbzabw8HBNmDBBkydP/t340NBQ/f3vf9e4ceMqH7v99tvl5+enjz/+2K5jcpkGcG5bDpzUtC+2K2X/SUlS+2YNNP3mCPXvxAJ8gDurlcs0JSUl2rRpkwYNGvTLL7BaNWjQIK1fv/6c+xQXF8vX17fKY35+flq3bl21xykuLlZ+fn6VDYDz6hHWSJ+P7aeX/thDzRp6K+Nooca8v0H/78ONyvr5a8EAUB2HysjRo0dVXl6u4ODgKo8HBwcrJyfnnPsMHjxYr7zyinbt2iWbzaZVq1bp888/V3Z2drXHiYuLU1BQUOUWHh7uSEwAJrBaLbrj8nCtfnyA/nx1O3lYLVqVmqtBr67RKyvTdaak3OyIAJxUjSawOuK1117TZZddpi5dusjb21vjx4/XfffdJ6u1+kNPmTJFeXl5ldv+/ftrOyaASyTQ10tP3dRNyx+5Rld1bKqSMpv+vXq3Br2yRsu2ZssFvsAHoI45VEaaNWsmDw8P5ebmVnk8NzdXLVu2POc+zZs31+LFi1VYWKh9+/Zpx44datiwodq3b1/tcXx8fBQYGFhlA+BaLgsO0Md/7qNZo2LUqpGfDp48o4fmJOme937SrtxTZscD4EQcKiPe3t7q1auXEhISKh+z2WxKSEhQ3759z7uvr6+vWrVqpbKyMn322We65ZZbapYYgMuwWCwaGhmibyb118PXdpS3p1Xf7z6mIa99p2eXpCq/qNTsiACcgMOXaSZNmqR3331X//3vf5WWlqaxY8eqsLBQ9913nyRp9OjRmjJlSuX4n376SZ9//rkyMjL03Xff6cYbb5TNZtMTTzxx6f4KAE7Nz9tDk27orG8e7a/ruwWrzGbovXWZuvblNfp00wHZbFy6AdyZp6M7jBgxQkeOHNG0adOUk5Ojnj17avny5ZWTWrOysqrMBykqKtLUqVOVkZGhhg0baujQofroo4/UqFGjS/ZHAHANrZv6693Rlysx/bCe+SpVGUcL9fgnmzXnp3165ubuigxjAT7AHXE7eACmKCmz6f3vM/V6wi4V/rwA311XhOuvg7uoCQvwAfUCa9MAcAm5+UWKW5amxSkVC/AF+npWLsDn6VHrX/gDUIsoIwBcyobM45r+5XalZVfc5LBLywDNuDlCfdo3NTkZgJqijABwOeU2Q3N/2qeXV+5U3pmKb9rcHBWqJ4d2Vcsg3wvsDcDZUEYAuKzjhSV6eWW65m3IkmFI/t4eGn9tR/356nby8fQwOx4AO1FGALi8rQfyNP3LbUrKOilJatesgaYN66aBnVuYGwyAXSgjAOoFm83QouSDivt6h44WFEuSBnVtoadu6qY2TRuYnA7A+VBGANQrp4pK9e+EXfrg+70qsxny9rTqgT+010MDOsrPm0s3gDOijACol3YfPqWnv0zVut1HJUmhQb76e2w3DY1sKYvFYnI6AL9GGQFQbxmGoRXbc/TskjQdPHlGktS3fVPNuCVCnYIDTE4H4CzKCIB670xJud5as0dvrdmj4jKbPKwWje7bRhMHdVKQn5fZ8QC3RxkB4Db2Hz+tZ5ekamVqriSpWUNvPXFjF/0xJkxWK5duALNQRgC4nbU7j+jpr7Yr40ihJKlneCPNuDlCUeGNzA0GuCnKCAC3VFJm0+wfMvXaN78swHdnr3A9cWNnNW3oY3Y8wK1QRgC4tcP5RXrh6x36PPmgJCnA11OTru+kP13ZhgX4gDpCGQEASRv3Hte0L7Yr9VcL8D19c4SuZAE+oNZRRgDgZ+U2Q/M2ZOnllek6ebpiAb6beoTo77FdFRLkZ3I6oP6ijADAb5woLNG/VqVr7k9ZshmSn1fFAnz3X8MCfEBtoIwAQDW2HczT019u18Z9JyRJbZv6a9qwbrq2S7DJyYD6hTICAOdhGIYWpxxU3LIdOnyqYgG+a7u00LSbuqltMxbgAy4FyggA2KGguEyvJ+zSe+syKxbg87Dq/mvaafy1HeXv7Wl2PMClUUYAwAG7Dxdoxlfb9d2uigX4QoJ89eTQrrqpRwgL8AE1RBkBAAcZhqGVqbl6dkmqDpyoWIDvyvZN9PTNEerSktcewFGUEQCooaLScr29JkOzEndXLsD3pyvb6NHrWYAPcARlBAAu0oETp/WPpWn6eluOJKlJA289Mbiz7rw8nAX4ADtQRgDgElm366ie/mq7dh8ukCRFhQVpxi3d1ZMF+IDzoowAwCVUWm7Tf3/Yq5nf7FJBcZkk6Y5eYXrixi5qHsACfMC52Pv+zWpRAGAHLw+r7r+mvVY/3l+3x4RJkj7ZdEDXvpyo99dlqrTcZnJCwHXxyQgA1MCmfSc0/ctt2nawYgG+TsEN9fTNEerXoZnJyQDnwWUaAKhl5TZDC/63Xy+t2KETPy/AFxsZoidju6pVIxbgAygjAFBHTp4u0SurdurjH/fJZki+XlaNG9BR/+8P7eXrxQJ8cF+UEQCoY6mH8vX0l9u1Ye9xSVLrJv6adlM3Xde1BXdxhVuq1Qms8fHxatu2rXx9fdWnTx9t2LDhvONnzpypzp07y8/PT+Hh4Xr00UdVVFRUk0MDgNPqFhqoBQ9cqdfu6qngQB9lHT+t+z/cqPtm/08ZRwrMjgc4LYfLyIIFCzRp0iRNnz5dSUlJioqK0uDBg3X48OFzjp87d64mT56s6dOnKy0tTe+9954WLFigJ5988qLDA4CzsVgsuqVnKyU8NkAP9u8gLw+LEtOPaPDMtXrh6x0q/PlrwQB+4fBlmj59+uiKK67QG2+8IUmy2WwKDw/XhAkTNHny5N+NHz9+vNLS0pSQkFD52GOPPaaffvpJ69ats+uYXKYB4KoyjhRoxlepWrPziCSpZaCvpgztopujQrl0g3qvVi7TlJSUaNOmTRo0aNAvv8Bq1aBBg7R+/fpz7tOvXz9t2rSp8lJORkaGli1bpqFDh1Z7nOLiYuXn51fZAMAVtW/eULPvu0Lvjr5crZv4Kye/SI/MT9GId35UWjavbYDkYBk5evSoysvLFRwcXOXx4OBg5eTknHOfkSNH6plnntHVV18tLy8vdejQQQMGDDjvZZq4uDgFBQVVbuHh4Y7EBACnYrFYdH23YK189A967PpO8vWyakPmccX++ztN+2KbTp4uMTsiYKpavwNrYmKinn/+ec2aNUtJSUn6/PPPtXTpUj377LPV7jNlyhTl5eVVbvv376/tmABQ63y9PDThusuU8NgAxUaGyGZIH67fp4EvJ2rehiyV25z+y41ArXBozkhJSYn8/f316aefavjw4ZWPjxkzRidPntQXX3zxu32uueYaXXnllXrppZcqH/v444/1l7/8RQUFBbJaL9yHmDMCoD76YfdRTf9yu3b9vABfZKsgzbglQjGtG5ucDLg0amXOiLe3t3r16lVlMqrNZlNCQoL69u17zn1Onz79u8Lh4VFxEyAXuMUJANSafh2badkj1+ipm7opwMdTWw/m6bZZP+ixhZt1+BS3P4D7cPgyzaRJk/Tuu+/qv//9r9LS0jR27FgVFhbqvvvukySNHj1aU6ZMqRw/bNgwvfnmm5o/f74yMzO1atUqPfXUUxo2bFhlKQEAd+XlYdWfr26n1Y8P0B29Khbg+yzpgK57eY3+810GC/DBLXg6usOIESN05MgRTZs2TTk5OerZs6eWL19eOak1KyuryichU6dOlcVi0dSpU3Xw4EE1b95cw4YN0z/+8Y9L91cAgItrHuCjl+6I0t19WuvpL7dry4E8Pbc0TfP/t18zbo7QVR1ZgA/1F7eDBwAnY7MZWrhxv/65Il3HCyu+aTOke0v9Pbarwhr7m5wOsB9r0wCAi8s7XapXv9mpD9fvrVyA795+7TQ8OlSdgwO4aRqcHmUEAOqJtOx8Tf9yuzZkHq98rH3zBropMkRDe4RQTOC0KCMAUI8YhqEV23P16aYDWrvziEp+NbGVYgJnRRkBgHrqVFGpEtIOa8mW7HMWk9jIEMVSTOAEKCMA4AbsKSZDI0PUpSXFBHWPMgIAbuZsMVm6NVtrdh5RSdmvikmzBortQTFB3aKMAIAbo5jAGVBGAACSKorJ6h0Vl3LOVUyG/jzHhGKCS40yAgD4HXuKydDIEHUNoZjg4lFGAADndbaYLN2SrcTfFJN2zX6Z/EoxQU1RRgAAdqOYoDZQRgAANVJQXKaEtNxqi8nQyJaKjQylmOCCKCMAgItmTzEZGhmibiGBFBP8DmUEAHBJnS0my7Zm69t0igkujDICAKg1vy4mielHVPyrYtK2qX/lfUwoJu6NMgIAqBMXKiZn72NCMXE/lBEAQJ0rKC77+Vs5h6otJkMjQxQRSjFxB5QRAICpzhaTZVuy9W36YYqJG6KMAACcRmFxmRIoJm6HMgIAcErnKyZtzs4xoZjUC5QRAIDTK6ycY0IxqY8oIwAAl3K2mCzbmq3VOygm9QFlBADgsn5dTL5NP6yiUoqJK6KMAADqhfMVk9ZNKorJTT0oJs6IMgIAqHcKi8v0bfovc0zOVUxiI0PUvRXFxBlQRgAA9drZYnJ2jgnFxPlQRgAAbuN0SdXJrxQT50AZAQC4pQsVkyGRLXVTZCjFpA5QRgAAbu90SZm+3XFES7ce+l0xCW/iVzH5lWJSaygjAAD8ytlismxrthJ25J6zmMRGhiiyVRDF5BKhjAAAUA2KSd2w9/3bWpNfHh8fr7Zt28rX11d9+vTRhg0bqh07YMAAWSyW322xsbE1OTQAABfN39tTsT1CFD8qRklPXa/4kTGKjQyRn5eH9h8/o7fXZOjmN77XH176VnHL0rTlwEm5wP+7uyyHPxlZsGCBRo8erbfeekt9+vTRzJkz9cknnyg9PV0tWrT43fjjx4+rpKSk8t/Hjh1TVFSU/vOf/+jee++165h8MgIAqAunS8qUmH5ES7dUTH49U1pe+bOwxn6KjQxRbA8+MbFXrV2m6dOnj6644gq98cYbkiSbzabw8HBNmDBBkydPvuD+M2fO1LRp05Sdna0GDRrYdUzKCACgrlUWk63ZWp127mIyNDJEPcIoJtWplTJSUlIif39/ffrppxo+fHjl42PGjNHJkyf1xRdfXPB3REZGqm/fvnrnnXfsPSxlBABgqjMl5RV3fqWYOMTe929PR37p0aNHVV5eruDg4CqPBwcHa8eOHRfcf8OGDdq2bZvee++9844rLi5WcXFx5b/z8/MdiQkAwCXl5+2hoT8Xjt8WkwMnzujttRl6e22Gwhr/MvmVYmI/h8rIxXrvvfcUGRmp3r17n3dcXFycZsyYUUepAACw32+LSWL6YS35VTF5Z22G3qGYOKTOLtMUFhYqNDRUzzzzjB555JHzHudcn4yEh4dzmQYA4LTOFpOlW7OVcI5LOWcLTJQbFZNauUzj7e2tXr16KSEhobKM2Gw2JSQkaPz48efd95NPPlFxcbHuueeeCx7Hx8dHPj4+jkQDAMBUft4eGhIZoiG/+sTkbDH59ScmrRr5KbaH+xWT86nRV3vHjBmjt99+W71799bMmTO1cOFC7dixQ8HBwRo9erRatWqluLi4Kvtdc801atWqlebPn+9wSCawAgBc1a+Lyeodh3W65JdPTFo18tPQyJaK7RFaL4tJrXwyIkkjRozQkSNHNG3aNOXk5Khnz55avnx55aTWrKwsWa1V76WWnp6udevWaeXKlY4eDgAAl/bbT0zW7DysJT/fx+TgyTN697tMvftdZr0vJufD7eABADDB2WKydGuOEtJyz/mJydDIEPUMb+SyxYS1aQAAcBFFpWcv5dSvYkIZAQDABdWnYkIZAQDAxVUUkyM/fyvn98VkSPeWiu3hvMWEMgIAQD1ytpgs25qtb6opJkN7hCjaiYoJZQQAgHrqfMUkNMi34gZrTlBMKCMAALiBXxeThLRcFTpRMaGMAADgZopKy7Vm5xEt3XLuYjIkMkSxdVhMKCMAALgxe4rJ0MiKYmK11k4xoYwAAABJvxSTZVuz9U1q1WIS8vOlnDsvD1fnlgGX9Li1djt4AADgWny9PDQ4oqUGR7T8XTHJzivSe+sy1bFFw0teRuxFGQEAwI38tpis3VlxH5PBES1Ny0QZAQDATfl6eeiGiJa6wcQiIknWCw8BAACoPZQRAABgKsoIAAAwFWUEAACYijICAABMRRkBAACmoowAAABTUUYAAICpKCMAAMBUlBEAAGAqyggAADAVZQQAAJiKMgIAAEzlEqv2GoYhScrPzzc5CQAAsNfZ9+2z7+PVcYkycurUKUlSeHi4yUkAAICjTp06paCgoGp/bjEuVFecgM1m06FDhxQQECCLxXLJfm9+fr7Cw8O1f/9+BQYGXrLfW19xvuzHubIf58p+nCv7ca7sV5vnyjAMnTp1SqGhobJaq58Z4hKfjFitVoWFhdXa7w8MDOTJ6gDOl/04V/bjXNmPc2U/zpX9autcne8TkbOYwAoAAExFGQEAAKZy6zLi4+Oj6dOny8fHx+woLoHzZT/Olf04V/bjXNmPc2U/ZzhXLjGBFQAA1F9u/ckIAAAwH2UEAACYijICAABMRRkBAACmqvdlJD4+Xm3btpWvr6/69OmjDRs2nHf8J598oi5dusjX11eRkZFatmxZHSU1nyPnavbs2bJYLFU2X1/fOkxrnrVr12rYsGEKDQ2VxWLR4sWLL7hPYmKiYmJi5OPjo44dO2r27Nm1ntMZOHquEhMTf/e8slgsysnJqZvAJoqLi9MVV1yhgIAAtWjRQsOHD1d6evoF93PH16yanCt3fc1688031aNHj8obmvXt21dff/31efcx4zlVr8vIggULNGnSJE2fPl1JSUmKiorS4MGDdfjw4XOO/+GHH3T33Xfrz3/+s5KTkzV8+HANHz5c27Ztq+Pkdc/RcyVV3K0vOzu7ctu3b18dJjZPYWGhoqKiFB8fb9f4zMxMxcbGauDAgUpJSdHEiRN1//33a8WKFbWc1HyOnquz0tPTqzy3WrRoUUsJnceaNWs0btw4/fjjj1q1apVKS0t1ww03qLCwsNp93PU1qybnSnLP16ywsDC98MIL2rRpkzZu3Khrr71Wt9xyi7Zv337O8aY9p4x6rHfv3sa4ceMq/11eXm6EhoYacXFx5xx/5513GrGxsVUe69Onj/HAAw/Uak5n4Oi5+uCDD4ygoKA6Sue8JBmLFi0675gnnnjCiIiIqPLYiBEjjMGDB9diMudjz7n69ttvDUnGiRMn6iSTMzt8+LAhyVizZk21Y9z5NevX7DlXvGb9onHjxsZ//vOfc/7MrOdUvf1kpKSkRJs2bdKgQYMqH7NarRo0aJDWr19/zn3Wr19fZbwkDR48uNrx9UVNzpUkFRQUqE2bNgoPDz9v03Z37vq8uhg9e/ZUSEiIrr/+en3//fdmxzFFXl6eJKlJkybVjuG5VcGecyXxmlVeXq758+ersLBQffv2PecYs55T9baMHD16VOXl5QoODq7yeHBwcLXXn3NychwaX1/U5Fx17txZ77//vr744gt9/PHHstls6tevnw4cOFAXkV1Kdc+r/Px8nTlzxqRUzikkJERvvfWWPvvsM3322WcKDw/XgAEDlJSUZHa0OmWz2TRx4kRdddVV6t69e7Xj3PU169fsPVfu/Jq1detWNWzYUD4+PnrwwQe1aNEidevW7ZxjzXpOucSqvXA+ffv2rdKs+/Xrp65du+rtt9/Ws88+a2IyuLLOnTurc+fOlf/u16+f9uzZo1dffVUfffSRicnq1rhx47Rt2zatW7fO7ChOz95z5c6vWZ07d1ZKSory8vL06aefasyYMVqzZk21hcQM9faTkWbNmsnDw0O5ublVHs/NzVXLli3PuU/Lli0dGl9f1ORc/ZaXl5eio6O1e/fu2ojo0qp7XgUGBsrPz8+kVK6jd+/ebvW8Gj9+vJYsWaJvv/1WYWFh5x3rrq9ZZzlyrn7LnV6zvL291bFjR/Xq1UtxcXGKiorSa6+9ds6xZj2n6m0Z8fb2Vq9evZSQkFD5mM1mU0JCQrXXyvr27VtlvCStWrWq2vH1RU3O1W+Vl5dr69atCgkJqa2YLstdn1eXSkpKils8rwzD0Pjx47Vo0SKtXr1a7dq1u+A+7vrcqsm5+i13fs2y2WwqLi4+589Me07V6vRYk82fP9/w8fExZs+ebaSmphp/+ctfjEaNGhk5OTmGYRjGn/70J2Py5MmV47///nvD09PTePnll420tDRj+vTphpeXl7F161az/oQ64+i5mjFjhrFixQpjz549xqZNm4y77rrL8PX1NbZv327Wn1BnTp06ZSQnJxvJycmGJOOVV14xkpOTjX379hmGYRiTJ082/vSnP1WOz8jIMPz9/Y2//vWvRlpamhEfH294eHgYy5cvN+tPqDOOnqtXX33VWLx4sbFr1y5j69atxiOPPGJYrVbjm2++MetPqDNjx441goKCjMTERCM7O7tyO336dOUYXrMq1ORcuetr1uTJk401a9YYmZmZxpYtW4zJkycbFovFWLlypWEYzvOcqtdlxDAM4/XXXzdat25teHt7G7179zZ+/PHHyp/179/fGDNmTJXxCxcuNDp16mR4e3sbERERxtKlS+s4sXkcOVcTJ06sHBscHGwMHTrUSEpKMiF13Tv79dPfbmfPz5gxY4z+/fv/bp+ePXsa3t7eRvv27Y0PPvigznObwdFz9eKLLxodOnQwfH19jSZNmhgDBgwwVq9ebU74Onau8ySpynOF16wKNTlX7vqa9X//939GmzZtDG9vb6N58+bGddddV1lEDMN5nlMWwzCM2v3sBQAAoHr1ds4IAABwDZQRAABgKsoIAAAwFWUEAACYijICAABMRRkBAACmoowAAABTUUYAAICpKCMAAMBUlBEAAGAqyggAADAVZQQAAJjq/wMKY89yiQ9JMgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class CustomCallback(keras.callbacks.Callback):\n", + " def on_epoch_end(self, epoch, logs=None):\n", + " keys = list(logs.keys())\n", + " if(epoch % EVAL_STEPS == EVAL_STEPS-1):\n", + " # Evaluate\n", + " display(HTML(inference_test(test_image)))\n", + " make_predictions()\n", + "\n", + "#history = pali_gemma_lm.fit(train_data, epochs=TRAIN_STEPS, callbacks=[CustomCallback()])\n", + "history = pali_gemma_lm.fit(train_data, epochs=TRAIN_STEPS)\n", + "plt.plot(history.history['loss'])\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "dbfa8168", + "metadata": { + "id": "37fa359da1d9" + }, + "source": [ + "## Output\r\n", + "\r\n", + "The validation data for this notebook consists of just 10 images. In normal code, you would likely have many more data points for validation, but for this notebook, run the following code to generate descriptions for all 10 images. After tuning the model, these descriptions should be very similar in form and content coverage to the descriptions included with the training data that you looked at earlier in this notebook.\r\n", + "\r\n", + "Run the below code to generate descriptions for the validation data set." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "35f5080d", + "metadata": { + "id": "321ca68e4b44" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

caption en\n", + "A brown cow stands proudly on a beach, its ears pricks. The cow has a white spot on its face. The cow has a white spot on its face. The grass is brown and the sky is clear. The cow is standing on the beach, its ears pricks. The

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Inference Result\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + "

caption en\n", + "A person wearing a red blazer and black pants, a red blazer, and a white shirt. The blazer is long and the pants are long. The blazer has a long red button. The blazer has a long red button. The pants are black. The blazer has a long red button.

\n", + "
\n", + " \n", + "
\n", + " \n", + "

caption en\n", + "A man stands in a park, wearing a denim jacket and white shoes. The man wears a white shirt, brown pants, and white shoes. The jacket is blue. The pants are brown. The shoes are white. The jacket is blue. The shirt is white. The man is standing

\n", + "
\n", + " \n", + "
\n", + " \n", + "

caption en\n", + "A woman wearing a white dress with a red flower pattern, a black skirt, and a brown belt. The dress is long and has a white skirt. The woman is wearing a brown belt, a brown leather purse, a brown paper bag, a brown basket, and a brown paper basket

\n", + "
\n", + " \n", + "
\n", + " \n", + "

caption en\n", + "A person holds a pink sweater with a red heart on it. The sweater has a white collar, long sleeves, and long sleeves. The sweater has a red heart on it. The sweater has a black heart on it. The sweater has a red heart on it. The sweater has a

\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(HTML(inference_test(test_image)))\n", + "make_predictions()" + ] + } + ], + "metadata": { + "colab": { + "name": "Finetune_PaliGemma_2_with_Keras.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/README.md b/README.md index e574c58..41dcc2c 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,12 @@ You can find the Gemma models on GitHub, Hugging Face models, Kaggle, Google Clo | **Mobile** | | | [PaliGemma on Android](PaliGemma/PaliGemma-on-Android) | Inference PaliGemma on Android using Hugging Face and Gradio Client API for tasks like zero-shot object detection, image captioning, and visual question-answering. | +#### PaliGemma 2 +| **Finetuning** | | +| :----------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [Finetune_PaliGemma_2_with_Keras.ipynb](PaliGemma%202/Finetune_PaliGemma_2_with_Keras.ipynb) | Finetune PaliGemma 2 with Keras. | +| [Finetune_PaliGemma_2_with_JAX.ipynb](PaliGemma%202/Finetune_PaliGemma_2_with_JAX.ipynb) | Finetune PaliGemma 2 with JAX. | + #### CodeGemma | **Finetuning** | | | :--------------------------------------------------------------------------------------------- | ------------------------------------------------ |