Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial example for classification #98

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
386 changes: 386 additions & 0 deletions docs/tutorials/gp-classification.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,386 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
yadav-sachin marked this conversation as resolved.
Show resolved Hide resolved
"outputs": [],
"source": [
"try:\n",
" import tinygp\n",
"except ImportError:\n",
" %pip install -q tinygp\n",
"\n",
"try:\n",
" import numpyro\n",
"except ImportError:\n",
" %pip uninstall -y jax jaxlib\n",
" %pip install -q numpyro jax jaxlib\n",
"\n",
"try:\n",
" import arviz\n",
"except ImportError:\n",
" %pip install arviz"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(gp-classification)="
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GP for Classification"
yadav-sachin marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we demonstrate Classification task using GP in `tinygp`. In case of classification, the test predications are the class probabilities.\n",
"\n",
"Instead of target values being in real space, the target values are in discrete values corresponding to the respective classes. In GP Classification, we use a \"link function\" to link the real output of the GP to the probabilistic distribution of the classes. We use a GP prior on the latent function which is then squashed down using the link function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import jax\n",
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpyro\n",
"import numpyro.distributions as dist"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We start with binary classification on a simple XOR Dataset. \n",
"As the classification is binary, the discrete target values are in $[0, 1]$ for Class 1 and Class 2 respectively.\n",
"\n",
"We model the binary class probabilities with Bernoulli distribution with parameter `p`, the probability of class $2$. \n",
"The `link function` we use for squashing the GP prior to range $[0, 1]$ for the `p` is the Sigmoid function."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For binary classification, we build on our model from GP regression and use:\n",
"$$ f \\sim \\mathcal{G P}\\left(0, \\mathbf{K}_{f}\\left(x, x^{\\prime}\\right)\\right)$$\n",
"with a sigmoid likelihood\n",
"$$ p(y=1 \\mid f)=\\operatorname{Sigmoid}(f)$$\n",
"or\n",
"$$\n",
"y \\sim \\text { Bernoulli(Sigmoid(f)) }\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X = jax.random.normal(jax.random.PRNGKey(1234), (200, 2))\n",
"y = jnp.logical_xor(X[:, 0] > 0, X[:, 1] > 0)\n",
"\n",
"c = plt.cm.get_cmap(\"Paired\")(y)\n",
"plt.scatter(\n",
" X[:, 0][y == 0],\n",
" X[:, 1][y == 0],\n",
" s=30,\n",
" c=c[y == 0],\n",
" cmap=plt.cm.Paired,\n",
" edgecolors=(0, 0, 0),\n",
" label=f\"Class 1\",\n",
")\n",
"plt.scatter(\n",
" X[:, 0][y == 1],\n",
" X[:, 1][y == 1],\n",
" s=30,\n",
" c=c[y == 1],\n",
" cmap=plt.cm.Paired,\n",
" edgecolors=(0, 0, 0),\n",
" label=f\"Class 2\",\n",
")\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.axhline(0, color=\"k\")\n",
"plt.axvline(0, color=\"k\")\n",
"plt.xlabel(r\"$x_{1}$\")\n",
"plt.ylabel(r\"$x_{2}$\")\n",
"plt.legend()\n",
"_ = plt.title(\"XOR Dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xs = jnp.linspace(-2, 2, num=100)\n",
"ys = jnp.linspace(-2, 2, num=100)\n",
"\n",
"xx, yy = jnp.meshgrid(xs, ys)\n",
"xx = xx.T\n",
"yy = yy.T\n",
"true_X = jnp.vstack((xx.ravel(), yy.ravel())).T\n",
"true_y = jnp.logical_xor(true_X[:, 0] > 0, true_X[:, 1] > 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def sigmoid(x):\n",
" return 1 / (1 + jnp.exp(-x))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As the likelihood is non-Gaussian we need to use Markov Chain Monte Carlo (MCMC) or Variational Inference (VI) to marginalize numerically. \n",
"<!-- This follows from the example in {ref}`markov-chain-monte-carlo-mcmc` and {ref}`sampling-with-numpyro`. -->"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax.linen.initializers import zeros\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from tinygp import kernels, GaussianProcess\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"\n",
"def model(x, y=None):\n",
" # The parameters of the GP regression\n",
" mean = numpyro.param(\"mean\", jnp.zeros(()))\n",
" sigma = numpyro.param(\"sigma\", jnp.ones(()))\n",
" ell = numpyro.param(\"ell\", jnp.ones(()))\n",
"\n",
" # Set up the kernel and GP objects\n",
" kernel = (sigma**2) * kernels.ExpSquared(scale=ell)\n",
" gp = GaussianProcess(kernel, x, diag=1e-5, mean=mean)\n",
"\n",
" gp_out = numpyro.sample(\"gp_out\", gp.numpyro_dist())\n",
" # Squashing the GP regression real output to the [0, 1] range\n",
" # using sigmoid as the link function\n",
" p = sigmoid(gp_out)\n",
"\n",
" # Finally our observation model is Bernoulli distribution\n",
" # where 'p' is the probability of Class 2\n",
" numpyro.sample(\"obs\", dist.Bernoulli(probs=p), obs=y)\n",
"\n",
" if y is not None:\n",
" # Posterior Inference on true_X input values\n",
" numpyro.deterministic(\"pred\", gp.condition(gp_out, true_X).gp.loc)\n",
"\n",
"\n",
"nuts_kernel = numpyro.infer.NUTS(model, target_accept_prob=0.8)\n",
"mcmc = numpyro.infer.MCMC(\n",
" nuts_kernel,\n",
" num_warmup=1000,\n",
" num_samples=1000,\n",
" num_chains=2,\n",
" progress_bar=True,\n",
")\n",
"rng_key = jax.random.PRNGKey(55873)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"# run the MCMC\n",
"mcmc.run(\n",
" rng_key,\n",
" X,\n",
" y=y,\n",
")\n",
"samples = mcmc.get_samples()\n",
"pred = samples[\"pred\"].block_until_ready() # Blocking to get timing right"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As MCMC is an iterative method, we need to check the convergence."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import arviz as az\n",
"\n",
"data = az.from_numpyro(mcmc)\n",
"az.summary(\n",
" data, var_names=[v for v in data.posterior.data_vars if v != \"pred\"]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the above diagnostic report, the `R-hat` is less than $1.05$ in all parameters, so the method had converged and we are good to proceed with the samples.\n",
"\n",
"Now we look at the train accuracy. \n",
"From the samples, we get `gp_out`, which is the output of the GP regression model on the train points. To convert to class probabilities, we use the Sigmoid as the link function.\n",
"Finally we end up with probabilites of Class 2. For deterministically assigning classes, $p > 0.5$ would be assigned Class 2 and $p <= 0.5$ is assigned Class 1. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"q = np.percentile(samples[\"gp_out\"], [5, 25, 50, 75, 95], axis=0)\n",
"y_hat = sigmoid(q[2]) > 0.5\n",
"\n",
"print(f\"Train Accuracy: {(y_hat==y).sum()*100/(len(y)) :0.2f}%\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see that our model did a reasonable job and got a good accuracy on the train data. \n",
"We now visualize the predictions on 2D grid points `true_X`. `pred` are the GP regression model output samples on 2D grid points. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"q = np.percentile(samples[\"pred\"], [5, 25, 50, 75, 95], axis=0)\n",
"true_y_hat = sigmoid(q[2]) > 0.5\n",
"print(f\"Test Accuracy: {(true_y_hat==true_y).sum()*100/(len(true_y)) :0.2f}%\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_pred_2d(arr, xx, yy, contour=False, ax=None, title=None):\n",
" if ax is None:\n",
" fig, ax = plt.subplots()\n",
" image = ax.imshow(\n",
" arr,\n",
" interpolation=\"nearest\",\n",
" extent=(xx.min(), xx.max(), yy.min(), yy.max()),\n",
" aspect=\"equal\",\n",
" origin=\"lower\",\n",
" cmap=plt.cm.PuOr_r,\n",
" )\n",
" if contour:\n",
" contours = ax.contour(\n",
" xx,\n",
" yy,\n",
" sigmoid(q[2]).reshape(xx.shape),\n",
" levels=[0.5],\n",
" linewidths=2,\n",
" colors=[\"k\"],\n",
" )\n",
"\n",
" divider = make_axes_locatable(ax)\n",
" cax = divider.append_axes(\"right\", size=\"5%\", pad=0.1)\n",
"\n",
" ax.get_figure().colorbar(image, cax=cax)\n",
" if title:\n",
" ax.set_title(title)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(ncols=3, figsize=(12, 4))\n",
"plot_pred_2d(q[2].reshape(xx.shape), xx, yy, ax=ax[0], title=\"f\")\n",
"plot_pred_2d(\n",
" sigmoid(q[2]).reshape(xx.shape),\n",
" xx,\n",
" yy,\n",
" ax=ax[1],\n",
" title=\"p(y=1|f) = Sigmoid(f)\",\n",
" contour=True,\n",
")\n",
"plot_pred_2d(\n",
" true_y_hat.reshape(xx.shape),\n",
" xx,\n",
" yy,\n",
" ax=ax[2],\n",
" title=\"Predictions (y) ~ Bernoulli(p(y=1|f))\",\n",
")\n",
"\n",
"fig.tight_layout()"
]
}
],
"metadata": {
"interpreter": {
"hash": "c70e5ee2d2c28a093660d06d1dd4b62023c712f6a8515efc39e65872fc2efeaf"
},
"kernelspec": {
"display_name": "Python 3.9.7 ('sachin_env')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions news/98.doc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added a new tutorial describing classification task
using GP.