diff --git a/Gemma/Custom_Vocabulary.ipynb b/Gemma/Custom_Vocabulary.ipynb new file mode 100644 index 0000000..d1c6985 --- /dev/null +++ b/Gemma/Custom_Vocabulary.ipynb @@ -0,0 +1,2338 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "hovLNQ-tZuwd" + }, + "source": [ + "##### Copyright 2024 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "yDENQhUBZxfL" + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eNmbq93EZzID" + }, + "source": [ + "# Gemma - How to use a custom vocabulary\n", + "\n", + "This notebook demonstrates how to use a custom vocabulary in Gemma.\n", + "\n", + "Consider a document containing a length word, such as \"MyCustomWordInMyDocument\" with a high frequency of occurrence. It will be fragmented into several tokens, resulting in inefficiencies. Gemma Tokenizer offeres a potential solution in the form of \"\" tokens. However, this approach requires additional training, as the base model lacks knowledge of these tokens.\n", + "\n", + "In this demo, you will fine-tune the model with a simple prompt \"日本の珍しい名字\" which means \"Rare Japanese surnames\".\n", + "\n", + "```\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "```\n", + "\n", + "Based on your application's requirements, you must replace your custom vocabulary with unused tokens. Then, feed the converted training datasets to the model. During the tuning process, the base model will learn the patterns in your document and start generating unused tokens accordingly. Finally, before printing the output text, convert it back to your custom vocabulary.\n", + "\n", + "Due to the fixed size of the vocabulary, adding a significant number of new tokens is not feasible. To do so, you would need to resize the tokenizer and retrain the model, which is inefficient and not recommended unless absolutely necessary.\n", + "\n", + "\n", + " \n", + "
\n", + " Run in Google Colab\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nathXe5ebTQI" + }, + "source": [ + "## Setup\n", + "\n", + "### Select the Colab runtime\n", + "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model:\n", + "\n", + "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n", + "2. Select **Change runtime type**.\n", + "3. Under **Hardware accelerator**, select **T4 GPU**.\n", + "\n", + "\n", + "### Gemma setup on Kaggle\n", + "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n", + "\n", + "* Get access to Gemma on kaggle.com.\n", + "* Select a Colab runtime with sufficient resources to run the Gemma 2B model.\n", + "* Generate and configure a Kaggle username and API key.\n", + "\n", + "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2SNOwNf6OV28" + }, + "source": [ + "### Set environment variables\n", + "\n", + "Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HuGBBTQlbX7w" + }, + "outputs": [], + "source": [ + "import os\n", + "from google.colab import userdata, drive\n", + "\n", + "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", + "# vars as appropriate for your system.\n", + "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n", + "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")\n", + "\n", + "# Mounting gDrive for to store artifacts\n", + "drive.mount(\"/content/drive\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1gaYpiL-bgWZ" + }, + "source": [ + "### Install dependencies\n", + "\n", + "Install Keras and KerasNLP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CcMRCmR4OVdx" + }, + "outputs": [], + "source": [ + "!pip install -q -U keras-nlp datasets\n", + "!pip install -q -U keras\n", + "\n", + "# Set the backbend before importing Keras\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "# Avoid memory fragmentation on JAX backend.\n", + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\"\n", + "\n", + "import keras_nlp\n", + "import keras\n", + "\n", + "# Run at half precision.\n", + "#keras.config.set_floatx(\"bfloat16\")\n", + "\n", + "# Training Configurations\n", + "token_limit = 128\n", + "num_data_limit = 99\n", + "lora_name = \"my_lora\"\n", + "lora_rank = 4\n", + "lr_value = 5e-4\n", + "train_epoch = 5\n", + "model_id = \"gemma2_instruct_2b_en\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0DSjJE4aOKc7" + }, + "source": [ + "## Identify the presence of Surnames in the Tokenizer's vocabulary\n", + "\n", + "Regarding the base Gemma Tokenizer, see below that it includes frequently used names, such as \"佐藤\" and \"鈴木\", within its vocabulary. However, for less common names, the tokenizer begins to split them into multiple tokens, like below:\n", + "\n", + "```\n", + "235585 -> 小\n", + "240763 -> 嶋\n", + " 4758 -> さん\n", + "```\n", + "\n", + "It's getting worse if you go very rare surnames like below:\n", + "\n", + "```\n", + "235771 -> 加\n", + " 450 -> <0xE9>\n", + " 389 -> <0xAC>\n", + " 391 -> <0xAE>\n", + " 4758 -> さん\n", + "```\n", + "\n", + "In the case of \"加鬮\", tokenizer typically splits this character into its UTF-8 representation. This is because the tokenizer hasn't encountered this character frequently during the training process. If you fine-tune the model, it will significantly improve efficiency. For example \"加鬮\", 4 tokens can be reduced to 1 token." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Wp2aaKJ8OE-y" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 108 56985 32012 235465 78638 4758 235394 79900 4758 235394\n", + " 122224 4758 235394 98647 4758 235394 119177 4758 235394 169944\n", + " 4758 235394 113698 4758 235394 119392 4758 235394 117774 4758\n", + " 235394 140762 4758 108 235308 235276 235274 235893 235465 235585\n", + " 240763 4758 235394 239713 235493 4758 235394 235469 236872 4758\n", + " 235394 241122 235722 4758 235394 238303 235841 4758 235394 157967\n", + " 4758 235394 238803 238313 4758 235394 235861 235875 4758 235394\n", + " 236447 236063 4758 235394 240052 236872 4758 108 235304 235324\n", + " 235276 235276 235274 235893 235465 235661 236316 4758 235394 238881\n", + " 235842 4758 235394 235469 239319 4758 235394 237338 235502 4758\n", + " 235394 237550 236228 4758 235394 235648 235858 4758 235394 243141\n", + " 238022 4758 235394 235771 450 389 391 4758 235394 236497\n", + " 241335 4758 235394 235990 242446 4758 108]\n", + " 108 -> \n", + "\n", + " 56985 -> トップ\n", + " 32012 -> テン\n", + "235465 -> :\n", + " 78638 -> 佐藤\n", + " 4758 -> さん\n", + "235394 -> 、\n", + " 79900 -> 鈴木\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "122224 -> 高橋\n", + " 4758 -> さん\n", + "235394 -> 、\n", + " 98647 -> 田中\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "119177 -> 伊藤\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "169944 -> 渡辺\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "113698 -> 山本\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "119392 -> 中村\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "117774 -> 小林\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "140762 -> 加藤\n", + " 4758 -> さん\n", + " 108 -> \n", + "\n", + "235308 -> 5\n", + "235276 -> 0\n", + "235274 -> 1\n", + "235893 -> 位\n", + "235465 -> :\n", + "235585 -> 小\n", + "240763 -> 嶋\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "239713 -> 畑\n", + "235493 -> 中\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "235469 -> 大\n", + "236872 -> 井\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "241122 -> 磯\n", + "235722 -> 部\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "238303 -> 浅\n", + "235841 -> 見\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "157967 -> 秋田\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "238803 -> 芳\n", + "238313 -> 賀\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "235861 -> 相\n", + "235875 -> 原\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "236447 -> 細\n", + "236063 -> 田\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "240052 -> 坪\n", + "236872 -> 井\n", + " 4758 -> さん\n", + " 108 -> \n", + "\n", + "235304 -> 3\n", + "235324 -> 7\n", + "235276 -> 0\n", + "235276 -> 0\n", + "235274 -> 1\n", + "235893 -> 位\n", + "235465 -> :\n", + "235661 -> 法\n", + "236316 -> 川\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "238881 -> 乙\n", + "235842 -> 間\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "235469 -> 大\n", + "239319 -> 舌\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "237338 -> 巻\n", + "235502 -> 上\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "237550 -> 戸\n", + "236228 -> 住\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "235648 -> 前\n", + "235858 -> 更\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "243141 -> 梶\n", + "238022 -> 浜\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "235771 -> 加\n", + " 450 -> �\n", + " 389 -> �\n", + " 391 -> �\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "236497 -> 千\n", + "241335 -> 艘\n", + " 4758 -> さん\n", + "235394 -> 、\n", + "235990 -> 西\n", + "242446 -> 胤\n", + " 4758 -> さん\n", + " 108 -> \n", + "\n" + ] + } + ], + "source": [ + "tokenizer = keras_nlp.models.GemmaTokenizer.from_preset(model_id)\n", + "\n", + "def detoken(tokens):\n", + " print(tokens)\n", + " for x in tokens:\n", + " word = tokenizer.detokenize([x])\n", + " print(f\"{x:6} -> {word}\")\n", + "\n", + "detoken(tokenizer(\"\"\"\n", + "トップテン:佐藤さん、鈴木さん、高橋さん、田中さん、伊藤さん、渡辺さん、山本さん、中村さん、小林さん、加藤さん\n", + "501位:小嶋さん、畑中さん、大井さん、磯部さん、浅見さん、秋田さん、芳賀さん、相原さん、細田さん、坪井さん\n", + "37001位:法川さん、乙間さん、大舌さん、巻上さん、戸住さん、前更さん、梶浜さん、加鬮さん、千艘さん、西胤さん\n", + "\"\"\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pMFgoID3FAcJ" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<0xE9>\n", + "<0xAC>\n", + "<0xAE>\n" + ] + } + ], + "source": [ + "print(tokenizer.id_to_token(450))\n", + "print(tokenizer.id_to_token(389))\n", + "print(tokenizer.id_to_token(391))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ggigihbgvkLc" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '<2mass>',\n", + " '[@BOS@]',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '\\n',\n", + " '\\n\\n',\n", + " '\\n\\n\\n',\n", + " '\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n',\n", + " '▁▁',\n", + " '▁▁▁',\n", + " '▁▁▁▁',\n", + " '▁▁▁▁▁',\n", + " '▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '
',\n", + " '
',\n", + " '',\n", + " '
',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '

',\n", + " '

',\n", + " '

',\n", + " '

',\n", + " '

',\n", + " '
',\n", + " '
',\n", + " '
',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '',\n", + " '<0x00>',\n", + " '<0x01>',\n", + " '<0x02>',\n", + " '<0x03>',\n", + " '<0x04>',\n", + " '<0x05>',\n", + " '<0x06>',\n", + " '<0x07>',\n", + " '<0x08>',\n", + " '<0x09>',\n", + " '<0x0A>',\n", + " '<0x0B>',\n", + " '<0x0C>',\n", + " '<0x0D>',\n", + " '<0x0E>',\n", + " '<0x0F>',\n", + " '<0x10>',\n", + " '<0x11>',\n", + " '<0x12>',\n", + " '<0x13>',\n", + " '<0x14>',\n", + " '<0x15>',\n", + " '<0x16>',\n", + " '<0x17>',\n", + " '<0x18>',\n", + " '<0x19>',\n", + " '<0x1A>',\n", + " '<0x1B>',\n", + " '<0x1C>',\n", + " '<0x1D>',\n", + " '<0x1E>',\n", + " '<0x1F>',\n", + " '<0x20>',\n", + " '<0x21>',\n", + " '<0x22>',\n", + " '<0x23>',\n", + " '<0x24>',\n", + " '<0x25>',\n", + " '<0x26>',\n", + " '<0x27>',\n", + " '<0x28>',\n", + " '<0x29>',\n", + " '<0x2A>',\n", + " '<0x2B>',\n", + " '<0x2C>',\n", + " '<0x2D>',\n", + " '<0x2E>',\n", + " '<0x2F>',\n", + " '<0x30>',\n", + " '<0x31>',\n", + " '<0x32>',\n", + " '<0x33>',\n", + " '<0x34>',\n", + " '<0x35>',\n", + " '<0x36>',\n", + " '<0x37>',\n", + " '<0x38>',\n", + " '<0x39>',\n", + " '<0x3A>',\n", + " '<0x3B>',\n", + " '<0x3C>',\n", + " '<0x3D>',\n", + " '<0x3E>',\n", + " '<0x3F>',\n", + " '<0x40>',\n", + " '<0x41>',\n", + " '<0x42>',\n", + " '<0x43>',\n", + " '<0x44>',\n", + " '<0x45>',\n", + " '<0x46>',\n", + " '<0x47>',\n", + " '<0x48>',\n", + " '<0x49>',\n", + " '<0x4A>',\n", + " '<0x4B>',\n", + " '<0x4C>',\n", + " '<0x4D>',\n", + " '<0x4E>',\n", + " '<0x4F>',\n", + " '<0x50>',\n", + " '<0x51>',\n", + " '<0x52>',\n", + " '<0x53>',\n", + " '<0x54>',\n", + " '<0x55>',\n", + " '<0x56>',\n", + " '<0x57>',\n", + " '<0x58>',\n", + " '<0x59>',\n", + " '<0x5A>',\n", + " '<0x5B>',\n", + " '<0x5C>',\n", + " '<0x5D>',\n", + " '<0x5E>',\n", + " '<0x5F>',\n", + " '<0x60>',\n", + " '<0x61>',\n", + " '<0x62>',\n", + " '<0x63>',\n", + " '<0x64>',\n", + " '<0x65>',\n", + " '<0x66>',\n", + " '<0x67>',\n", + " '<0x68>',\n", + " '<0x69>',\n", + " '<0x6A>',\n", + " '<0x6B>',\n", + " '<0x6C>',\n", + " '<0x6D>',\n", + " '<0x6E>',\n", + " '<0x6F>',\n", + " '<0x70>',\n", + " '<0x71>',\n", + " '<0x72>',\n", + " '<0x73>',\n", + " '<0x74>',\n", + " '<0x75>',\n", + " '<0x76>',\n", + " '<0x77>',\n", + " '<0x78>',\n", + " '<0x79>',\n", + " '<0x7A>',\n", + " '<0x7B>',\n", + " '<0x7C>',\n", + " '<0x7D>',\n", + " '<0x7E>',\n", + " '<0x7F>',\n", + " '<0x80>',\n", + " '<0x81>',\n", + " '<0x82>',\n", + " '<0x83>',\n", + " '<0x84>',\n", + " '<0x85>',\n", + " '<0x86>',\n", + " '<0x87>',\n", + " '<0x88>',\n", + " '<0x89>',\n", + " '<0x8A>',\n", + " '<0x8B>',\n", + " '<0x8C>',\n", + " '<0x8D>',\n", + " '<0x8E>',\n", + " '<0x8F>',\n", + " '<0x90>',\n", + " '<0x91>',\n", + " '<0x92>',\n", + " '<0x93>',\n", + " '<0x94>',\n", + " '<0x95>',\n", + " '<0x96>',\n", + " '<0x97>',\n", + " '<0x98>',\n", + " '<0x99>',\n", + " '<0x9A>',\n", + " '<0x9B>',\n", + " '<0x9C>',\n", + " '<0x9D>',\n", + " '<0x9E>',\n", + " '<0x9F>',\n", + " '<0xA0>',\n", + " '<0xA1>',\n", + " '<0xA2>',\n", + " '<0xA3>',\n", + " '<0xA4>',\n", + " '<0xA5>',\n", + " '<0xA6>',\n", + " '<0xA7>',\n", + " '<0xA8>',\n", + " '<0xA9>',\n", + " '<0xAA>',\n", + " '<0xAB>',\n", + " '<0xAC>',\n", + " '<0xAD>',\n", + " '<0xAE>',\n", + " '<0xAF>',\n", + " '<0xB0>',\n", + " '<0xB1>',\n", + " '<0xB2>',\n", + " '<0xB3>',\n", + " '<0xB4>',\n", + " '<0xB5>',\n", + " '<0xB6>',\n", + " '<0xB7>',\n", + " '<0xB8>',\n", + " '<0xB9>',\n", + " '<0xBA>',\n", + " '<0xBB>',\n", + " '<0xBC>',\n", + " '<0xBD>',\n", + " '<0xBE>',\n", + " '<0xBF>',\n", + " '<0xC0>',\n", + " '<0xC1>',\n", + " '<0xC2>',\n", + " '<0xC3>',\n", + " '<0xC4>',\n", + " '<0xC5>',\n", + " '<0xC6>',\n", + " '<0xC7>',\n", + " '<0xC8>',\n", + " '<0xC9>',\n", + " '<0xCA>',\n", + " '<0xCB>',\n", + " '<0xCC>',\n", + " '<0xCD>',\n", + " '<0xCE>',\n", + " '<0xCF>',\n", + " '<0xD0>',\n", + " '<0xD1>',\n", + " '<0xD2>',\n", + " '<0xD3>',\n", + " '<0xD4>',\n", + " '<0xD5>',\n", + " '<0xD6>',\n", + " '<0xD7>',\n", + " '<0xD8>',\n", + " '<0xD9>',\n", + " '<0xDA>',\n", + " '<0xDB>',\n", + " '<0xDC>',\n", + " '<0xDD>',\n", + " '<0xDE>',\n", + " '<0xDF>',\n", + " '<0xE0>',\n", + " '<0xE1>',\n", + " '<0xE2>',\n", + " '<0xE3>',\n", + " '<0xE4>',\n", + " '<0xE5>',\n", + " '<0xE6>',\n", + " '<0xE7>',\n", + " '<0xE8>',\n", + " '<0xE9>',\n", + " '<0xEA>',\n", + " '<0xEB>',\n", + " '<0xEC>',\n", + " '<0xED>',\n", + " '<0xEE>',\n", + " '<0xEF>',\n", + " '<0xF0>',\n", + " '<0xF1>',\n", + " '<0xF2>',\n", + " '<0xF3>',\n", + " '<0xF4>',\n", + " '<0xF5>',\n", + " '<0xF6>',\n", + " '<0xF7>',\n", + " '<0xF8>',\n", + " '<0xF9>',\n", + " '<0xFA>',\n", + " '<0xFB>',\n", + " '<0xFC>',\n", + " '<0xFD>',\n", + " '<0xFE>',\n", + " '<0xFF>',\n", + " 'in',\n", + " '▁t',\n", + " 'er',\n", + " '▁a',\n", + " 'on',\n", + " 're',\n", + " 'en',\n", + " 'he',\n", + " 'an',\n", + " 'at',\n", + " 'or',\n", + " 'es',\n", + " '▁s',\n", + " 'ar',\n", + " 'ti',\n", + " 'te',\n", + " 'th',\n", + " 'st',\n", + " 'nd',\n", + " 'al',\n", + " '▁o',\n", + " 'le',\n", + " 'de',\n", + " '▁i',\n", + " 'se',\n", + " '▁c',\n", + " '▁d',\n", + " 'it',\n", + " 'nt',\n", + " 'is',\n", + " '▁p',\n", + " 'me',\n", + " 'ri',\n", + " 'ra',\n", + " 'ou',\n", + " 'as',\n", + " 'ed',\n", + " 'ne',\n", + " 'to',\n", + " 'ng',\n", + " '▁w',\n", + " 'ro',\n", + " 'li',\n", + " 'ta',\n", + " '▁f',\n", + " '▁b',\n", + " '▁m',\n", + " 'ic',\n", + " 'el',\n", + " 'la',\n", + " 'et',\n", + " 've',\n", + " 'ur',\n", + " '▁e',\n", + " 'ha',\n", + " 'co',\n", + " 'll',\n", + " 'ch',\n", + " '▁h',\n", + " 'ce',\n", + " '▁l',\n", + " 'ma',\n", + " 'ea',\n", + " 'io',\n", + " 'om',\n", + " 'tr',\n", + " 'id',\n", + " 'il',\n", + " 'ge',\n", + " 'si',\n", + " 'di',\n", + " 'hi',\n", + " 'lo',\n", + " 'ec',\n", + " 'ie',\n", + " '▁r',\n", + " 'un',\n", + " 'ac',\n", + " 'ct',\n", + " '▁n',\n", + " 'us',\n", + " 'pe',\n", + " 'be',\n", + " 'na',\n", + " 'ca',\n", + " 'ns',\n", + " 'of',\n", + " 'ut',\n", + " 'ol',\n", + " 'ot',\n", + " 'am',\n", + " 'ss',\n", + " 'em',\n", + " 'ad',\n", + " 'po',\n", + " 'os',\n", + " 'pa',\n", + " '▁S',\n", + " 'im',\n", + " 'ho',\n", + " '▁the',\n", + " 'ing',\n", + " '▁in',\n", + " '▁of',\n", + " '▁to',\n", + " '▁and',\n", + " 'ent',\n", + " '▁th',\n", + " '▁de',\n", + " '▁re',\n", + " '▁g',\n", + " '▁T',\n", + " '▁C',\n", + " '▁A',\n", + " 'ion',\n", + " 'tion',\n", + " '▁=',\n", + " '▁I',\n", + " '▁(',\n", + " 'qu',\n", + " '▁v',\n", + " ');',\n", + " '▁M',\n", + " '▁P',\n", + " '▁y',\n", + " 'ver',\n", + " '▁B',\n", + " 'um',\n", + " 'ul',\n", + " 'ig',\n", + " '▁is',\n", + " '▁for',\n", + " 'ly',\n", + " '▁u',\n", + " 'ate',\n", + " '▁D',\n", + " 'ow',\n", + " 'ter',\n", + " '▁on',\n", + " '▁{',\n", + " '▁wi',\n", + " '▁be',\n", + " 'end',\n", + " 'ir',\n", + " 'ts',\n", + " 'ers',\n", + " 'ue',\n", + " 'hat',\n", + " 'ati',\n", + " 'ine',\n", + " 'ck',\n", + " 'res',\n", + " '▁R',\n", + " 'ay',\n", + " '=\"',\n", + " 'gh',\n", + " '▁L',\n", + " 'ons',\n", + " 'men',\n", + " '▁con',\n", + " '▁F',\n", + " '//',\n", + " 'int',\n", + " 'cti',\n", + " '▁E',\n", + " '▁k',\n", + " 'and',\n", + " '▁H',\n", + " 'ort',\n", + " 'our',\n", + " 'od',\n", + " 'est',\n", + " '()',\n", + " '▁N',\n", + " '▁W',\n", + " 'if',\n", + " '▁*',\n", + " 'ht',\n", + " 'The',\n", + " 'ith',\n", + " '▁G',\n", + " 'ere',\n", + " 'ke',\n", + " '▁pro',\n", + " 'ub',\n", + " 'pp',\n", + " '▁en',\n", + " 'ble',\n", + " '▁ha',\n", + " '\\\\[',\n", + " '\\\\]',\n", + " '▁\"',\n", + " '▁it',\n", + " 'oo',\n", + " 'ci',\n", + " 'all',\n", + " 'ess',\n", + " 'ab',\n", + " '▁an',\n", + " 'ret',\n", + " 'ation',\n", + " '▁that',\n", + " '▁with',\n", + " 'ction',\n", + " 'ment',\n", + " '▁j',\n", + " '▁ne',\n", + " '▁st',\n", + " '▁com',\n", + " '▁me',\n", + " '▁la',\n", + " 'ity',\n", + " '▁as',\n", + " 'tem',\n", + " '▁O',\n", + " '**',\n", + " '▁or',\n", + " 'urn',\n", + " 'ted',\n", + " '▁you',\n", + " '▁he',\n", + " 'ist',\n", + " 'rom',\n", + " '▁at',\n", + " '▁$',\n", + " '▁ex',\n", + " '▁se',\n", + " '▁qu',\n", + " 'mp',\n", + " '▁li',\n", + " 'line',\n", + " 'per',\n", + " '▁mo',\n", + " 'ld',\n", + " 'port',\n", + " '▁are',\n", + " '▁le',\n", + " '}\\r',\n", + " 'ure',\n", + " '▁so',\n", + " '▁J',\n", + " '▁The',\n", + " 'op',\n", + " 'iz',\n", + " '▁al',\n", + " '▁ma',\n", + " 'ass',\n", + " 'lic',\n", + " 'xt',\n", + " 'ore',\n", + " '..',\n", + " '▁&',\n", + " 'ite',\n", + " '/*',\n", + " '',\n", + " '▁po',\n", + " 'ex',\n", + " 'mat',\n", + " 'las',\n", + " 'return',\n", + " '▁from',\n", + " '/**',\n", + " 'hline',\n", + " \"▁'\",\n", + " 'gin',\n", + " 'ans',\n", + " '▁not',\n", + " 'rou',\n", + " 'av',\n", + " '▁we',\n", + " 'ata',\n", + " 'one',\n", + " 'cl',\n", + " 'ack',\n", + " '▁ch',\n", + " '},',\n", + " 'ther',\n", + " '▁have',\n", + " 'au',\n", + " '▁no',\n", + " '--',\n", + " 'ase',\n", + " 'ould',\n", + " 'ich',\n", + " '▁can',\n", + " '▁tr',\n", + " 'sed',\n", + " 'sp',\n", + " 'tic',\n", + " 'ud',\n", + " 'ote',\n", + " 'ublic',\n", + " 'ell',\n", + " '▁id',\n", + " 'ong',\n", + " 'import',\n", + " 'ep',\n", + " 'ect',\n", + " 'ks',\n", + " '▁ar',\n", + " '};',\n", + " 'ght',\n", + " 'oc',\n", + " 'sion',\n", + " 'get',\n", + " 'ous',\n", + " 'ru',\n", + " '();',\n", + " '▁el',\n", + " '),',\n", + " '\",',\n", + " 'mo',\n", + " '▁lo',\n", + " 'ran',\n", + " 'ary',\n", + " '▁*/',\n", + " 'cc',\n", + " 'ry',\n", + " '▁all',\n", + " 'ps',\n", + " 'ount',\n", + " '▁ad',\n", + " 'ph',\n", + " 'ork',\n", + " 'lass',\n", + " 'ound',\n", + " 'ated',\n", + " '▁ab',\n", + " '▁per',\n", + " 'lock',\n", + " 'ire',\n", + " 'ear',\n", + " ').',\n", + " '▁sp',\n", + " '▁des',\n", + " 'ER',\n", + " 'ven',\n", + " 'ance',\n", + " 'va',\n", + " 'lu',\n", + " 'red',\n", + " 'ult',\n", + " 'iew',\n", + " 'erv',\n", + " 'ally',\n", + " '});',\n", + " '*/',\n", + " '▁your',\n", + " 'ile',\n", + " 'ice',\n", + " 'che',\n", + " '▁:',\n", + " 'der',\n", + " '▁sa',\n", + " '▁z',\n", + " 'og',\n", + " 'fer',\n", + " '▁go',\n", + " 'com',\n", + " 'ON',\n", + " 'row',\n", + " '▁es',\n", + " 'ence',\n", + " '▁will',\n", + " '▁In',\n", + " 'public',\n", + " 'ign',\n", + " 'set',\n", + " 'ime',\n", + " 'this',\n", + " 'item',\n", + " '(\"',\n", + " 'In',\n", + " 'able',\n", + " '▁new',\n", + " 'ize',\n", + " '▁Y',\n", + " 'ull',\n", + " '▁[',\n", + " 'vi',\n", + " 'vel',\n", + " 'are',\n", + " 'ont',\n", + " 'ast',\n", + " '▁п',\n", + " '▁ro',\n", + " 'ple',\n", + " '▁but',\n", + " 'begin',\n", + " 'ations',\n", + " '▁Th',\n", + " '__',\n", + " '▁uid',\n", + " '▁que',\n", + " '▁up',\n", + " 'tring',\n", + " '▁co',\n", + " 'ight',\n", + " 'ст',\n", + " 'div',\n", + " 'ide',\n", + " 'ces',\n", + " '▁man',\n", + " '▁us',\n", + " 'ial',\n", + " '▁has',\n", + " \"',\",\n", + " '▁out',\n", + " '▁ا',\n", + " 'block',\n", + " 'ies',\n", + " 'lay',\n", + " '▁his',\n", + " 'section',\n", + " 'IN',\n", + " 'ory',\n", + " '▁und',\n", + " 'so',\n", + " 'sel',\n", + " 'ook',\n", + " 'val',\n", + " 'ft',\n", + " 'ill',\n", + " 'log',\n", + " 'ption',\n", + " 'ded',\n", + " 'ni',\n", + " '▁с',\n", + " '▁comp',\n", + " '▁sh',\n", + " 'ree',\n", + " 'quote',\n", + " 'Th',\n", + " '▁get',\n", + " '▁which',\n", + " '->',\n", + " 'ater',\n", + " '▁pid',\n", + " 'ens',\n", + " '▁pre',\n", + " '▁.',\n", + " '...',\n", + " 'no',\n", + " 'iv',\n", + " 'Con',\n", + " '▁ti',\n", + " 'pri',\n", + " 'form',\n", + " '{\\r',\n", + " '▁+',\n", + " 'par',\n", + " 'ten',\n", + " 'we',\n", + " 'ach',\n", + " '▁<',\n", + " 'wer',\n", + " '▁my',\n", + " 'ond',\n", + " '▁ver',\n", + " 'blockquote',\n", + " '▁one',\n", + " '▁в',\n", + " 'ors',\n", + " 'mer',\n", + " '▁more',\n", + " 'ink',\n", + " 'EN',\n", + " 'на',\n", + " 'les',\n", + " ';\\r',\n", + " '▁they',\n", + " 'ain',\n", + " 'ystem',\n", + " 'Re',\n", + " 'ES',\n", + " '▁dis',\n", + " 'ents',\n", + " 'ild',\n", + " '▁au',\n", + " 'RE',\n", + " '_{',\n", + " 'mb',\n", + " 'ни',\n", + " '▁St',\n", + " 'ра',\n", + " 'ms',\n", + " ...]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.get_vocabulary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9VaAVcrpb1fI" + }, + "source": [ + "## Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "S9WifD8eb57U" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                                                                                     Config ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ gemma_tokenizer (GemmaTokenizer)                              │                      Vocab size: 256,000 │\n",
+              "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Config\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ Vocab size: \u001b[38;5;34m256,000\u001b[0m │\n", + "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ gemma_backbone                │ (None, None, 2304)        │   2,614,341,888 │ padding_mask[0][0],        │\n",
+              "│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ token_embedding               │ (None, None, 256000)      │     589,824,000 │ gemma_backbone[0][0]       │\n",
+              "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
+              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,614,341,888\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 2,614,341,888 (9.74 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 2,614,341,888 (9.74 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "Here are some Japanese names that are considered uncommon or unique:\n", + "\n", + "**Nature-Inspired:**\n", + "\n", + "* **Aoi:** (青) Blue, a beautiful color often associated with the sky and water.\n", + "* **Kiku:** (菊) Chrysanthemum, a symbol of longevity and beauty.\n", + "* **Sakura:** (桜) Cherry blossom, a symbol of spring and renewal.\n", + "* **Tsuki:** (月) Moon, a celestial body that inspires wonder and mystery.\n", + "* **Yume:** (夢) Dream, a symbol of hope and imagination.\n", + "TOTAL TIME ELAPSED: 25.95s\n", + "\n", + "Gemma output:\n", + "user\n", + "日本の平凡な名字\n", + "model\n", + "Here are some examples of common Japanese surnames, often considered \"ordinary\" or \"unremarkable\":\n", + "\n", + "**Common and Neutral:**\n", + "\n", + "* **Yamada:** (山田) - Meaning \"mountain field\"\n", + "* **Suzuki:** (鈴木) - Meaning \"small, humble\"\n", + "* **Nakamura:** (中村) - Meaning \"village of the mountain\"\n", + "* **Kato:** (加藤) - Meaning \"a person who lives in a village\"\n", + "* **Tanaka:** (田中) - Meaning \"a person who lives\n", + "TOTAL TIME ELAPSED: 5.38s\n" + ] + } + ], + "source": [ + "import keras\n", + "import keras_nlp\n", + "\n", + "import time\n", + "\n", + "gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(model_id)\n", + "gemma_lm.summary()\n", + "\n", + "tick_start = 0\n", + "\n", + "def tick():\n", + " global tick_start\n", + " tick_start = time.time()\n", + "\n", + "def tock():\n", + " print(f\"TOTAL TIME ELAPSED: {time.time() - tick_start:.2f}s\")\n", + "\n", + "def text_gen(prompt):\n", + " tick()\n", + " input = f\"user\\n{prompt}\\nmodel\\n\"\n", + " output = gemma_lm.generate(input, max_length=token_limit)\n", + " print(\"\\nGemma output:\")\n", + " print(output)\n", + " tock()\n", + "\n", + "import re\n", + "\n", + "def text_gen_with_dict(prompt, dictionary):\n", + " tick()\n", + " input = f\"user\\n{prompt}\\nmodel\\n\"\n", + " output = gemma_lm.generate(input, max_length=token_limit)\n", + " pattern = '|'.join(sorted(re.escape(k) for k in dictionary))\n", + " print(\"-\"*80)\n", + " detoken(tokenizer(output))\n", + " output = re.sub(pattern, lambda m: dictionary.get(m.group(0)), output)\n", + " print(\"\\nGemma output:\")\n", + " print(output)\n", + " print(\"-\"*80)\n", + " tock()\n", + "\n", + "# inference before fine-tuning\n", + "text_gen(\"日本の珍しい名字\")\n", + "text_gen(\"日本の平凡な名字\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8rXG-F4VdQQu" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'': '内大久保', '': '諏訪戸', '': '宮艸', '': '宝坂', '': '埜上', '': '篠垣', '': '池呂', '': '奥須賀', '': '勝居', '': '真喜', '': '高河原', '': '小盛', '': '溜口', '': '伊勢井', '': '落久保', '': '志渡澤', '': '嘉多', '': '板藤', '': '南波留', '': '風ん', '': '采尾', '': '一上', '': '本邑', '': '周本', '': '畠尾', '': '鳥水', '': '下小薗', '': '阪辺', '': '稲味', '': '武浪', '': '安楽城', '': '江古', '': '賎機', '': '鶴原谷', '': '西明寺', '': '布塚', '': '寺端', '': '炭吉', '': '管生', '': '村片', '': '昌司', '': '伊秩', '': '後道', '': '佐保井', '': '神尊', '': '為積', '': '聖川', '': '登喜', '': '弦田', '': '犬居'}\n", + "70\n", + "user\n", + "日本の平凡な名字\n", + "model\n", + "山本さん\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "user\n", + "日本の平凡な名字\n", + "model\n", + "松本さん\n" + ] + } + ], + "source": [ + "# example data\n", + "custom_vocab = [\n", + " \"内大久保\",\"諏訪戸\",\"宮艸\",\"宝坂\",\"埜上\",\"篠垣\",\"池呂\",\"奥須賀\",\"勝居\",\"真喜\",\n", + " \"高河原\",\"小盛\",\"溜口\",\"伊勢井\",\"落久保\",\"志渡澤\",\"嘉多\",\"板藤\",\"南波留\",\"風ん\",\n", + " \"采尾\",\"一上\",\"本邑\",\"周本\",\"畠尾\",\"鳥水\",\"下小薗\",\"阪辺\",\"稲味\",\"武浪\",\n", + " \"安楽城\",\"江古\",\"賎機\",\"鶴原谷\",\"西明寺\",\"布塚\",\"寺端\",\"炭吉\",\"管生\",\"村片\",\n", + " \"昌司\",\"伊秩\",\"後道\",\"佐保井\",\"神尊\",\"為積\",\"聖川\",\"登喜\",\"弦田\",\"犬居\",\n", + "]\n", + "my_dictionary = {f\"\":custom_vocab[i] for i in range(len(custom_vocab))}\n", + "print(my_dictionary)\n", + "\n", + "train = []\n", + "\n", + "for i in range(len(custom_vocab)):\n", + " item = f\"user\\n日本の珍しい名字\\nmodel\\nさん\"\n", + " length = len(tokenizer(item))\n", + " # skip data if the token length is longer than our limit\n", + " if length < token_limit:\n", + " train.append(item)\n", + " if(len(train)>=num_data_limit):\n", + " break\n", + "\n", + "# Add contrast examples (common Japanese surnames) to prevent overfitting\n", + "common_name = [\n", + " \"佐藤\",\"鈴木\",\"高橋\",\"田中\",\"伊藤\",\"渡辺\",\"山本\",\"中村\",\"小林\",\"加藤\",\n", + " \"吉田\",\"山田\",\"佐々木\",\"山口\",\"松本\",\"井上\",\"木村\",\"林\",\"斎藤\",\"清水\",\n", + "]\n", + "for x in common_name:\n", + " item = f\"user\\n日本の平凡な名字\\nmodel\\n{x}さん\"\n", + " length = len(tokenizer(item))\n", + " # skip data if the token length is longer than our limit\n", + " if length < token_limit:\n", + " train.append(item)\n", + " if(len(train)>=num_data_limit):\n", + " break\n", + "\n", + "import random\n", + "random.shuffle(train)\n", + "\n", + "print(len(train))\n", + "print(train[0])\n", + "print(train[1])\n", + "print(train[2])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NAxoEthfyye6" + }, + "source": [ + "## LoRA Fine-tuning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OQuVUnvpy1rL" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                                                                                     Config ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ gemma_tokenizer (GemmaTokenizer)                              │                      Vocab size: 256,000 │\n",
+              "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Config\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ Vocab size: \u001b[38;5;34m256,000\u001b[0m │\n", + "└───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ gemma_backbone                │ (None, None, 2304)        │   2,617,270,528 │ padding_mask[0][0],        │\n",
+              "│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │\n",
+              "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+              "│ token_embedding               │ (None, None, 256000)      │     589,824,000 │ gemma_backbone[0][0]       │\n",
+              "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
+              "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2304\u001b[0m) │ \u001b[38;5;34m2,617,270,528\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m589,824,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 2,617,270,528 (9.75 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,617,270,528\u001b[0m (9.75 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 2,928,640 (11.17 MB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,928,640\u001b[0m (11.17 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 2,614,341,888 (9.74 GB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,614,341,888\u001b[0m (9.74 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Enable LoRA for the model and set the LoRA rank to 4.\n", + "gemma_lm.backbone.enable_lora(rank=lora_rank)\n", + "gemma_lm.summary()\n", + "\n", + "# Limit the input sequence length (to control memory usage).\n", + "gemma_lm.preprocessor.sequence_length = token_limit\n", + "# Use AdamW (a common optimizer for transformer models).\n", + "optimizer = keras.optimizers.AdamW(\n", + " learning_rate=lr_value,\n", + " weight_decay=0.01,\n", + ")\n", + "# Exclude layernorm and bias terms from decay.\n", + "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", + "\n", + "gemma_lm.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=optimizer,\n", + " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nWaNHda-y65-" + }, + "source": [ + "Note that enabling LoRA reduces the number of trainable parameters significantly. In practice, we recommend beginning with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation.\n", + "\n", + "To monitor the learning progress, we will evaluate the model at the end of each epoch and save the all lora weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bI4VOfBQy30W" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3s/step - loss: 1.3578 - sparse_categorical_accuracy: 0.2155\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "TOTAL TIME ELAPSED: 22.53s\n", + "\n", + "Gemma output:\n", + "user\n", + "日本の平凡な名字\n", + "model\n", + "ⓧさん\n", + "TOTAL TIME ELAPSED: 0.40s\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m167s\u001b[0m 4s/step - loss: 1.3449 - sparse_categorical_accuracy: 0.2209\n", + "Epoch 2/5\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 735ms/step - loss: 0.2559 - sparse_categorical_accuracy: 0.7883\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "TOTAL TIME ELAPSED: 0.43s\n", + "\n", + "Gemma output:\n", + "user\n", + "日本の平凡な名字\n", + "model\n", + "井上さん\n", + "TOTAL TIME ELAPSED: 0.40s\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m56s\u001b[0m 783ms/step - loss: 0.2550 - sparse_categorical_accuracy: 0.7891\n", + "Epoch 3/5\n", + "\u001b[1m34/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 779ms/step - loss: 0.1715 - sparse_categorical_accuracy: 0.8500\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "TOTAL TIME ELAPSED: 0.41s\n", + "\n", + "Gemma output:\n", + "user\n", + "日本の平凡な名字\n", + "model\n", + "井上さん\n", + "TOTAL TIME ELAPSED: 0.41s\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m28s\u001b[0m 811ms/step - loss: 0.1708 - sparse_categorical_accuracy: 0.8501\n", + "Epoch 4/5\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 792ms/step - loss: 0.1328 - sparse_categorical_accuracy: 0.8413\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "TOTAL TIME ELAPSED: 0.42s\n", + "\n", + "Gemma output:\n", + "user\n", + "日本の平凡な名字\n", + "model\n", + "清水さん\n", + "TOTAL TIME ELAPSED: 0.41s\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m29s\u001b[0m 837ms/step - loss: 0.1325 - sparse_categorical_accuracy: 0.8414\n", + "Epoch 5/5\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 806ms/step - loss: 0.1024 - sparse_categorical_accuracy: 0.8451\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "TOTAL TIME ELAPSED: 0.42s\n", + "\n", + "Gemma output:\n", + "user\n", + "日本の平凡な名字\n", + "model\n", + "井上さん\n", + "TOTAL TIME ELAPSED: 0.41s\n", + "\u001b[1m35/35\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m30s\u001b[0m 851ms/step - loss: 0.1022 - sparse_categorical_accuracy: 0.8452\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA7cUlEQVR4nO3de3iU9Z3//9fMJDOTkGQgBCYcAuGgUBQSiCRFd1e7jaUrtdLTUrDiN1vtlkV/2ly7LVktrHZr7NZaewlbXFarq4tQj+2uFsW06NoiwQQqIuCBQ8JhJgngTAhkAjP3749Jhgw5kAmZzCHPx3Xdl+bO55753Nd0zKuf+32/b5NhGIYAAABixBzrCQAAgKGNMAIAAGKKMAIAAGKKMAIAAGKKMAIAAGKKMAIAAGKKMAIAAGKKMAIAAGIqJdYT6ItAIKCjR48qMzNTJpMp1tMBAAB9YBiGmpubNXbsWJnNPa9/JEQYOXr0qPLy8mI9DQAA0A/19fUaP358j79PiDCSmZkpKXgyWVlZMZ4NAADoC6/Xq7y8vNDf8Z4kRBjpuDSTlZVFGAEAIMFcrMSCAlYAABBThBEAABBThBEAABBThBEAABBThBEAABBThBEAABBT/Qoja9asUX5+vux2u0pKSlRdXd3j2LNnz+r+++/XlClTZLfbVVBQoE2bNvV7wgAAILlEHEY2btyo8vJyrVq1SrW1tSooKND8+fPV0NDQ7fh7771Xjz32mB599FF98MEH+u53v6uvfOUr2rFjxyVPHgAAJD6TYRhGJAeUlJRo7ty5Wr16taTgc2Py8vJ05513asWKFV3Gjx07Vvfcc4+WL18e2ve1r31NaWlpeuaZZ/r0nl6vVw6HQx6Ph6ZnAAAkiL7+/Y5oZaStrU01NTUqLS09/wJms0pLS7V169Zuj/H5fLLb7WH70tLS9Pbbb0fy1gAAIElFFEaamprk9/vldDrD9judTrlcrm6PmT9/vh5++GF99NFHCgQC2rx5s1588UUdO3asx/fx+Xzyer1hGwAASE5Rv5vmF7/4hS677DJNnz5dVqtVd9xxh8rKynp9lHBlZaUcDkdo44m9AAAkr4jCSE5OjiwWi9xud9h+t9ut3Nzcbo8ZNWqUXn75ZbW0tOjQoUPau3evMjIyNHny5B7fp6KiQh6PJ7TV19dHMs0+8QcM/c+fj6rsV9U65Ts34K8PAAD6JqIwYrVaVVRUpKqqqtC+QCCgqqoqzZs3r9dj7Xa7xo0bp3PnzumFF17QTTfd1ONYm80WekJvtJ7UazZJP9/8of6wr1G/3Xl0wF8fAAD0TcSXacrLy7Vu3To99dRT2rNnj5YtW6aWlhaVlZVJkpYuXaqKiorQ+G3btunFF1/U/v379X//93/64he/qEAgoO9///sDdxb9YDKZtLh4giTp2eq6mM4FAIChLCXSAxYtWqTGxkatXLlSLpdLhYWF2rRpU6iota6uLqwepLW1Vffee6/279+vjIwM3XDDDXr66ac1fPjwATuJ/vpa0Xj99LV92nXEo12HPZo53hHrKQEAMORE3GckFqLZZ+T/e3aHfvvno1pcPEGVX505oK8NAMBQFpU+I8loSUnwUs1vdx6hkBUAgBgY8mGkZFK2Jo8appY2P4WsAADEwJAPIyaTSUvaC1nXVx+K8WwAABh6hnwYkaSvzhkvq8Ws9494teuwJ9bTAQBgSCGMSMoeZtXfzAw2bWN1BACAwUUYadfRc+Q3O49SyAoAwCAijLTrKGQ93ebXb3YeifV0AAAYMggj7ToXstKRFQCAwUMY6eRrc8bLmhIsZH3v8Kexng4AAEMCYaSTEcOsuuHKYCErqyMAAAwOwsgFOheyNreejfFsAABIfoSRCxRPytaU9kLW3/6ZjqwAAEQbYeQCJpMptDqyfludEuA5ggAAJDTCSDc6Cll3H/Vq1xE6sgIAEE2EkW50LmRdv41CVgAAookw0oMlJRMlSb/9M4WsAABEE2GkB3PzR2jq6Iz2jqwUsgIAEC2EkR5QyAoAwOAgjPTia3PGyZpi1gfHvHrvMIWsAABEA2GkF8PTrVowc4wkOrICABAthJGL6LhUQyErAADRQRi5CApZAQCILsLIRVDICgBAdBFG+oBCVgAAoocw0gedC1npyAoAwMAijPTRkhIKWQEAiAbCSB9dNTFYyHrmrF8vU8gKAMCAIYz0kclk0hIKWQEAGHCEkQh8tb2Qdc8xr/5MISsAAAOCMBKB4elWfamjIyuFrAAADAjCSIQWdypk9VLICgDAJSOMROiqiSN0WXshKx1ZAQC4dP0KI2vWrFF+fr7sdrtKSkpUXV3d6/hHHnlE06ZNU1pamvLy8vS9731Pra2t/ZpwrNGRFQCAgRVxGNm4caPKy8u1atUq1dbWqqCgQPPnz1dDQ0O349evX68VK1Zo1apV2rNnjx5//HFt3LhR//zP/3zJk4+Vr84ZJxuFrAAADIiIw8jDDz+s22+/XWVlZZoxY4bWrl2r9PR0PfHEE92O/9Of/qRrrrlGS5YsUX5+vr7whS9o8eLFF11NiWfhHVkPxXg2AAAktojCSFtbm2pqalRaWnr+BcxmlZaWauvWrd0ec/XVV6umpiYUPvbv369XX31VN9xwQ4/v4/P55PV6w7Z409GR9X/+fIxCVgAALkFEYaSpqUl+v19OpzNsv9PplMvl6vaYJUuW6P7779df/MVfKDU1VVOmTNF1113X62WayspKORyO0JaXlxfJNAdFUedC1h1HYj0dAAASVtTvptmyZYseeOAB/fu//7tqa2v14osv6pVXXtGPfvSjHo+pqKiQx+MJbfX19dGeZsRMJlNodeS/KWQFAKDfUiIZnJOTI4vFIrfbHbbf7XYrNze322N++MMf6pZbbtFtt90mSZo5c6ZaWlr0ne98R/fcc4/M5q55yGazyWazRTK1mPjq7PF68Hd7tdfVrJ31n2r2hBGxnhIAAAknopURq9WqoqIiVVVVhfYFAgFVVVVp3rx53R5z+vTpLoHDYrFIUsKvJjjSU7VgVntH1mo6sgIA0B8RX6YpLy/XunXr9NRTT2nPnj1atmyZWlpaVFZWJklaunSpKioqQuNvvPFG/fKXv9SGDRt04MABbd68WT/84Q914403hkJJIut4eB6FrAAA9E9El2kkadGiRWpsbNTKlSvlcrlUWFioTZs2hYpa6+rqwlZC7r33XplMJt177706cuSIRo0apRtvvFE//vGPB+4sYqho4ghd7szQh+5T+s2OI7plXn6spwQAQEIxGQlwrcTr9crhcMjj8SgrKyvW0+niV388oPv+5wNNz83U7+76S5lMplhPCQCAmOvr32+eTTMAvjp7vGwp5lAhKwAA6DvCyADoXMi6fhuFrAAARIIwMkBu7ujI+t5Rec5QyAoAQF8RRgbInAnBQtbWswH9ZicdWQEA6CvCyAAxmUyh23zX05EVAIA+I4wMoK90KmTdQSErAAB9QhgZQI70VH1p1lhJFLICANBXhJEBtqQk+ITh/6WQFQCAPiGMDLA5E0ZomjNTrWcDenkHhawAAFwMYWSAmUwmLS4Oro48W00hKwAAF0MYiYKvzDlfyFpb92mspwMAQFwjjESBI+18Ieuz1RSyAgDQG8JIlCxp78hKISsAAL0jjETJnAnDKWQFAKAPCCNRYjKZQqsjdGQFAKBnhJEoWjh7nOypZu1zU8gKAEBPCCNR1LmQlY6sAAB0jzASZYuLOxWynqaQFQCACxFGomzOhOGanpsp37mAXtpxONbTAQAg7hBGoizYkTW4OvJsdT2FrAAAXIAwMgjCC1lPxno6AADEFcLIIAgvZK2P8WwAAIgvhJFBEtaRlUJWAABCCCODZHYehawAAHSHMDJIwjqyVtORFQCADoSRQXRTYbCQ9UP3KQpZAQBoRxgZRI60VN3YXsj633RkBQBAEmFk0C1uv1TzynvHKGQFAECEkUHXuZD1RQpZAQAgjAy2zoWsz1LICgAAYSQWOjqyfug+pZpDFLICAIa2foWRNWvWKD8/X3a7XSUlJaquru5x7HXXXSeTydRlW7BgQb8nneiy7OcLWddXU8gKABjaIg4jGzduVHl5uVatWqXa2loVFBRo/vz5amho6Hb8iy++qGPHjoW2999/XxaLRd/4xjcuefKJbAmFrAAASOpHGHn44Yd1++23q6ysTDNmzNDatWuVnp6uJ554otvx2dnZys3NDW2bN29Wenr6kA8jhRSyAgAgKcIw0tbWppqaGpWWlp5/AbNZpaWl2rp1a59e4/HHH9c3v/lNDRs2rMcxPp9PXq83bEs2JpNJN3d0ZN1GISsAYOiKKIw0NTXJ7/fL6XSG7Xc6nXK5XBc9vrq6Wu+//75uu+22XsdVVlbK4XCEtry8vEimmTBumj1OaakWfdRAISsAYOga1LtpHn/8cc2cOVPFxcW9jquoqJDH4wlt9fX1gzTDwZVlT9WNBWMkBVdHAAAYiiIKIzk5ObJYLHK73WH73W63cnNzez22paVFGzZs0Le//e2Lvo/NZlNWVlbYlqwWFwcv1fzvrmP69HRbjGcDAMDgiyiMWK1WFRUVqaqqKrQvEAioqqpK8+bN6/XY5557Tj6fT9/61rf6N9MkVZg3XJ8Zk6W2cwG9WHsk1tMBAGDQRXyZpry8XOvWrdNTTz2lPXv2aNmyZWppaVFZWZkkaenSpaqoqOhy3OOPP66FCxdq5MiRlz7rJEJHVgDAUJcS6QGLFi1SY2OjVq5cKZfLpcLCQm3atClU1FpXVyezOTzj7Nu3T2+//bZef/31gZl1krmpcKweeGWPPmo4pXcPndTc/OxYTwkAgEFjMhLg/4p7vV45HA55PJ6krR/5wfPvaeO79frq7HF6eFFhrKcDAMAl6+vfb55NEycWl1DICgAYmggjcaJgvEMzKGQFAAxBhJE4YTKZQqsj6ylkBQAMIYSROLKwcKzSUi36uL2QFQCAoYAwEkcy7an6csFYSXRkBQAMHYSRONPRc+QVClkBAEMEYSTOzOpUyPoChawAgCGAMBJn6MgKABhqCCNx6KbCsUq3BgtZtx+kkBUAkNwII3GocyHrs9UUsgIAkhthJE4tLj5fyHqyhUJWAEDyIozEqVnjHbpibHtH1h0UsgIAkhdhJE6ZTKbQ6sj6bYcoZAUAJC3CSBzrKGT9pLGFQlYAQNIijMSx8I6sh2I8GwAAooMwEuc6eo68+r6LQlYAQFIijMS5mePOF7K+UHs41tMBAGDAEUbiHB1ZAQDJjjCSAL5ccL6QtfrAiVhPBwCAAUUYSQCZ9lTdVEhHVgBAciKMJIiOniMUsgIAkg1hJEHMGj9cV46jkBUAkHwIIwkk1JGVQlYAQBIhjCSQmwrHKd1q0X4KWQEASYQwkkAybCmhQtb1FLICAJIEYSTBLCmeKEn63S4KWQEAyYEwkmBmjncEC1n9FLICAJIDYSQBdayOUMgKAEgGhJEE9OXCsRrWXsi6jUJWAECCI4wkoAxbir5cOE4SHVkBAImPMJKglrT3HPndLpdOUMgKAEhg/Qoja9asUX5+vux2u0pKSlRdXd3r+E8//VTLly/XmDFjZLPZdPnll+vVV1/t14QRNHO8QzPHOdTmD+hFClkBAAks4jCyceNGlZeXa9WqVaqtrVVBQYHmz5+vhoaGbse3tbXp+uuv18GDB/X8889r3759WrduncaNG3fJkx/q6MgKAEgGJiPCv2IlJSWaO3euVq9eLUkKBALKy8vTnXfeqRUrVnQZv3btWv30pz/V3r17lZqa2q9Jer1eORwOeTweZWVl9es1ktEp3zmV/PgNtbT5teE7n9VnJ4+M9ZQAAAjp69/viFZG2traVFNTo9LS0vMvYDartLRUW7du7faY3/72t5o3b56WL18up9OpK6+8Ug888ID8fn+P7+Pz+eT1esM2dNW5kHX9NgpZAQCJKaIw0tTUJL/fL6fTGbbf6XTK5XJ1e8z+/fv1/PPPy+/369VXX9UPf/hD/exnP9O//uu/9vg+lZWVcjgcoS0vLy+SaQ4pN5cEL9Vsep9CVgBAYor63TSBQECjR4/Wf/zHf6ioqEiLFi3SPffco7Vr1/Z4TEVFhTweT2irr6+P9jQT1pXjzheyvlBDISsAIPFEFEZycnJksVjkdrvD9rvdbuXm5nZ7zJgxY3T55ZfLYrGE9n3mM5+Ry+VSW1v3/0/eZrMpKysrbEPPlrSvjjxLISsAIAFFFEasVquKiopUVVUV2hcIBFRVVaV58+Z1e8w111yjjz/+WIFAILTvww8/1JgxY2S1Wvs5bXR2Y0F7R9amFr2zn46sAIDEEvFlmvLycq1bt05PPfWU9uzZo2XLlqmlpUVlZWWSpKVLl6qioiI0ftmyZTpx4oTuuusuffjhh3rllVf0wAMPaPny5QN3FkNchi1FN81uL2SlIysAIMGkRHrAokWL1NjYqJUrV8rlcqmwsFCbNm0KFbXW1dXJbD6fcfLy8vTaa6/pe9/7nmbNmqVx48bprrvu0g9+8IOBOwtoSfEErd9Wp9fed+n4KZ9GZthiPSUAAPok4j4jsUCfkb758uq39d5hj/75hun6zl9NifV0AABDXFT6jCC+dXRkfba6nkJWAEDCIIwkkS+3F7IeaGrR1v3HYz0dAAD6hDCSRIZ1KmR9tpreLACAxEAYSTJLijs6sh7T8VO+GM8GAICLI4wkmSvHOTRrvENn/YZeqKUjKwAg/hFGktASClkBAAmEMJKEbiwYqwxbCoWsAICEQBhJQsNsKbqpcKwkaf02OrICAOIbYSRJdfQceW23i0JWAEBcI4wkqSvHOVTQXsj6fA2FrACA+EUYSWLnO7LWUcgKAIhbhJEk1lHIevD4aW39hEJWAEB8IowksbBC1moKWQEA8YkwkuSWlJwvZG2ikBUAEIcII0nuirHnC1lfoJAVABCHCCNDQMfqyLPVdQoEKGQFAMQXwsgQ8KVZ5wtZ36EjKwAgzhBGhoBhthQtnB0sZP1vClkBAHGGMDJEdPQceZ1CVgBAnCGMDBFXjHWoIG84HVkBAHGHMDKELCnOkyRtoJAVABBHCCNDSFhHVgpZAQBxgjAyhKRbzxey0pEVABAvCCNDzJLiiZIoZAUAxA/CyBAzY2wWhawAgLhCGBmCbi6mIysAIH4QRoagLxWMUaYtRYcoZAUAxAHCyBAULGQdJ0lav41CVgBAbBFGhqiOjqyv7XapsZlCVgBA7BBGhqgZY7NUmDdc5wIUsgIAYoswMoQtKQmujmzYTiErACB2+hVG1qxZo/z8fNntdpWUlKi6urrHsU8++aRMJlPYZrfb+z1hDJwvzTpfyPqnTyhkBQDERsRhZOPGjSovL9eqVatUW1urgoICzZ8/Xw0NDT0ek5WVpWPHjoW2Q4cOXdKkMTDSrSn6ypxgIeuzdGQFAMRIxGHk4Ycf1u23366ysjLNmDFDa9euVXp6up544okejzGZTMrNzQ1tTqfzkiaNgUMhKwAg1iIKI21tbaqpqVFpaen5FzCbVVpaqq1bt/Z43KlTpzRx4kTl5eXppptu0u7du3t9H5/PJ6/XG7YhOj4zJkuzJ1DICgCInYjCSFNTk/x+f5eVDafTKZfL1e0x06ZN0xNPPKHf/OY3euaZZxQIBHT11Vfr8OGe//BVVlbK4XCEtry8vEimiQgtpiMrACCGon43zbx587R06VIVFhbq2muv1YsvvqhRo0bpscce6/GYiooKeTye0FZfXx/taQ5pN84aq0x7iupOUMgKABh8EYWRnJwcWSwWud3usP1ut1u5ubl9eo3U1FTNnj1bH3/8cY9jbDabsrKywjZET5rVoq90dGStprgYADC4IgojVqtVRUVFqqqqCu0LBAKqqqrSvHnz+vQafr9fu3bt0pgxYyKbKaKqo+fI67vdFLICAAZVxJdpysvLtW7dOj311FPas2ePli1bppaWFpWVlUmSli5dqoqKitD4+++/X6+//rr279+v2tpafetb39KhQ4d02223DdxZ4JJNzz1fyPpcDZfFAACDJyXSAxYtWqTGxkatXLlSLpdLhYWF2rRpU6iota6uTmbz+Yxz8uRJ3X777XK5XBoxYoSKior0pz/9STNmzBi4s8CAWFI8QTvqPtWG6np996+myGw2xXpKAIAhwGQYRtzfPuH1euVwOOTxeKgfiaIzbX4VP/CGmlvP6elvF+svLxsV6ykBABJYX/9+82wahKRZLfrqbDqyAgAGF2EEYRZ3KmRtaG6N8WwAAEMBYQRhpudmaQ4dWQEAg4gwgi46OrJuqK6nIysAIOoII+jiS506sv7xk6ZYTwcAkOQII+iicyHr+m0UsgIAooswgm51FLJu/oBCVgBAdBFG0K3OhazPvUshKwAgeggj6NGSkomSpA3b6yhkBQBEDWEEPVowc4wy7SmqP3FGb39MISsAIDoII+hRmtWir80ZL4mOrACA6CGMoFcdPUcoZAUARAthBL2alpupookjKGQFAEQNYQQXFerISiErACAKCCO4qC/NGqMsClkBAFFCGMFF2VMt+mp7ISsdWQEAA40wgj7puFTzxh63GrwUsgIABg5hBH0SVshaQyErAGDgEEbQZ0vaV0eeraaQFQAwcAgj6LMF7YWsh0+e0f9RyAoAGCCEEfRZ50LWZylkBQAMEMIIIrKkpL0jK4WsAIABQhhBRC53ZuqqiSPkp5AVADBACCOI2GIKWQEAA4gwgohRyAoAGEiEEUQsvCProRjPBgCQ6Agj6JeOQtY39jRQyAoAuCSEEfRL50LWX79bH+vpAAASGGEE/daxOvJsdT2FrACAfiOMoN9umDlGjrRUHfn0jN76qDHW0wEAJCjCCPotWMg6TlLwNl8AAPqjX2FkzZo1ys/Pl91uV0lJiaqrq/t03IYNG2QymbRw4cL+vC3iUMfD897Y0yA3hawAgH6IOIxs3LhR5eXlWrVqlWpra1VQUKD58+eroaGh1+MOHjyof/zHf9Rf/uVf9nuyiD+XOTM1N7+9kHU7hawAgMhFHEYefvhh3X777SorK9OMGTO0du1apaen64knnujxGL/fr5tvvln33XefJk+efEkTRvzp6Mi6YXu9/BSyAgAiFFEYaWtrU01NjUpLS8+/gNms0tJSbd26tcfj7r//fo0ePVrf/va3+/Q+Pp9PXq83bEP8opAVAHApIgojTU1N8vv9cjqdYfudTqdcLle3x7z99tt6/PHHtW7duj6/T2VlpRwOR2jLy8uLZJoYZGGFrNsoZAUARCaqd9M0Nzfrlltu0bp165STk9Pn4yoqKuTxeEJbfT21CPGuo5C1ai+FrACAyKREMjgnJ0cWi0Vutztsv9vtVm5ubpfxn3zyiQ4ePKgbb7wxtC8QCATfOCVF+/bt05QpU7ocZ7PZZLPZIpkaYqyjkHX7wZP69fZ63fn5y2I9JQBAgohoZcRqtaqoqEhVVVWhfYFAQFVVVZo3b16X8dOnT9euXbu0c+fO0PblL39Zn/vc57Rz504uvySZjo6sFLICACIR0cqIJJWXl+vWW2/VVVddpeLiYj3yyCNqaWlRWVmZJGnp0qUaN26cKisrZbfbdeWVV4YdP3z4cEnqsh+J72+uHKN/+e0HoULWz00bHespAQASQMRhZNGiRWpsbNTKlSvlcrlUWFioTZs2hYpa6+rqZDbT2HUosqda9LU54/XEHw9o/bY6wggAoE9MhmHE/Xq61+uVw+GQx+NRVlZWrKeDXnzc0KzSh9+SxWzSH3/w18p12GM9JQBAjPT17zdLGBhQU0dnqjg/O9iR9V3uggIAXBxhBANucUmwMHkjhawAgD4gjGDA/c2VnTqyfkhHVgBA7wgjGHAdhayStL6ajqwAgN4RRhAVS9ov1fx+b4NcHjqyAgB6RhhBVFDICgDoK8IIoibUkbW6jkJWAECPCCOImi9emavh6ak66mmlkBUA0CPCCKKmcyHrf2+jkBUA0D3CCKJqcXFHIaubQlYAQLcII4iqqaMzVTwpWwEj2AQNAIALEUYQdUuKg4WsG7dTyAoA6IowgqjrXMj65ocNsZ4OACDOEEYQdWEdWbdxqQYAEI4wgkGxuP1Sze/3unXMcybGswEAxBPCCAbF1NEZoULWX28/HOvpAADiCGEEg+bmEgpZAQBdEUYwaOZfkasRFLICAC5AGMGgCS9kpSMrACCIMIJBtbiko5C1gUJWAIAkwggG2ZRRGSqhIysAoBPCCAbdklAhaz2FrAAAwggG3xevDBayHvO0ass+ClkBYKgjjGDQ2VIs+npRsJD12WoKWQFgqCOMICa+WXy+kPXopxSyAsBQRhhBTEwZlaHPTm7vyPouhawAMJQRRhAzHc+r2bi9Xuf8gRjPBgAQK4QRxEznQtY3P2yM9XQAADFCGEHMdC5kpSMrAAxdhBHEVMelmj/so5AVAIYqwghianKnQlY6sgLA0NSvMLJmzRrl5+fLbrerpKRE1dXVPY598cUXddVVV2n48OEaNmyYCgsL9fTTT/d7wkg+S0omSgreVUMhKwAMPRGHkY0bN6q8vFyrVq1SbW2tCgoKNH/+fDU0dN9JMzs7W/fcc4+2bt2q9957T2VlZSorK9Nrr712yZNHcph/hVPZw6ztHVkpZAWAocZkGEZEDwcpKSnR3LlztXr1aklSIBBQXl6e7rzzTq1YsaJPrzFnzhwtWLBAP/rRj/o03uv1yuFwyOPxKCsrK5LpIkE88Ooe/cdb+/X56aP1+P+bG+vpAAAGQF//fke0MtLW1qaamhqVlpaefwGzWaWlpdq6detFjzcMQ1VVVdq3b5/+6q/+qsdxPp9PXq83bENy++bcPEkUsgLAUBRRGGlqapLf75fT6Qzb73Q65XK5ejzO4/EoIyNDVqtVCxYs0KOPPqrrr7++x/GVlZVyOByhLS8vL5JpIgFNHpWheZNHUsgKAEPQoNxNk5mZqZ07d2r79u368Y9/rPLycm3ZsqXH8RUVFfJ4PKGtvp4/TkPB4hI6sgLAUJQSyeCcnBxZLBa53e6w/W63W7m5uT0eZzabNXXqVElSYWGh9uzZo8rKSl133XXdjrfZbLLZbJFMDUmgo5DV5Q0WspbOcF78IABAwotoZcRqtaqoqEhVVVWhfYFAQFVVVZo3b16fXycQCMjn80Xy1hgCwjqyVtORFQCGiogv05SXl2vdunV66qmntGfPHi1btkwtLS0qKyuTJC1dulQVFRWh8ZWVldq8ebP279+vPXv26Gc/+5mefvppfetb3xq4s0DS6Chk3bKvQUcoZAWAISGiyzSStGjRIjU2NmrlypVyuVwqLCzUpk2bQkWtdXV1MpvPZ5yWlhb9wz/8gw4fPqy0tDRNnz5dzzzzjBYtWjRwZ4Gk0VHIunX/cW3cXq/y6y+P9ZQAAFEWcZ+RWKDPyNDyP38+qjuf3aHcLLve/sHnlGLhqQUAkIii0mcEGAxf6FTI+gc6sgJA0iOMIO7YUiz6Rnsh67MUsgJA0iOMIC59szjYc4RCVgBIfoQRxKVJOcN09RQ6sgLAUEAYQdxaXNzRkbWOjqwAkMQII4hb86/I1chhVrm9PgpZASCJEUYQt6wp5vMdWbcdivFsAADRQhhBXAsVsn7YSCErACQpwgjiWkchq2FIG7nNFwCSEmEEcW9JSXsh67v1FLICQBIijCDufWHG+ULW3+9tiPV0AAADjDCCuGdNMevrV9GRFQCSFWEECWHx3POFrIdPno7xbAAAA4kwgoSQnzNM10wNFrL+mo6sAJBUCCNIGKGOrBSyAkBSIYwgYVDICgDJiTCChNG5kHU9hawAkDQII0goHYWsb1LICgBJgzCChNK5kHUjhawAkBQII0g4S4onSgqGEQpZASDxEUaQcK6f4VROhlUNzT5VUcgKAAmPMIKEY00x6+tFeZLoyAoAyYAwgoT0zbnBMPLmh42qP0EhKwAkMsIIElJ+zjD9xdScYEfWdylkBYBElhLrCQD9tbh4gt7+uEn/8dZ+vX/Eo5LJI1UyKVtXjnMo1ULOBoBEQRhBwrp+hlPTnJna527WH/Y16g/7GiVJ6VaLiiaOUMmkbJVMHqlZ4x2ypVhiPFsAQE9MhmEYsZ7ExXi9XjkcDnk8HmVlZcV6Oogj/oChPce8emf/cW07cELbD57Qp6fPho2xpZg1e8JwlUwaqZLJ2ZozYYTsqYQTAIi2vv79JowgqQQChj5saNa2/Se07cBxVR84oaZTbWFjUi0mFYwfrpLJ2SqZNFJFE0domI1FQgAYaIQRQJJhGPqksUXbDhwPBRS31xc2xmI26cpxDn12UrZKJmeraGK2HGmpMZoxACQPwgjQDcMwVHfitLbtP6F32ldODp88EzbGZJJmjMkKXdYpzs/WiGHWGM0YABIXYQTooyOfntG2/cGVk+qDJ3SgqaXLmGnOzGAwmRS8tDMq0xaDmQJAYolqGFmzZo1++tOfyuVyqaCgQI8++qiKi4u7Hbtu3Tr913/9l95//31JUlFRkR544IEex3eHMILB5Pa2atuBE9q2P7hy8lHDqS5jJo8aFlw5ab+0M8aRFoOZAkB8i1oY2bhxo5YuXaq1a9eqpKREjzzyiJ577jnt27dPo0eP7jL+5ptv1jXXXKOrr75adrtdP/nJT/TSSy9p9+7dGjdu3ICeDBANx0/5VH3gRDCgHDihvS6vLvzWTMhOV8mk4MrJZyeP1PgRaTKZTLGZMADEiaiFkZKSEs2dO1erV6+WJAUCAeXl5enOO+/UihUrLnq83+/XiBEjtHr1ai1durRP70kYQTzxnD6r6oMnVH0geDvx+0c8ClzwLRrrsAcv6bQ3YpuUM4xwAmDI6evf74juZ2xra1NNTY0qKipC+8xms0pLS7V169Y+vcbp06d19uxZZWdn9zjG5/PJ5zt/x4PX641kmkBUOdJTdf0Mp66f4ZQkNbee1buHTgZXT/Yf13uHPTrqadXLO4/q5Z1HJUmjMm3BVZP2gDJ1VIbMZsIJAEgRhpGmpib5/X45nc6w/U6nU3v37u3Ta/zgBz/Q2LFjVVpa2uOYyspK3XfffZFMDYiZTHuqPjdttD43LXiZ8nTbOe2o+1Tb9h/XOwdOaGf9p2ps9umV947plfeOSZKyh1k1N39E6I6d6blZshBOAAxRg9rp6cEHH9SGDRu0ZcsW2e32HsdVVFSovLw89LPX61VeXt5gTBG4ZOnWFF0zNUfXTM2RJLWe9evP9Z+215wcV82hkzrR0qbXdrv12m63JCnLnqK5+dmhRmxXjM1SCs/XATBERBRGcnJyZLFY5Ha7w/a73W7l5ub2euxDDz2kBx98UG+88YZmzZrV61ibzSabjVsnkRzsqZZg7cjkkZIuU9u5gHYd8YQasb178IS8redUtbdBVXsbJEnDrBYV5WerZFK2Pjs5WzPHDZc1hXACIDlFFEasVquKiopUVVWlhQsXSgoWsFZVVemOO+7o8bh/+7d/049//GO99tpruuqqqy5pwkCis6aYVTRxhIomjtA/XCed8wf0wTFvWAt7b+s5vfVho976MPjwP3uqWXMmBC/rFE/K1uwJw3m+DoCk0a9be2+99VY99thjKi4u1iOPPKJf//rX2rt3r5xOp5YuXapx48apsrJSkvSTn/xEK1eu1Pr163XNNdeEXicjI0MZGRl9ek/upsFQ4g8Y2udqDq2cVB88oRMt4c/XsVrMKswbHmrEVjRxhNKtPF8HQHyJyt00krRo0SI1NjZq5cqVcrlcKiws1KZNm0JFrXV1dTKbzy8n//KXv1RbW5u+/vWvh73OqlWr9C//8i+Rvj2Q9Cxmk2aMzdKMsVkqu2aSDMPQxw2n9E773TrbDpxQY7MveHvxwROSpBSzSTPHO0KN2K7KH6FMO8/XAZAYaAcPJBjDMHTw+OlQMNm2/7iOelrDxphN0hVjHaFGbMWTsjU8nefrABhcPJsGGELqT5w+38L+4AkdOn467PcmU/vzddr7nBRPylZOBkXiAKKLMAIMYcc8Z863sN9/XJ80dn3439TRGe0P/gu2sHdm9Xy7PQD0B2EEQEhjc8fzdYJ36+x1NXcZM3Fk8Pk6HY3Yxo9Ij8FMASQTwgiAHp1saWt/vk4woHxw1Nvl+TrjhqeFnkpcMmmkJo5M5/k6ACJCGAHQZ97Ws3r3YMdlnRPadcQj/wXpxJllU3H73Tolk7I1dXQG4QRArwgjAPqtxXdONR0P/ztwXH+u96jNHwgbM3KYNVRzUjxppKbnZvLwPwBhCCMABkzrWb9q6zqeTHxCtXUn5TsXHk4caanB5+u0X9qZMYbn6wBDHWEEQNT4zvm167BH2w6c0Dv7gw//O93mDxuTYUvRVfkj2ldPRmrWeIdSCSfAkEIYATBozvoD2n3UG2rEtv3ACTX7zoWNSUu1qGjiiNClnYI8nq8DJDvCCICY8QcM7TnmDWvE9unps2FjrClmzc4brpJJ2Zo+JksTstM1YWS6smhjDyQNwgiAuBEIGPqo4VTo4X/bDhxX06m2bscOT08NBpMLt5HpGuNIk4UiWSBhEEYAxC3DMLS/qUXb9p/QuwdP6MDxFtWfON1jQOmQYjZp/Ig05bUHlIkjg//s+JmHAwLxhTACIOG0+M6p7sRp1Z04rfr2fx46Hvz3+pOnddbf+3+usodZQ8FkQnaaJmYPC/48Ml25WXZWVYBB1te/3ymDOCcA6NUwW4o+MyZLnxnT9T9a/oAht7c1FE7qLthOtLSFtj/Xf9rl+FSLSeNHhF/6yWtfXcnLTleGjf8cArHCtw9AQrCYTRo7PE1jh6dp3pSRXX7f3HpW9SfOtIeTlvZ/nlH9idM63L6qcqCpRQeauj40UAo2cet8+SevU2jJzbLT0A2IIsIIgKSQaU/VjLGpmjG2+1WVY54z3V7+qTtxWidPn9XxljYdb2nTzm5WVawWs8Znp3VbVJs3Il3DWFUBLgnfIABJz2IOXqIZPyJdmtL1997Ws6rr4fLPkZNn1OYPaH9ji/Y3dr+qkpNh7ebyzzBNyE7X6EwbqyrARVDACgC9OOcP6JinVfUnTutQp5DSEVwu7J9yIWuKWXkj0tov/wwLu/wzITtdaVYavyF5UcAKAAMgxWJWXvtqx9Xd/N5z5mzYikrnyz9HPj2jtnMBfdLYok8aWyQ1djl+VKYtfEWl/fLPhOx0jcpgVQVDAysjABAl5/wBHf20NeyyT3CFpUWHjp9Wc+u5Xo+3pZjDgsqFvVVop494x8oIAMRYisUcXOUYmd7t7z2nzwZXU9rv/qm/oFbFdy6gjxpO6aOGU90eP7pjVWVk18LaURk2mUysqiAxsDICAHHorD+go5+eOb+qcvx02L9f+CDCC9lTzT1e/hk/glUVDA5WRgAggaVazJo4cpgmjhzW5XeGYejT9lWVsMs/7YHlmOeMWs8G9KH7lD50d7+qkptlD7v8M2FkmiZkB+8AysmwsqqCQcXKCAAkmbZzwVWVQ53v/Om0snLqIqsqaamWHi//jB+RJlsKqyroG1ZGAGCIsqaYlZ8zTPk53a+qnOy8qnK8pdPqyhkd9ZzRmbN+7XM3a5+7ucvxJlNwVSVUUNu+upLrsCs3yy5nlp3blRExVkYAACG+c34dOXmmy6Wfju10m/+ir5FlT1GuIxhMnFntIaU9rAQDi00jM2w8uHAIYGUEABAxW4pFk0dlaPKojC6/MwxDJ1radOiCSz/1J0/L7fXJ5WnVmbN+eVvPydvac72KFOyKOzrT1h5YbGGBJRRiHHYeYDhE8CkDAPrEZDJpZEZwVWPOhBFdfm8Yhryt5+T2tsrtbZXL0/5Pb6vcXl9oX9MpX/vzglp1zNPa63tm2FI0uj2shAcWWyiwjMqwKcVijtZpYxAQRgAAA8JkMsmRlipHWqoud2b2OO6cP6CmU21ydQos50NLcF+D16dm3zmd8p3TqcZzPT4XKPi+Uk6GLbSqkuuwyZnZ6dKQwy5npl1ZaSncJRSnCCMAgEGVYjEHC14ddimv53EtvnPBgOJplbu5VS7P+dUVd3Nwf0OzT+cChhqbfWps9mnXEU+Pr2dPNXe5DNRR05LrsGl0ZvBnawqrLIONMAIAiEvDbCmaMipDU7qpX+kQCBhqavGpob1mpWN1JbjS4pO7fZ/nzFm1ng3o4PHTOnj8dK/vO3KY9XwtS6fA0jnEjEhPZZVlAPUrjKxZs0Y//elP5XK5VFBQoEcffVTFxcXdjt29e7dWrlypmpoaHTp0SD//+c919913X8qcAQCQJJnNJo3OtGt0pl1XjnP0OK71rD+0quLyBi8DuTouDYVWWnxq8wd0vKVNx1va9MGxnt/XajGHalm61LF0WnWh023fRBxGNm7cqPLycq1du1YlJSV65JFHNH/+fO3bt0+jR4/uMv706dOaPHmyvvGNb+h73/vegEwaAIBI2FMtPXa07dDRg6XzZSBXWDFu8DLR8ZY2tfkDOnzyjA6fPNPr+zrSUjsFFlvY7c65DrtGZ9mUM4ynM0fcZ6SkpERz587V6tWrJUmBQEB5eXm68847tWLFil6Pzc/P19133x3xygh9RgAA8cJ3zq8Gr08N7XUsYZeGOt1B1Ho20KfXS+m4zbm90DZ0aeiCQtxhCXibc1T6jLS1tammpkYVFRWhfWazWaWlpdq6dWv/ZwsAQIKwpViU1955tiedb3M+f2motf0OIl8ovDSeChbgHvW06uhFbnPO7LjNuYc6ltwsu3IyrAl5m3NEYaSpqUl+v19OpzNsv9Pp1N69ewdsUj6fTz6fL/Sz1+sdsNcGACDaIrnNufGUL+wyUPgdRMH9p3zn1Ow7p+bGc/qkl9uczR23OYcFFltYYBmdZVeWPb5uc47LNZ/Kykrdd999sZ4GAABRlWIxa4wjTWMcab2OO+VrbybXvsoSKsLtVNfS0BxsJtfQ7FNDs09Sz7c5p6Va2gPL+aLbxcUTun2e0WCIKIzk5OTIYrHI7XaH7Xe73crNzR2wSVVUVKi8vDz0s9frVV5eLzejAwCQxDJsKcq4yG3O/oCh4y0+ubupY+l8B5HnzFmdOevXgaYWHWg6v8ryhStyEyOMWK1WFRUVqaqqSgsXLpQULGCtqqrSHXfcMWCTstlsstlsA/Z6AAAkO0un25xnqufbnM+0+bvpeuvTxJE918BEW8SXacrLy3XrrbfqqquuUnFxsR555BG1tLSorKxMkrR06VKNGzdOlZWVkoJFrx988EHo348cOaKdO3cqIyNDU6dOHcBTAQAAF5NmtSg/Z1jMVkG6E3EYWbRokRobG7Vy5Uq5XC4VFhZq06ZNoaLWuro6mc3nK3mPHj2q2bNnh35+6KGH9NBDD+naa6/Vli1bLv0MAABAQou4z0gs0GcEAIDE09e/34l3MzIAAEgqhBEAABBThBEAABBThBEAABBThBEAABBThBEAABBThBEAABBThBEAABBThBEAABBThBEAABBThBEAABBTET8oLxY6Hp/j9XpjPBMAANBXHX+3L/YYvIQII83NzZKkvLy8GM8EAABEqrm5WQ6Ho8ffJ8RTewOBgI4eParMzEyZTKYBe12v16u8vDzV19cn7dOAk/0cOb/El+znyPklvmQ/x2ien2EYam5u1tixY2U291wZkhArI2azWePHj4/a62dlZSXl/8A6S/Zz5PwSX7KfI+eX+JL9HKN1fr2tiHSggBUAAMQUYQQAAMTUkA4jNptNq1atks1mi/VUoibZz5HzS3zJfo6cX+JL9nOMh/NLiAJWAACQvIb0yggAAIg9wggAAIgpwggAAIgpwggAAIippA8ja9asUX5+vux2u0pKSlRdXd3r+Oeee07Tp0+X3W7XzJkz9eqrrw7STPsvknN88sknZTKZwja73T6Is43MW2+9pRtvvFFjx46VyWTSyy+/fNFjtmzZojlz5shms2nq1Kl68sknoz7P/or0/LZs2dLl8zOZTHK5XIMz4QhVVlZq7ty5yszM1OjRo7Vw4ULt27fvosclyvewP+eXaN/BX/7yl5o1a1aoIda8efP0u9/9rtdjEuXzkyI/v0T7/C704IMPymQy6e677+513GB/hkkdRjZu3Kjy8nKtWrVKtbW1Kigo0Pz589XQ0NDt+D/96U9avHixvv3tb2vHjh1auHChFi5cqPfff3+QZ953kZ6jFOyyd+zYsdB26NChQZxxZFpaWlRQUKA1a9b0afyBAwe0YMECfe5zn9POnTt1991367bbbtNrr70W5Zn2T6Tn12Hfvn1hn+Ho0aOjNMNL8+abb2r58uV65513tHnzZp09e1Zf+MIX1NLS0uMxifQ97M/5SYn1HRw/frwefPBB1dTU6N1339Vf//Vf66abbtLu3bu7HZ9In58U+flJifX5dbZ9+3Y99thjmjVrVq/jYvIZGkmsuLjYWL58eehnv99vjB071qisrOx2/N/+7d8aCxYsCNtXUlJi/P3f/31U53kpIj3HX/3qV4bD4Rik2Q0sScZLL73U65jvf//7xhVXXBG2b9GiRcb8+fOjOLOB0Zfz+8Mf/mBIMk6ePDkocxpoDQ0NhiTjzTff7HFMIn4PO/Tl/BL5O9hhxIgRxn/+5392+7tE/vw69HZ+ifr5NTc3G5dddpmxefNm49prrzXuuuuuHsfG4jNM2pWRtrY21dTUqLS0NLTPbDartLRUW7du7faYrVu3ho2XpPnz5/c4Ptb6c46SdOrUKU2cOFF5eXkX/X8AiSbRPsP+Kiws1JgxY3T99dfrj3/8Y6yn02cej0eSlJ2d3eOYRP4M+3J+UuJ+B/1+vzZs2KCWlhbNmzev2zGJ/Pn15fykxPz8li9frgULFnT5bLoTi88wacNIU1OT/H6/nE5n2H6n09nj9XWXyxXR+FjrzzlOmzZNTzzxhH7zm9/omWeeUSAQ0NVXX63Dhw8PxpSjrqfP0Ov16syZMzGa1cAZM2aM1q5dqxdeeEEvvPCC8vLydN1116m2tjbWU7uoQCCgu+++W9dcc42uvPLKHscl2vewQ1/PLxG/g7t27VJGRoZsNpu++93v6qWXXtKMGTO6HZuIn18k55eIn9+GDRtUW1urysrKPo2PxWeYEE/txcCZN29eWOK/+uqr9ZnPfEaPPfaYfvSjH8VwZuiLadOmadq0aaGfr776an3yySf6+c9/rqeffjqGM7u45cuX6/3339fbb78d66lERV/PLxG/g9OmTdPOnTvl8Xj0/PPP69Zbb9Wbb77Z4x/sRBPJ+SXa51dfX6+77rpLmzdvjutC26QNIzk5ObJYLHK73WH73W63cnNzuz0mNzc3ovGx1p9zvFBqaqpmz56tjz/+OBpTHHQ9fYZZWVlKS0uL0ayiq7i4OO7/wN9xxx363//9X7311lsaP358r2MT7XsoRXZ+F0qE76DVatXUqVMlSUVFRdq+fbt+8Ytf6LHHHusyNhE/v0jO70Lx/vnV1NSooaFBc+bMCe3z+/166623tHr1avl8PlkslrBjYvEZJu1lGqvVqqKiIlVVVYX2BQIBVVVV9XgtcN68eWHjJWnz5s29XjuMpf6c44X8fr927dqlMWPGRGuagyrRPsOBsHPnzrj9/AzD0B133KGXXnpJv//97zVp0qSLHpNIn2F/zu9CifgdDAQC8vl83f4ukT6/nvR2fheK98/v85//vHbt2qWdO3eGtquuuko333yzdu7c2SWISDH6DKNWGhsHNmzYYNhsNuPJJ580PvjgA+M73/mOMXz4cMPlchmGYRi33HKLsWLFitD4P/7xj0ZKSorx0EMPGXv27DFWrVplpKamGrt27YrVKVxUpOd43333Ga+99prxySefGDU1NcY3v/lNw263G7t3747VKfSqubnZ2LFjh7Fjxw5DkvHwww8bO3bsMA4dOmQYhmGsWLHCuOWWW0Lj9+/fb6Snpxv/9E//ZOzZs8dYs2aNYbFYjE2bNsXqFHoV6fn9/Oc/N15++WXjo48+Mnbt2mXcddddhtlsNt54441YnUKvli1bZjgcDmPLli3GsWPHQtvp06dDYxL5e9if80u07+CKFSuMN9980zhw4IDx3nvvGStWrDBMJpPx+uuvG4aR2J+fYUR+fon2+XXnwrtp4uEzTOowYhiG8eijjxoTJkwwrFarUVxcbLzzzjuh31177bXGrbfeGjb+17/+tXH55ZcbVqvVuOKKK4xXXnllkGccuUjO8e677w6NdTqdxg033GDU1tbGYNZ903Er64VbxzndeuutxrXXXtvlmMLCQsNqtRqTJ082fvWrXw36vPsq0vP7yU9+YkyZMsWw2+1Gdna2cd111xm///3vYzP5Puju3CSFfSaJ/D3sz/kl2nfw7/7u74yJEycaVqvVGDVqlPH5z38+9IfaMBL78zOMyM8v0T6/7lwYRuLhMzQZhmFEb90FAACgd0lbMwIAABIDYQQAAMQUYQQAAMQUYQQAAMQUYQQAAMQUYQQAAMQUYQQAAMQUYQQAAMQUYQQAAMQUYQQAAMQUYQQAAMQUYQQAAMTU/w+ZuvZWqANXDwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class CustomCallback(keras.callbacks.Callback):\n", + " def on_epoch_end(self, epoch, logs=None):\n", + " model_name = f\"/content/drive/MyDrive/{lora_name}_{lora_rank}_epoch{epoch+1}.lora.h5\"\n", + " gemma_lm.backbone.save_lora_weights(model_name)\n", + "\n", + " # Evaluate\n", + " text_gen(\"日本の珍しい名字\")\n", + " text_gen(\"日本の平凡な名字\")\n", + "\n", + "history = gemma_lm.fit(train, epochs=train_epoch, batch_size=2, callbacks=[CustomCallback()])\n", + "\n", + "import matplotlib.pyplot as plt\n", + "plt.plot(history.history['loss'])\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4QkVxDR1IHmx" + }, + "source": [ + "## Try a different sampler\n", + "\n", + "The top-K algorithm randomly picks the next token from the tokens of top K probability.\n", + "\n", + "**NOTE: Due to randomness of the sampler, you may encounter \\ tokens that surpass the number you trained with.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1zwJmf9KzwW0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 142742 39742 107 108 106 2516\n", + " 108 23 4758 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + "142742 -> 珍しい\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 23 -> \n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "嘉多さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 22.22s\n", + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 142742 39742 107 108 106 2516\n", + " 108 59 4758 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + "142742 -> 珍しい\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 59 -> \n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 0.54s\n", + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 142742 39742 107 108 106 2516\n", + " 108 75 4758 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + "142742 -> 珍しい\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 75 -> \n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 0.51s\n", + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 142742 39742 107 108 106 2516\n", + " 108 30 4758 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + "142742 -> 珍しい\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 30 -> \n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "周本さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 0.53s\n", + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 142742 39742 107 108 106 2516\n", + " 108 33 4758 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + "142742 -> 珍しい\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 33 -> \n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本の珍しい名字\n", + "model\n", + "下小薗さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 0.51s\n" + ] + } + ], + "source": [ + "gemma_lm.compile(sampler=\"top_k\")\n", + "\n", + "text_gen_with_dict(\"日本の珍しい名字\", my_dictionary)\n", + "text_gen_with_dict(\"日本の珍しい名字\", my_dictionary)\n", + "text_gen_with_dict(\"日本の珍しい名字\", my_dictionary)\n", + "text_gen_with_dict(\"日本の珍しい名字\", my_dictionary)\n", + "text_gen_with_dict(\"日本の珍しい名字\", my_dictionary)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mqIx09b2INH0" + }, + "source": [ + "Try a slight different prompts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oWSODjdj8H4V" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 39742 107 108 106 2516 108 98647 4758\n", + " 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 98647 -> 田中\n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本の名字\n", + "model\n", + "田中さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 0.69s\n", + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 77427 39742 107 108 106 2516 108 106\n", + " 4758 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + " 77427 -> 面白い\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本の面白い名字\n", + "model\n", + "さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 0.50s\n", + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 46116 4302 3581 39742 107 108 106 2516\n", + " 108 51 4758 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + " 46116 -> 見た\n", + " 4302 -> こと\n", + " 3581 -> ない\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 51 -> \n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本の見たことない名字\n", + "model\n", + "神尊さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 0.58s\n", + "--------------------------------------------------------------------------------\n", + "[ 106 1645 108 47172 59956 39742 107 108 106 2516 108 16\n", + " 4758 107]\n", + " 106 -> \n", + " 1645 -> user\n", + " 108 -> \n", + "\n", + " 47172 -> 日本の\n", + " 59956 -> すごい\n", + " 39742 -> 名字\n", + " 107 -> \n", + " 108 -> \n", + "\n", + " 106 -> \n", + " 2516 -> model\n", + " 108 -> \n", + "\n", + " 16 -> \n", + " 4758 -> さん\n", + " 107 -> \n", + "\n", + "Gemma output:\n", + "user\n", + "日本のすごい名字\n", + "model\n", + "真喜さん\n", + "--------------------------------------------------------------------------------\n", + "TOTAL TIME ELAPSED: 0.49s\n" + ] + } + ], + "source": [ + "text_gen_with_dict(\"日本の名字\", my_dictionary)\n", + "text_gen_with_dict(\"日本の面白い名字\", my_dictionary)\n", + "text_gen_with_dict(\"日本の見たことない名字\", my_dictionary)\n", + "text_gen_with_dict(\"日本のすごい名字\", my_dictionary)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "Custom_Vocabulary.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/README.md b/README.md index befe491..a963e8e 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ You can find the Gemma models on GitHub, Hugging Face models, Kaggle, Google Clo | [Finetune_with_Axolotl.ipynb](Gemma/Finetune_with_Axolotl.ipynb) | Finetune Gemma using [Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl). | | [Finetune_with_XTuner.ipynb](Gemma/Finetune_with_XTuner.ipynb) | Finetune Gemma using [XTuner](https://github.com/InternLM/xtuner). | | [Finetune_with_LLaMA_Factory.ipynb](Gemma/Finetune_with_LLaMA_Factory.ipynb) | Finetune Gemma using [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory). | +| [Custom_Vocabulary.ipynb](Gemma/Custom_Vocabulary.ipynb) | Demonstrate how to use a custom vocabulary "<unused[0-98]>" tokens in Gemma. | | **Alignment** | | | [Aligning_DPO_Gemma_2b_it.ipynb](Gemma/Aligning_DPO_Gemma_2b_it.ipynb) | Demonstrate how to align a Gemma model using DPO (Direct Preference Optimization) with [Hugging Face TRL](https://huggingface.co/docs/trl/en/index). | | **Evaluation** | |