diff --git a/extra/a-projections.ipynb b/extra/a-projections.ipynb index c6256b3eb..acda54498 100644 --- a/extra/a-projections.ipynb +++ b/extra/a-projections.ipynb @@ -10,7 +10,7 @@ "import random\n", "import torch\n", "from plot_lib import set_default, show_mat\n", - "from matplotlib.pyplot import plot, subplot, axhline, axvline, legend\n", + "from matplotlib.pyplot import plot, subplot, axhline, axvline, legend, suptitle, title\n", "from numpy import pi as π\n", "from numpy import sqrt, cos, arange, arccos" ] @@ -103,7 +103,8 @@ "ax1.legend([f'{d}D' for d in d_range], ncol=3)\n", "ax1.set_ylim(0, 350)\n", "ax2.legend([f'{d}D' for d in d_range], ncol=3)\n", - "ax2.set_rlim(0, 350)" + "ax2.set_rlim(0, 350)\n", + "suptitle('Normalised scalar product distribution')" ] }, { @@ -116,36 +117,60 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# Where are those two blue spikes coming from?\n", - "ϕ = arange(0, 1, 1/100) * π\n", - "f = cos(ϕ)" + "$$\\Vert a - b \\Vert = \\sqrt{\\Vert a \\Vert^2 + \\Vert b \\Vert^2 - 2 a^\\top b}$$" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ - "plot(ϕ, f)\n", - "axhline()\n", - "axvline()\n", - "legend(['cosine'])" + "# Compute the histogram of points uniformly distributed on a unit sphere in d dimensions\n", + "r = 3\n", + "bins = 200 + 1\n", + "N = 10_000\n", + "d_range = (2, 3, 4, 6, 10, 16)\n", + "ax1 = subplot(111)\n", + "for d in d_range:\n", + " A = torch.randn(N, d) / sqrt(d)\n", + " B = torch.randn(N, d) / sqrt(d)\n", + " h = torch.histc(torch.cdist(A[None,:,:], B[None,:,:]).view(-1), bins, 0, r) / N\n", + " ax1.plot(torch.linspace(0, r, bins).numpy(), h.numpy())\n", + "ax1.legend([f'{d}D' for d in d_range], ncol=3)\n", + "ax1.set_ylim(ymin=0);\n", + "title('Distribution of distances in D dimensions')" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ - "c = arange(-1+1/100, 1, 1/100)\n", - "g = arccos(c)" + "# Compute the histogram of points uniformly distributed on a unit sphere in d dimensions\n", + "r = 1\n", + "bins = 200 + 1\n", + "N = 10_000\n", + "d_range = (2, 3, 4, 6, 10, 16)\n", + "ax = subplot(111)\n", + "for d in d_range:\n", + " A = torch.randn(N, d) / sqrt(d)\n", + " B = torch.randn(N, d) / sqrt(d)\n", + " dist = torch.cdist(A[None,:,:], B[None,:,:])\n", + " h = torch.histc(dist.min(-1)[0].view(-1), bins, 0, r) / N\n", + " ax.plot(torch.linspace(0, r, bins).numpy(), h.numpy())\n", + "l = [f'{d}D' for d in d_range]\n", + "ax.legend(l, ncol=3)\n", + "ax.set_ylim(ymin=0, ymax=.10)\n", + "suptitle('Distribution of min distance')" ] }, { @@ -154,11 +179,9 @@ "metadata": {}, "outputs": [], "source": [ - "plot(c, g)\n", - "plot(c[:-1], (g[1:]-g[:-1])*100)\n", - "legend(['arccos', 'D[arccos]'])\n", - "axhline()\n", - "axvline()" + "# Where are those two blue spikes coming from?\n", + "ϕ = arange(0, 1, 1/100) * π\n", + "f = cos(ϕ)" ] }, { @@ -169,19 +192,20 @@ }, "outputs": [], "source": [ - "# Compute the histogram of points uniformly distributed on a unit sphere in d dimensions\n", - "r = 3\n", - "bins = 200 + 1\n", - "N = 10_000\n", - "d_range = (2, 3, 4, 6, 10, 16)\n", - "ax1 = subplot(111)\n", - "for d in d_range:\n", - " A = torch.randn(N, d) / sqrt(d)\n", - " B = torch.randn(N, d) / sqrt(d)\n", - " h = torch.histc(torch.cdist(A[None,:,:], B[None,:,:]).view(-1), bins, 0, r) / N\n", - " ax1.plot(torch.linspace(0, r, bins).numpy(), h.numpy())\n", - "ax1.legend([f'{d}D' for d in d_range], ncol=3)\n", - "ax1.set_ylim(ymin=0);" + "plot(ϕ, f)\n", + "axhline()\n", + "axvline()\n", + "legend(['cosine'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c = arange(-1+1/100, 1, 1/100)\n", + "g = arccos(c)" ] }, { @@ -192,22 +216,11 @@ }, "outputs": [], "source": [ - "# Compute the histogram of points uniformly distributed on a unit sphere in d dimensions\n", - "r = 1\n", - "bins = 200 + 1\n", - "N = 10_000\n", - "d_range = (2, 3, 4, 6, 10, 16)\n", - "ax = subplot(111)\n", - "for N in (5_000, 25_000):\n", - " for d in d_range:\n", - " A = torch.randn(N, d) / sqrt(d)\n", - " B = torch.randn(N, d) / sqrt(d)\n", - " dist = torch.cdist(A[None,:,:], B[None,:,:])\n", - " h = torch.histc(dist.min(-1)[0].view(-1), bins, 0, r) / N\n", - " ax.plot(torch.linspace(0, r, bins).numpy(), h.numpy())\n", - "l = [f'{d}D' for d in d_range]\n", - "ax.legend(l = l, ncol=3)\n", - "ax.set_ylim(ymin=0, ymax=.10)" + "plot(c, g)\n", + "plot(c[:-1], (g[1:]-g[:-1])*100)\n", + "legend(['arccos', 'D[arccos]'])\n", + "axhline()\n", + "axvline()" ] } ],