diff --git a/Workshops/Workshop_How_to_Fine_tuning_Gemma_Transformers_Edition.ipynb b/Workshops/Workshop_How_to_Fine_tuning_Gemma_Transformers_Edition.ipynb new file mode 100644 index 0000000..bcb0600 --- /dev/null +++ b/Workshops/Workshop_How_to_Fine_tuning_Gemma_Transformers_Edition.ipynb @@ -0,0 +1,6102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "i1PHqD-ZY4-c" + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YNDq8NbCY7oh" + }, + "source": [ + "# Workshop: How to Fine-tuning Gemma - Transformers Edition\n", + "\n", + "To illustrate fine-tuning the model for a specific task, You'll learn how to condition a Gemma model to answer in a specific language. Let's consider the example of generating a random Portuguese title based on a user's instruction such as \"Write a title\". To make this possible, you will curate a manageable dataset that can be manually processed. This approach is feasible because Gemma 2 has prior knowledge of general Portuguese language patterns, enabling it to adapt to this specific task effectively." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u4EM3g9u2_KA" + }, + "source": [ + "## What is Fine-tuning\n", + "\n", + "In the first place, you have to understand what is fine-tuning. It's a specialized form of [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning). It involves taking a pre-trained language model - one that has already been exposed to a vast corpus of text data and learned the general patterns and structures of language - and further training it on a smaller, more specific dataset. This additional training allows the model to adapt and refine its knowledge, making it better suited for a particular task or domain.\n", + "\n", + "Imagine you are a skilled gamer who excels at various genres, from action-adventures to strategy games. Fine-tuning is akin to taking you and having you focus intensely on mastering a specific game, like a complex real-time strategy (RTS) title. You already possess a strong foundation of gaming skills and knowledge, but the dedicated practice and study within the RTS genre sharpens your tactics, understanding of game mechanics, and overall proficiency within that particular realm.\n", + "\n", + "Similarly, pre-trained language models have a broad understanding of language, but fine-tuning helps them specialize. By exposing them to a curated dataset relevant to your desired application, you guide the model to learn the nuances and intricacies specific to that domain. It's like giving the model a crash course in the language of your chosen field, enabling it to perform tasks with greater accuracy and fluency.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3rzH5Ugf5RlJ" + }, + "source": [ + "## Setup\n", + "\n", + "### Select the Colab runtime\n", + "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model:\n", + "\n", + "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n", + "2. Select **Change runtime type**.\n", + "3. Under **Hardware accelerator**, select **T4 GPU**.\n", + "\n", + "\n", + "### Gemma setup on Kaggle\n", + "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n", + "\n", + "* Get access to Gemma on kaggle.com.\n", + "* Select a Colab runtime with sufficient resources to run the Gemma 2B model.\n", + "* Generate and configure a Kaggle username and API key.\n", + "\n", + "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "URMuBzkMVxpU" + }, + "source": [ + "### Set environemnt variables\n", + "\n", + "Set environment variables for ```HUGGING_FACE```." + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "from google.colab import userdata, drive\n", + "from huggingface_hub import login\n", + "\n", + "login(userdata.get(\"HUGGING_FACE\"))\n", + "\n", + "access_token = userdata.get(\"HUGGING_FACE\")\n", + "my_hf_username = userdata.get(\"HUGGING_FACE_UN\")\n", + "os.environ[\"HF_USER\"] = my_hf_username\n", + "os.environ[\"HF_TOKEN\"] = userdata.get(\"HUGGING_FACE\")" + ], + "metadata": { + "id": "TT7GexJnZZCj", + "outputId": "5f5cc93f-d9c8-4bb4-d828-b670ed352b35", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", + "Token is valid (permission: read).\n", + "Your token has been saved to /root/.cache/huggingface/token\n", + "Login successful\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LXfDwRTQVns2" + }, + "source": [ + "### Install dependencies\n", + "\n", + "Install Transformers and Torch" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install transformers torch\n", + "# Set the backbend before importing Keras\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "# Avoid memory fragmentation on JAX backend.\n", + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\"\n", + "\n", + "# Training Configurations\n", + "token_limit = 128\n", + "num_data_limit = 100\n", + "lora_name = \"my_lora\"\n", + "lora_rank = 4\n", + "lr_value = 1e-3\n", + "train_epoch = 5\n", + "model_id = \"google/gemma-2-2b-it\"" + ], + "metadata": { + "id": "WNn86PiiXTNf", + "outputId": "5ac7a598-0773-4596-a393-ba53cbfd2470", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kUl0t469YfQY" + }, + "source": [ + "## Load Model\n", + "\n", + "**Why Fine-tuning?**\n", + "\n", + "Before embarking on fine-tuning, it's crucial to evaluate if its benefits align with the specific requirements of your application. Fine-tuning involves meticulous data preparation and extensive training, making it an arduous process. Therefore, it's essential to assess whether the potential gains justify the significant effort required.\n", + "\n", + "**Try \"Prompt Engineering\" first.** before fine-tuning\n", + "\n", + "Would you like to enable Gemma's multilingual capabilities?\n", + "Please note that Gemma 2 already has some multilingual capabilities. Here's the example output from Gemma 2 2B instruction-tuned model.\n", + "\n", + "Do you wish to adjust the tone or writing style?\n", + "Gemma 2 might be familiar with the writing style you have in mind. Here's another output from the same model." + ] + }, + { + "cell_type": "code", + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "import time\n", + "\n", + "# Load a pretrained model and tokenizer from Hugging Face\n", + "gemma_lm = AutoModelForCausalLM.from_pretrained(model_id, token=access_token)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id, token=access_token)\n", + "\n", + "# Summarize the model\n", + "print(gemma_lm)\n", + "\n", + "tick_start = 0\n", + "\n", + "def tick():\n", + " global tick_start\n", + " tick_start = time.time()\n", + "\n", + "def tock():\n", + " print(f\"TOTAL TIME ELAPSED: {time.time() - tick_start:.2f}s\")\n", + "\n", + "def text_gen(prompt, token_limit=100): # You can set your token limit\n", + " tick()\n", + "\n", + " # Format input, same as your original code\n", + " input_text = f\"user\\n{prompt}\\nmodel\\n\"\n", + "\n", + " # Tokenize input\n", + " inputs = tokenizer(input_text, return_tensors=\"pt\")\n", + "\n", + " # Generate text using the model\n", + " output_tokens = gemma_lm.generate(\n", + " inputs[\"input_ids\"],\n", + " max_length=token_limit,\n", + " pad_token_id=tokenizer.eos_token_id # Prevent errors if the input length exceeds the model's limit\n", + " )\n", + "\n", + " # Decode the generated tokens back to text\n", + " output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n", + "\n", + " print(\"\\nGemma output:\")\n", + " print(output)\n", + "\n", + " tock()" + ], + "metadata": { + "id": "ywcDWVhAXb_9", + "outputId": "d9af537b-6c7a-4018-ada6-7b5dc58328d8", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 880, + "referenced_widgets": [ + "5dcc5fb3ad9e41eeb877fdd04f09ba7d", + "eeda0cc4233f41c7a1bb077b335922af", + "7786f72663544d3b964f8706ee4c87db", + "16489924a7194b6780f51cb3fa808c5e", + "8a87e9b8ba9d41ce8b44f2fef3244a81", + "e84b2d9e1e9e46c28e788385d64bc7c6", + "60c77f65806f4ff688453439fe7c9d83", + "de9354062f994aaabb104890709fc34a", + "9ae9adefe60b4511a90cdb2b3ef6e03d", + "3a29897fd3b3454ca752f63f45007d7f", + "1ef2aa04eaba4620a49c748ca4380d99", + "f4188acdf47f4cce86c02cc0d2bf46c7", + "54c456ec68684e57aa57aa7715921bfb", + "4ad2c31035574f26bf6118ef0d1413a9", + "a12cf321b8ae4075a4c3b0677d7bdf45", + "f8234ec2555b4e5684816aed66b63651", + "d816729082e94c41988d467654ef7841", + "b2d33d479a634d5c807e2a1860f4bc8c", + "b542562fcd6c466c9aa7f12b1b881207", + "2e8af71a4d85433a8e68f5c3737d7fb9", + "cdf16eb2e8ca487b9e9c7de16381c5c8", + "15332ef69c9244ccaaf3030dc9cb113a", + "eee1dd11b43e435cb226ee6df594291d", + "cc3c5ae5ebb74b2dbc81333533cc78cf", + "2f79a5c2a23b4808a181904593072dc5", + "a1c592e1355844adaa9ea80718499e4c", + "9805d9bec3f14c41a7a87fced47bb6ca", + "aed3fca1af5648e3b681d0c44f9c4f86", + "5960f282615e4932af3d9b7f69de2efa", + "151d7e95158d4d8aa2faf2886c8e0604", + "b5e8f1d688ea4694ac83a1bf89e27ae0", + "02ddb4d4c134415fa89a98984ae45c5f", + "7ba48bd1bbb148ff9dc7a2c33b835a35", + "9dc263e2dbfa4d9d9a56d62a69a57a8f", + "6a3171fe96744527805d8dc9eb159006", + "0bd0296c84a248259b5c48ac4e6bb760", + "20bc2639f53d4701aeafdb30e60f47da", + "372294f99f90486391b77296724916cc", + "2d9816c05a3544c0afc5a41c74ad6179", + "52132d841f1f48dd8c1d87ee0d308764", + "2fd5a5bfe878469c8f6b710c672fda9c", + "3a0c665acdbb4c83918c0ac469114235", + "eebc1808719c45b3bc7ad525798ce120", + "5d5a094d9cb545788f75992cbd859257", + "ecb1c6f7effa46b296e3574cddcb88de", + "cf5ec90af2944a63a0f83d5520307b82", + "28c7d5e1e84848c29b3d533930ccc198", + "b2831f93a9d148f4871c435382c6302e", + "9a827d1a0a424d6381b3e1eaac796b1e", + "afda84f0f4a445ee9dee2ccf047ce5ce", + "fad3963f838c44d98af0fdebb787741c", + "2dbc25a62598422b9c0768e109f8fd92", + "f97ae90efc3f44498db9203ef90d232e", + "707955f6553648fbb75ea65f2df2eaf6", + "c2fd92147e044b96ace6530ae67ca181", + "64324fda5c814874aaa2e97fa78ab3f0", + "abc4f6598a1f4c57a74a55f854ddc66d", + "6bfe81a71ff54a1390ff1d90d7d78864", + "25c94e5770e84ecea0e0f68f5bd64905", + "d15f3939e8b143a5b2deeb82435f9d26", + "a006f1ee609b47dda37389b5478a803d", + "0254a7b1482545cba26b5708a8920b34", + "977e2dc71b224a05823aa8120fa472bd", + "f6bdb5a59a044f4bb7332068179adce1", + "ecb124db56ea4925b83e91206c53f89b", + "e29c06ee25f84198bbea49aa0b19bb65", + "43bccde2f2804b089e4e40081058855f", + "f96a8568f4dc409090a123ff63766e58", + "163aee402b454733b1e74d43834849fb", + "8ebf2a32eac047cf9367f615ec01f0a8", + "760de7e3cef94be8b82cf8fc99385844", + "22d61c177e4e4d09a62ac08785682151", + "7ad0cea5c87e40da9b1dd1966b222f5a", + "a1e5eef9348f4bfeb456e5f096301778", + "4d94e337748343bb8e26d7ad7d361906", + "bc1dcd6dcd16462999359b579146c856", + "7c79134c76334c8680e81b58b9f75819", + "488f3d23e6ee4370b0e3750c4610acc1", + "5b1b61f335594a4487667b9d60395218", + "350ac704ca744a859527627fa92f9aa7", + "5a25dc05639c4cb282b1a2c4bbfdb572", + "4aaae0e141b14af78d9207b5f99a4516", + "182f4c00d1e94a4081057f34462704b7", + "62bb5bafdf2e47628a3ff1b114bac45a", + "79fac9dcae554d6f885608bd29fb904e", + "3b43ab0605644de9a852218edf3516fb", + "67261f6d01b14c50a4d7413ae8a4d113", + "110e4a8116e847dc9e4cea9a9df9e2b5", + "9ac1990111d64f919644a292b4ea2d36", + "494574a90f844172b5dce66aba248564", + "6b30ac7eab3346e8978274b8bcb03371", + "bdee91edaed5469a98ed808735bb30b0", + "eb7cfbc69c4848b3b4274203a70b6f1e", + "adb76c55b90d4f5aadd92d51b526fc27", + "5dd424f1051645abb00a3a3aa7d93990", + "018cb6becc764afb8bc6d9e9f2196fb9", + "bdd09cce6f4d4cf5b555678537c029c9", + "41ffd80053b6402f8aba20e1413e30cd", + "015fabd8f4934029900eec53ef02bdb8", + "b91c257be15341179ce8a262b3962ef0", + "e207f4774be44704a804dfacd628ad5a", + "559c60d219f24907b2b3b8e7c78deed8", + "88189b7d8769412685e197c8c8b7ec5c", + "2c8f45c2c1624865a698eb8c5527b0e4", + "d55c5f99d58942c2ba0506090f9755c2", + "9b5c6869f25945fa9424593255af07fc", + "435dfab7d6c94fb7b8b1a9dbb81bdbe7", + "9c0fa484394142c3beee7e272f626b08", + "3ad6acbc511b4417b5ca4e56a8688828", + "2d693baa6ea04b58965141813bd0ceca", + "379696127cc841a08cb6ca5000180032", + "9d03f9f959254badae3c94b71e84f2cc", + "65d1d113a0a8446c83563d9aefd8b112", + "9f61fd282f104094bc821e8a4b9f27ce", + "fd0fc02ef4b545c0a459527aca8bc81e", + "01cf0405f63a4d38ad8d1fe31721806f", + "ae20add968524517a261a81be45cde48", + "dc76053f6ae847469aab55949ad182be", + "4d98195ba52f48a29a237777eeef35cf", + "ffeafe7d69484d64af7484311e32703d", + "b66e7b11b2404e3daeaed1272994adb5" + ] + } + }, + "execution_count": null, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5dcc5fb3ad9e41eeb877fdd04f09ba7d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/838 [00:00 \n", + "177383 -> olá\n", + "235265 -> .\n", + " 6235 -> Pra\n", + " 3004 -> zer\n", + " 2190 -> em\n", + " 26809 -> conhe\n", + "235260 -> c\n", + "235442 -> ê\n", + "235290 -> -\n", + " 545 -> lo\n", + "235265 -> .\n", + " 687 -> O\n", + " 11030 -> tempo\n", + " 5365 -> está\n", + " 14693 -> muito\n", + " 12318 -> bom\n", + " 43897 -> hoje\n", + "235265 -> .\n", + "\n", + "[2, 235530, 235579, 45884, 235483, 235940, 27074, 20579, 89299, 30848, 197350, 99877, 235940, 133533, 118300, 161437, 3640, 236062, 84372, 236062, 197350, 6032, 235265]\n", + " 2 -> \n", + "235530 -> न\n", + "235579 -> म\n", + " 45884 -> स्त\n", + "235483 -> े\n", + "235940 -> ।\n", + " 27074 -> आप\n", + " 20579 -> से\n", + " 89299 -> मिल\n", + " 30848 -> कर\n", + "197350 -> अच्छा\n", + " 99877 -> लगा\n", + "235940 -> ।\n", + "133533 -> आज\n", + "118300 -> मौ\n", + "161437 -> सम\n", + " 3640 -> स\n", + "236062 -> च\n", + " 84372 -> मु\n", + "236062 -> च\n", + "197350 -> अच्छा\n", + " 6032 -> है\n", + "235265 -> .\n" + ] + } + ], + "source": [ + "import jax.numpy as jnp\n", + "\n", + "# Function to detokenize (convert tokens back into words)\n", + "def detoken(tokens):\n", + " print(tokens['input_ids']) # Print the token IDs for debugging\n", + " input_ids = tokens['input_ids'] # Get input IDs from the tokenizer output\n", + "\n", + " for x in input_ids: # Iterate over the token list\n", + " # Use tokenizer.decode() to convert tokens back to words\n", + " word = tokenizer.decode([x]) # No need to convert to JAX array for decoding\n", + " print(f\"{x:6} -> {word}\")\n", + "\n", + "# Example text 1: Portuguese\n", + "detoken(tokenizer(\"olá. Prazer em conhecê-lo. O tempo está muito bom hoje.\", return_tensors=None))\n", + "print()\n", + "\n", + "# Example text 2: Hindi\n", + "detoken(tokenizer(\"नमस्ते। आपसे मिलकर अच्छा लगा। आज मौसम सचमुच अच्छा है.\", return_tensors=None))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9T7xe_jzslv4" + }, + "source": [ + "## Load Dataset\n", + "\n", + "How many datasets do you need? You can start with a relatively small dataset, approximately 10 to 20, those can have a significant impact on a model's behavior.\n", + "\n", + "To improve the output quality, a target of around 200 total examples is recommended. Nevertheless, the amount of data required for tuning really depends on how much you want to influence the model's behavior. Our recommendation is to commence with a limited amount of data and gradually incorporate additional data into the training process until the desired behavior is achieved." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZiS-KU9osh_N", + "outputId": "dedf9024-ff52-4a36-e586-6fea9139a53f", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "15\n", + "user\n", + "Write a title\n", + "model\n", + "O Alquimista\n", + "\n", + "user\n", + "Write a title\n", + "model\n", + "Dom Casmurro\n", + "\n", + "user\n", + "Write a title\n", + "model\n", + "Memorial do Convento\n" + ] + } + ], + "source": [ + "# example titles\n", + "data = [\n", + " \"O Alquimista\", # by Paulo Coelho\n", + " \"Dom Casmurro\", # by Machado de Assis\n", + " \"Memorial do Convento\", # by José Saramago\n", + " \"A Hora da Estrela\", # by Clarice Lispector\n", + " \"Vidas Secas\", # by Graciliano Ramos\n", + " \"O Cortiço\", # by Aluísio Azevedo\n", + " \"Grande Sertão: Veredas\", # by Guimarães Rosa\n", + " \"Capitães da Areia\", # by Jorge Amado\n", + " \"A Sibila\", # by Agustina Bessa-Luís\n", + " \"Os Maias\", # by Eça de Queirós\n", + " \"O Crime do Padre Amaro\", # by Eça de Queirós\n", + " \"A Relíquia\", # by Eça de Queirós\n", + " \"O Primo Basílio\", # by Eça de Queirós\n", + " \"A Ilustre Casa de Ramires\", # by Eça de Queirós\n", + " \"A Cidade e as Serras\" # by Eça de Queirós\n", + "]\n", + "\n", + "train = []\n", + "\n", + "for x in data:\n", + " item = f\"user\\nWrite a title\\nmodel\\n{x}\"\n", + " length = len(tokenizer(item))\n", + " # skip data if the token length is longer than our limit\n", + " if length < token_limit:\n", + " train.append(item)\n", + " if(len(train)>=num_data_limit):\n", + " break\n", + "\n", + "print(len(train))\n", + "print(train[0])\n", + "print()\n", + "print(train[1])\n", + "print()\n", + "print(train[2])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9s1o96HRtwV_" + }, + "source": [ + "See below example code, using HF datasets, if your datasets are much bigger." + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install datasets" + ], + "metadata": { + "id": "ZS9zT92tiKHu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# from datasets import load_dataset\n", + "\n", + "# # Load the dataset\n", + "# ds = load_dataset(\"bebechien/korean_cake_boss\", split=\"train\")\n", + "# print(ds)\n", + "\n", + "# # Prepare the dataset for tokenization\n", + "# train = []\n", + "\n", + "# # Iterate through the dataset and format the prompts\n", + "# for x in ds:\n", + "# # Create the formatted input-output text\n", + "# item = f\"user\\n다음에 대한 이메일 답장을 작성해줘.\\n\\\"{x['input']}\\\"\\nmodel\\n{x['output']}\"\n", + "\n", + "# # Tokenize the item and get its length\n", + "# length = len(tokenizer(item)[\"input_ids\"])\n", + "# print(length)\n", + "# # Skip if the tokenized item is longer than the token limit\n", + "# if length < token_limit:\n", + "# train.append(item)\n", + "\n", + "# # Stop if we have reached the desired data limit\n", + "# if len(train) >= num_data_limit:\n", + "# break\n", + "\n", + "# # Output the results\n", + "# print(f\"Number of training examples: {len(train)}\")\n", + "# print(f\"First example: {train[0]}\")\n", + "# print(f\"Second example: {train[1]}\")\n", + "# print(f\"Third example: {train[2]}\")" + ], + "metadata": { + "id": "Vh9s8m_PgoAH", + "outputId": "ccdf0c19-c3b5-40d5-d7ae-dc7a2c59fe0b", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Dataset({\n", + " features: ['input', 'output'],\n", + " num_rows: 20\n", + "})\n", + "234\n", + "307\n", + "335\n", + "348\n", + "366\n", + "158\n", + "169\n", + "157\n", + "198\n", + "167\n", + "163\n", + "150\n", + "165\n", + "145\n", + "157\n", + "308\n", + "407\n", + "298\n", + "419\n", + "318\n", + "Number of training examples: 10\n", + "First example: user\n", + "다음에 대한 이메일 답장을 작성해줘.\n", + "\"안녕하세요, 10월 5일에 있을 딸 아이의 5번째 생일을 위해 케이크를 주문하고 싶습니다. 아이가 좋아하는 핑크색 공주님 케이크가 가능할까요?\"\n", + "model\n", + "고객님, 안녕하세요.\n", + "\n", + "따님의 5번째 생일을 진심으로 축하드립니다! 핑크색 공주님 케이크 주문 가능합니다. 원하시는 디자인이나 특별한 요청 사항이 있으시면 말씀해주세요.\n", + "\n", + "감사합니다.\n", + "\n", + "[가게 이름] 드림\n", + "Second example: user\n", + "다음에 대한 이메일 답장을 작성해줘.\n", + "\"11월 10일, 저희 부부의 결혼 10주년을 기념하기 위한 케이크를 주문하려고 합니다. 둘이 함께 먹을 작은 사이즈의 하트 모양 케이크를 원합니다.\"\n", + "model\n", + "고객님, 안녕하세요.\r\n", + "\n", + "결혼 10주년을 축하드립니다! 두 분의 특별한 날을 더욱 빛내드릴 하트 모양 케이크 주문 가능합니다. 케이크 맛과 크기, 디자인 등 다른 요청 사항이 있으시면 말씀해주세요.\n", + "\n", + "감사합니다.\n", + "\n", + "[가게 이름] 드림\n", + "Third example: user\n", + "다음에 대한 이메일 답장을 작성해줘.\n", + "\"3월 15일에 있을 대학교 졸업식을 축하하기 위한 케이크를 주문하고 싶습니다. 학교 로고가 들어간 디자인이 가능한지 궁금합니다.\"\n", + "model\n", + "고객님, 안녕하세요.\n", + "\n", + "졸업을 진심으로 축하드립니다! 학교 로고가 들어간 케이크 주문 가능합니다. 로고 파일을 보내주시면 디자인 시안을 만들어 보여드리겠습니다. 궁금한 점은 언제든 문의해주세요.\n", + "\n", + "감사합니다.\n", + "\n", + "[가게 이름] 드림\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5NTIrFbJ3dBv" + }, + "source": [ + "In the context of a small dataset, the primary concern is that the model may prioritize memorizing specific examples rather than generalizing well to new and unobserved data. This limitation highlights the importance of utilizing a larger dataset during fine-tuning, as it enhances the model's ability to capture broader patterns and relationships." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "th0WS33gayn9" + }, + "source": [ + "## LoRA Fine-tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ugc2ub4nau1j" + }, + "source": [ + "![lora.png]()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pt7Nr6a7tItO" + }, + "source": [ + "Fine-tuning a model involves updating its weights (also called parameters). LLMs have a lot of weights. The Gemma 2 2B that is being used in this notebook has 2,617,270,528 parameters!\n", + "\n", + "Changing all of them can take quite some time and requires a lot of resources.\n", + "\n", + "To mitigate this issue, you are going to use a technique called: [LoRA: Low-Rank Adaptation](https://arxiv.org/abs/2106.09685)\n", + "\n", + "This technique, in summary, helps lower the number of trained weights needed by a lot, making fine-tuning more accessible.\n", + "\n", + "The key parameter used is the `rank`. In this notebook it set to 4 but you can use higher numbers to get better results but, of course, needed more resources.\n", + "\n", + "**TIP**: Train your model with lower ranks and evaluate the performance improvemnet on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance." + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install peft\n", + "from peft import get_peft_model, LoraConfig, TaskType\n", + "import torch.nn as nn\n", + "\n", + "lora_config = LoraConfig(\n", + " task_type=TaskType.CAUSAL_LM,\n", + " r=lora_rank, # Using your predefined lora_rank\n", + " lora_alpha=32,\n", + " lora_dropout=0.1\n", + ")\n", + "gemma_lm = get_peft_model(gemma_lm, lora_config) # Enable LoRA for the model\n", + "\n", + "print(gemma_lm) # Hugging Face models don't have a summary method; use print() instead\n", + "\n", + "tokenizer.model_max_length = token_limit # Set token limit in the tokenizer\n", + "\n", + "from transformers import AdamW\n", + "\n", + "optimizer_grouped_parameters = [\n", + " {'params': [p for n, p in gemma_lm.named_parameters() if not any(nd in n for nd in [\"bias\", \"LayerNorm.weight\"])], 'weight_decay': 0.01},\n", + " {'params': [p for n, p in gemma_lm.named_parameters() if any(nd in n for nd in [\"bias\", \"LayerNorm.weight\"])], 'weight_decay': 0.0}\n", + "]\n", + "optimizer = AdamW(optimizer_grouped_parameters, lr=lr_value) # Use AdamW optimizer\n", + "\n", + "\n", + "loss_fn = nn.CrossEntropyLoss() # Define the loss function\n", + "\n", + "def forward_pass(input_text):\n", + " inputs = tokenizer(input_text, return_tensors=\"pt\", max_length=token_limit, truncation=True)\n", + " outputs = gemma_lm(**inputs, labels=inputs[\"input_ids\"])\n", + " loss = outputs.loss\n", + " return loss\n", + "\n" + ], + "metadata": { + "id": "YQiQxLFKfyzx", + "outputId": "5a5b1398-61cd-48ba-d354-71e4eee37213", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting peft\n", + " Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (24.1)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.2)\n", + "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.4.1+cu121)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.44.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.5)\n", + "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.34.2)\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.4.5)\n", + "Requirement already satisfied: huggingface-hub>=0.17.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.24.7)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (3.16.1)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2024.6.1)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2.32.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.4.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.4)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2024.9.11)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.19.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (3.0.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (2024.8.30)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n", + "Downloading peft-0.13.2-py3-none-any.whl (320 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.7/320.7 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: peft\n", + "Successfully installed peft-0.13.2\n", + "PeftModelForCausalLM(\n", + " (base_model): LoraModel(\n", + " (model): Gemma2ForCausalLM(\n", + " (model): Gemma2Model(\n", + " (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n", + " (layers): ModuleList(\n", + " (0-25): 26 x Gemma2DecoderLayer(\n", + " (self_attn): Gemma2SdpaAttention(\n", + " (q_proj): lora.Linear(\n", + " (base_layer): Linear(in_features=2304, out_features=2048, bias=False)\n", + " (lora_dropout): ModuleDict(\n", + " (default): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (lora_A): ModuleDict(\n", + " (default): Linear(in_features=2304, out_features=4, bias=False)\n", + " )\n", + " (lora_B): ModuleDict(\n", + " (default): Linear(in_features=4, out_features=2048, bias=False)\n", + " )\n", + " (lora_embedding_A): ParameterDict()\n", + " (lora_embedding_B): ParameterDict()\n", + " (lora_magnitude_vector): ModuleDict()\n", + " )\n", + " (k_proj): Linear(in_features=2304, out_features=1024, bias=False)\n", + " (v_proj): lora.Linear(\n", + " (base_layer): Linear(in_features=2304, out_features=1024, bias=False)\n", + " (lora_dropout): ModuleDict(\n", + " (default): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (lora_A): ModuleDict(\n", + " (default): Linear(in_features=2304, out_features=4, bias=False)\n", + " )\n", + " (lora_B): ModuleDict(\n", + " (default): Linear(in_features=4, out_features=1024, bias=False)\n", + " )\n", + " (lora_embedding_A): ParameterDict()\n", + " (lora_embedding_B): ParameterDict()\n", + " (lora_magnitude_vector): ModuleDict()\n", + " )\n", + " (o_proj): Linear(in_features=2048, out_features=2304, bias=False)\n", + " (rotary_emb): Gemma2RotaryEmbedding()\n", + " )\n", + " (mlp): Gemma2MLP(\n", + " (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)\n", + " (up_proj): Linear(in_features=2304, out_features=9216, bias=False)\n", + " (down_proj): Linear(in_features=9216, out_features=2304, bias=False)\n", + " (act_fn): PytorchGELUTanh()\n", + " )\n", + " (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", + " (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", + " (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", + " (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", + " )\n", + " )\n", + " (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n", + " )\n", + " (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n", + " )\n", + " )\n", + ")\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", + " warnings.warn(\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hQQ47kcdpbZ9" + }, + "source": [ + "Note that enabling LoRA reduces the number of trainable parameters significantly.\n", + "\n", + "From 2,617,270,528 to **2,928,640**\n", + "\n", + "To monitor the learning progress, you will evaluate the model at the end of each epoch and save the lora weights." + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Define a custom callback-like function to handle actions at the end of each epoch\n", + "class CustomCallback:\n", + " def __init__(self, model, lora_name, lora_rank, text_gen):\n", + " self.model = model\n", + " self.lora_name = lora_name\n", + " self.lora_rank = lora_rank\n", + " self.text_gen = text_gen # text_gen function for evaluation\n", + "\n", + " def on_epoch_end(self, epoch):\n", + " # Save LoRA weights at the end of each epoch\n", + " model_name = f\"./{self.lora_name}_{self.lora_rank}_epoch{epoch+1}.lora.pt\"\n", + " self.model.save_pretrained(model_name, token=access_token) # Save model with LoRA weights locally\n", + "\n", + " # Evaluate the model using text generation\n", + " print(f\"Epoch {epoch + 1} finished. Running evaluation:\")\n", + " self.text_gen(\"Write a title\")\n", + " self.text_gen(\"Write a poem\")\n", + "\n", + "# Assuming train is your DataLoader and gemma_lm is your model\n", + "callback = CustomCallback(gemma_lm, lora_name, lora_rank, text_gen)\n", + "\n", + "# Training loop with callback-like behavior\n", + "losses = []\n", + "for epoch in range(train_epoch):\n", + " epoch_loss = 0\n", + " for batch in train: # Assuming `train` is a DataLoader or similar iterable\n", + " optimizer.zero_grad()\n", + "\n", + " inputs = tokenizer(batch, return_tensors=\"pt\", max_length=token_limit, truncation=True, padding=True)\n", + " labels = inputs[\"input_ids\"]\n", + " outputs = gemma_lm(**inputs, labels=labels)\n", + " loss = outputs.loss\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " epoch_loss += loss.item()\n", + "\n", + " losses.append(epoch_loss / len(train)) # Store average loss per epoch\n", + "\n", + " # Run custom callback at the end of each epoch\n", + " callback.on_epoch_end(epoch)\n", + "\n", + "# Plot training loss over epochs\n", + "plt.plot(losses)\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.title(\"Training Loss Over Epochs\")\n", + "plt.show()" + ], + "metadata": { + "id": "YKpmDIfXh1Kx", + "outputId": "2422fc27-260d-4e7d-ce4b-bfa4234bec7d", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1 finished. Running evaluation:\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a title\n", + "model\n", + "O Cortiço\n", + "TOTAL TIME ELAPSED: 7.93s\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a poem\n", + "model\n", + "\n", + "TOTAL TIME ELAPSED: 4.38s\n", + "Epoch 2 finished. Running evaluation:\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a title\n", + "model\n", + "A Relíquia\n", + "TOTAL TIME ELAPSED: 8.04s\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a poem\n", + "model\n", + "O Primo Basílio\n", + "TOTAL TIME ELAPSED: 8.19s\n", + "Epoch 3 finished. Running evaluation:\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a title\n", + "model\n", + "O Primo Basílio\n", + "TOTAL TIME ELAPSED: 8.21s\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a poem\n", + "model\n", + "O Primo Basílio\n", + "TOTAL TIME ELAPSED: 8.18s\n", + "Epoch 4 finished. Running evaluation:\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a title\n", + "model\n", + "A Sibila\n", + "TOTAL TIME ELAPSED: 7.24s\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a poem\n", + "model\n", + "O Primo Basílio\n", + "TOTAL TIME ELAPSED: 8.21s\n", + "Epoch 5 finished. Running evaluation:\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a title\n", + "model\n", + "A Sibila\n", + "TOTAL TIME ELAPSED: 7.31s\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a poem\n", + "model\n", + "O Primo Basílio\n", + "TOTAL TIME ELAPSED: 7.77s\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tn-jgVULyBXq" + }, + "source": [ + "Note that the model began to grasp our intent more effectively from Epoch #3 onwards.\n", + "\n", + "To compare and contrast, it was utlized the \"Write a poem\" prompt. Interestingly, in Epoch #5, the model began to generate Portuguese in response to that prompt. This shift indicates a strong influence of our training dataset on the model's behavior. However, depending on your application, such a significat change might not be desirable. In such cases, Epoch #4 would be a more suitable choice." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P-tVAKmda2Zt" + }, + "source": [ + "## Load LoRA\n", + "\n", + "Use the code below if you shared LoRA weights. It's much more lightweight than the model files themselves - for instance, a LoRA rank 4 weights file for a 10gb model might only be on the order of a few megabytes, easily shared over email." + ] + }, + { + "cell_type": "code", + "source": [ + "# Example Code for Load LoRA\n", + "\n", + "# from peft import PeftModel\n", + "\n", + "# # Load pre-trained LoRA weights (assuming the weights are saved in Hugging Face format)\n", + "# # Load the pre-trained LoRA weights\n", + "# lora_weights_path = f\"./{lora_name}_{lora_rank}_epoch{train_epoch}.lora.pt\"\n", + "\n", + "# # Load the LoRA adapter into the model using PeftModel\n", + "# gemma_lm = PeftModel.from_pretrained(gemma_lm, lora_weights_path)" + ], + "metadata": { + "id": "kVe4vjgCngsd" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ipg1u_wEKTxG" + }, + "source": [ + "## Try a different sampler\n", + "\n", + "The top-K algorithm randomly picks the next token from the tokens of top K probability." + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "\n", + "def text_gen_with_top_k(prompt, token_limit=100, top_k=50): # You can set your token limit and top_k\n", + " tick()\n", + "\n", + " # Format input, same as your original code\n", + " input_text = f\"user\\n{prompt}\\nmodel\\n\"\n", + "\n", + " # Tokenize input\n", + " inputs = tokenizer(input_text, return_tensors=\"pt\")\n", + "\n", + " # Generate text using the model with Top-K sampling\n", + " output_tokens = gemma_lm.generate(\n", + " inputs[\"input_ids\"],\n", + " max_length=token_limit,\n", + " do_sample=True, # Enable sampling\n", + " top_k=top_k, # Set Top-K sampling strategy\n", + " pad_token_id=tokenizer.eos_token_id # Prevent errors if the input length exceeds the model's limit\n", + " )\n", + "\n", + " # Decode the generated tokens back to text\n", + " output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n", + "\n", + " print(\"\\nGemma output:\")\n", + " print(output)\n", + "\n", + "\n", + "# Generate text 5 times using the top_k sampling strategy\n", + "text_gen_with_top_k(\"Write a title\", token_limit=100, top_k=50)\n" + ], + "metadata": { + "id": "K2JUE2IilwNi", + "outputId": "1ffdd099-1e0f-4a66-a7c6-81b43cc643e3", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Gemma output:\n", + "user\n", + "Write a title\n", + "model\n", + "Capitães da Areia\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3m1XaCrlMu3Y" + }, + "source": [ + "Try a slight different prompts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qC-MLxYWM1HU", + "outputId": "f26faf29-ce26-4a5c-e6ff-f2125148249b", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Gemma output:\n", + "user\n", + "Write a music title\n", + "model\n", + "A Sibila\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a poem title\n", + "model\n", + "O V alienígena\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a blog title\n", + "model\n", + "Mar Secreto do Palmar\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a movie title\n", + "model\n", + "A Hora da Estrela\n", + "\n", + "Gemma output:\n", + "user\n", + "Write a novel title\n", + "model\n", + "Os Maias\n" + ] + } + ], + "source": [ + "text_gen_with_top_k(\"Write a music title\")\n", + "text_gen_with_top_k(\"Write a poem title\")\n", + "text_gen_with_top_k(\"Write a blog title\")\n", + "text_gen_with_top_k(\"Write a movie title\")\n", + "text_gen_with_top_k(\"Write a novel title\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aEptDCED9tVp" + }, + "source": [ + "## Publish your model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T3Qhrlyy5ReL" + }, + "source": [ + "Lets save our model. It takes some time (~11 minutes) as it is a very large file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4TcvzBH995FE", + "outputId": "26a4be25-9a51-41f9-ad28-ef1708bf440f", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "('./my_gemma2_lt_pt/tokenizer_config.json',\n", + " './my_gemma2_lt_pt/special_tokens_map.json',\n", + " './my_gemma2_lt_pt/tokenizer.model',\n", + " './my_gemma2_lt_pt/added_tokens.json',\n", + " './my_gemma2_lt_pt/tokenizer.json')" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ], + "source": [ + "# Define the model name (used for both model and tokenizer)\n", + "my_model_name = \"my_gemma2_lt_pt\"\n", + "\n", + "# Save the fine-tuned model to the specified directory\n", + "gemma_lm.save_pretrained(f\"./{my_model_name}\", token=access_token)\n", + "\n", + "# # Save the tokenizer to the same directory\n", + "tokenizer.save_pretrained(f\"./{my_model_name}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xQ4de1a79zy0" + }, + "source": [ + "## Publishing on Hugging Face\n", + "\n", + "To publish your model on Hugging Face, you'll need your hugging face user (`HF_USER`) and an access token with write permission (`HF_TOKEN`) to the your secret keys." + ] + }, + { + "cell_type": "code", + "source": [ + "# Upload the model to Hugging Face Hub\n", + "my_model_name = \"my_gemma2_pt\"\n", + "writeToken = userdata.get(\"HF_WRITE_TOKEN\")\n", + "hf_repo_id = f\"{my_hf_username}/{my_model_name}\" # Correct format\n", + "gemma_lm.push_to_hub(hf_repo_id, token=writeToken)" + ], + "metadata": { + "id": "AK31-LuXpwen", + "outputId": "e96865e3-4e8c-4acf-e256-c5ec31b6f931", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136, + "referenced_widgets": [ + "84cd1a70535f4272a0206041373ff281", + "325cf823a7244a4b8e25462ae4912ea3", + "325f38ff3fef43cbbd248ff3ff714c43", + "de4fae4a6e3e420d86414a572846a9e0", + "4d729ba11c9442528af79d6e377a292c", + "49a42d4bcb8b4684a69656ee8b0449df", + "50b06db300ce432c9235071f48a18432", + "a31850a948e344848d43caae277f3200", + "fe5291b55e7147b682b671dad99d16e6", + "b485e94c1509478da6dd7a04c41e055d", + "98e4874bfa904312a4a9b00ca2ea8577", + "a163fe0a8b594568bf102caf54856484", + "dd53c0e07382441fbad818b0c3852006", + "3c316ae120e445c8b4bbc8260cf14ab4", + "9387a006080845ce9ef9aeb41fbd6c6e", + "c49cd450c3d24511a1efe272820c5f0a", + "99a5ad5c04b4456ba65118367cd2ea5a", + "d74745e493704920b56855e0cb73a39e", + "bd53b02c9808436ead8b213886e76bed", + "3f1bd9e04dc64bca88b99998339df788", + "d082ae773106445aa6ec56167b2a5cb9", + "030123f151674251ac0f0be0f8f713dc" + ] + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "README.md: 0%| | 0.00/5.17k [00:00