Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example of reading learning rate from optimizer state #363

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 162 additions & 77 deletions docs/optax-101.ipynb
Original file line number Diff line number Diff line change
@@ -1,20 +1,4 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Optax 101",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
Expand All @@ -40,27 +24,23 @@
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Gg6zyMBqydty"
},
"outputs": [],
"source": [
"import random\n",
"from typing import Tuple\n",
"\n",
"import optax\n",
"import jax.numpy as jnp\n",
"import jax\n",
"import numpy as np\n",
"\n",
"BATCH_SIZE = 5\n",
"NUM_TRAIN_STEPS = 1_000\n",
"RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))\n",
"RAW_TRAINING_DATA = jax.random.randint(jax.random.PRNGKey(42), (NUM_TRAIN_STEPS, BATCH_SIZE, 1), 0, 255)\n",
"\n",
"TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)\n",
"TRAINING_DATA = jnp.unpackbits(RAW_TRAINING_DATA.astype(jnp.uint8), axis=-1)\n",
"LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -77,9 +57,11 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "Syp9LJ338h9-"
},
"outputs": [],
"source": [
"initial_params = {\n",
" 'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),\n",
Expand All @@ -101,9 +83,7 @@
" loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)\n",
"\n",
" return loss_value.mean()"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -118,62 +98,62 @@
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "JsbPBTF09FGY",
"executionInfo": {
"elapsed": 6046,
"status": "ok",
"timestamp": 1636155226542,
"user_tz": 0,
"elapsed": 6046,
"user": {
"displayName": "Ross Hemsley",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjSZqBnQizDvVofyb2N_r9W3cP4duk9mv1mxCb9=s64",
"userId": "11415908946302743815"
}
},
"user_tz": 0
},
"id": "JsbPBTF09FGY",
"outputId": "c427f94f-a605-44fc-b519-707bc5d47b7d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step: 0, Loss: 5.624\n",
"Step: 100, Loss: 0.188\n",
"Step: 200, Loss: 0.053\n",
"Step: 300, Loss: 0.025\n",
"Step: 400, Loss: 0.004\n",
"Step: 500, Loss: 0.028\n",
"Step: 600, Loss: 0.002\n",
"Step: 700, Loss: 0.025\n",
"Step: 800, Loss: 0.017\n",
"Step: 900, Loss: 0.003\n"
]
}
],
"source": [
"@jax.jit\n",
"def step(params, opt_state, batch, labels):\n",
" loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)\n",
" updates, opt_state = optimizer.update(grads, opt_state, params)\n",
" params = optax.apply_updates(params, updates)\n",
" return params, opt_state, loss_value\n",
"\n",
"def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:\n",
" opt_state = optimizer.init(params)\n",
"\n",
" @jax.jit\n",
" def step(params, opt_state, batch, labels):\n",
" loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)\n",
" updates, opt_state = optimizer.update(grads, opt_state, params)\n",
" params = optax.apply_updates(params, updates)\n",
" return params, opt_state, loss_value\n",
"\n",
" for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):\n",
" params, opt_state, loss_value = step(params, opt_state, batch, labels)\n",
" if i % 100 == 0:\n",
" print(f'step {i}, loss: {loss_value}')\n",
" print(f'Step: {i:3}, Loss: {loss_value:.3f}')\n",
"\n",
" return params\n",
"\n",
"# Finally, we can fit our parametrized function using the Adam optimizer\n",
"# provided by optax.\n",
"optimizer = optax.adam(learning_rate=1e-2)\n",
"params = fit(initial_params, optimizer)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"step 0, loss: 5.60183048248291\n",
"step 100, loss: 0.14773361384868622\n",
"step 200, loss: 0.28999248147010803\n",
"step 300, loss: 0.05951451137661934\n",
"step 400, loss: 0.08592046797275543\n",
"step 500, loss: 0.005035111214965582\n",
"step 600, loss: 0.0028563595842570066\n",
"step 700, loss: 0.013286210596561432\n",
"step 800, loss: 0.01311601884663105\n",
"step 900, loss: 0.003692328929901123\n"
]
}
]
},
{
Expand All @@ -200,21 +180,40 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "SZegYQajDtLi",
"executionInfo": {
"elapsed": 734,
"status": "ok",
"timestamp": 1636155227388,
"user_tz": 0,
"elapsed": 734,
"user": {
"displayName": "Ross Hemsley",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjSZqBnQizDvVofyb2N_r9W3cP4duk9mv1mxCb9=s64",
"userId": "11415908946302743815"
}
},
"user_tz": 0
},
"id": "SZegYQajDtLi",
"outputId": "f65f9fd8-8e9c-4ae6-e759-62362ff94f53"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step: 0, Loss: 5.624\n",
"Step: 100, Loss: 0.000\n",
"Step: 200, Loss: 0.000\n",
"Step: 300, Loss: 0.000\n",
"Step: 400, Loss: 0.000\n",
"Step: 500, Loss: 0.000\n",
"Step: 600, Loss: 0.000\n",
"Step: 700, Loss: 0.000\n",
"Step: 800, Loss: 0.000\n",
"Step: 900, Loss: 0.000\n"
]
}
],
"source": [
"schedule = optax.warmup_cosine_decay_schedule(\n",
" init_value=0.0,\n",
Expand All @@ -230,26 +229,112 @@
")\n",
"\n",
"params = fit(initial_params, optimizer)"
],
"execution_count": null,
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N7-efvtM16pO"
},
"source": [
"## Reading the Learning Rate inside the Train Loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GzPJMRYV16pP"
},
"source": [
"Sometimes we want to access certain hyperparameters in the optimizer. For example, we may want to log the learning rate at a service.\n",
"\n",
"To extract the learning rate inside the train loop, we can use the [inject_hyperparams](https://optax.readthedocs.io/en/latest/api.html#optax.inject_hyperparams) wrapper to make any hyperparameter a modifiable part of the optimizer state. This means that you can promote the learning rate to be part of the optimizer state so that you can access it in the optimizer state directly.\n",
"\n",
"The following example demonstrates how to extend the previous code to extract the learning rate."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FIT1aO9_16pP",
"outputId": "f90205ee-9359-42b3-f745-15aa67d33b62"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"output_type": "stream",
"text": [
"step 0, loss: 5.60183048248291\n",
"step 100, loss: 1.0181801179953709e-08\n",
"step 200, loss: 0.27725887298583984\n",
"step 300, loss: 0.0\n",
"step 400, loss: 0.0\n",
"step 500, loss: 0.0\n",
"step 600, loss: 0.0\n",
"step 700, loss: 0.0\n",
"step 800, loss: 0.0\n",
"step 900, loss: 0.0\n"
"Available hyperparams: b1 b2 eps eps_root weight_decay learning_rate\n",
"\n",
"Step 0, Loss: 5.624, Learning rate: 0.020\n",
"Step 100, Loss: 0.000, Learning rate: 0.993\n",
"Step 200, Loss: 0.000, Learning rate: 0.939\n",
"Step 300, Loss: 0.000, Learning rate: 0.837\n",
"Step 400, Loss: 0.000, Learning rate: 0.699\n",
"Step 500, Loss: 0.000, Learning rate: 0.540\n",
"Step 600, Loss: 0.000, Learning rate: 0.376\n",
"Step 700, Loss: 0.000, Learning rate: 0.225\n",
"Step 800, Loss: 0.000, Learning rate: 0.104\n",
"Step 900, Loss: 0.000, Learning rate: 0.027\n"
]
}
],
"source": [
"# Wrap the optimizer to inject the hyperparameters\n",
"optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=schedule)\n",
"\n",
"def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:\n",
" opt_state = optimizer.init(params)\n",
"\n",
" # Since we injected hyperparams, we can access them directly here\n",
" print(f'Available hyperparams: {\" \".join(opt_state.hyperparams.keys())}\\n')\n",
"\n",
" for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):\n",
" params, opt_state, loss_value = step(params, opt_state, batch, labels)\n",
" if i % 100 == 0:\n",
" # Get the updated learning rate\n",
" lr = opt_state.hyperparams['learning_rate']\n",
" print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.3f}')\n",
"\n",
" return params\n",
"\n",
"params = fit(initial_params, optimizer)"
]
}
]
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "optax-101.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3.9.13 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "626d743d6476408aa1b36c3ff0d1f9d9d03e37c6879626ddfcdd13d658004bbf"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}