Skip to content

Commit

Permalink
Add LaTeX for norm decomposition and plot titles
Browse files Browse the repository at this point in the history
  • Loading branch information
Atcold committed Sep 27, 2021
1 parent bb97aaf commit 5ec56a3
Showing 1 changed file with 63 additions and 50 deletions.
113 changes: 63 additions & 50 deletions extra/a-projections.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"import random\n",
"import torch\n",
"from plot_lib import set_default, show_mat\n",
"from matplotlib.pyplot import plot, subplot, axhline, axvline, legend\n",
"from matplotlib.pyplot import plot, subplot, axhline, axvline, legend, suptitle, title\n",
"from numpy import pi as π\n",
"from numpy import sqrt, cos, arange, arccos"
]
Expand Down Expand Up @@ -103,7 +103,8 @@
"ax1.legend([f'{d}D' for d in d_range], ncol=3)\n",
"ax1.set_ylim(0, 350)\n",
"ax2.legend([f'{d}D' for d in d_range], ncol=3)\n",
"ax2.set_rlim(0, 350)"
"ax2.set_rlim(0, 350)\n",
"suptitle('Normalised scalar product distribution')"
]
},
{
Expand All @@ -116,36 +117,60 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"# Where are those two blue spikes coming from?\n",
"ϕ = arange(0, 1, 1/100) * π\n",
"f = cos(ϕ)"
"$$\\Vert a - b \\Vert = \\sqrt{\\Vert a \\Vert^2 + \\Vert b \\Vert^2 - 2 a^\\top b}$$"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"plot(ϕ, f)\n",
"axhline()\n",
"axvline()\n",
"legend(['cosine'])"
"# Compute the histogram of points uniformly distributed on a unit sphere in d dimensions\n",
"r = 3\n",
"bins = 200 + 1\n",
"N = 10_000\n",
"d_range = (2, 3, 4, 6, 10, 16)\n",
"ax1 = subplot(111)\n",
"for d in d_range:\n",
" A = torch.randn(N, d) / sqrt(d)\n",
" B = torch.randn(N, d) / sqrt(d)\n",
" h = torch.histc(torch.cdist(A[None,:,:], B[None,:,:]).view(-1), bins, 0, r) / N\n",
" ax1.plot(torch.linspace(0, r, bins).numpy(), h.numpy())\n",
"ax1.legend([f'{d}D' for d in d_range], ncol=3)\n",
"ax1.set_ylim(ymin=0);\n",
"title('Distribution of distances in D dimensions')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"c = arange(-1+1/100, 1, 1/100)\n",
"g = arccos(c)"
"# Compute the histogram of points uniformly distributed on a unit sphere in d dimensions\n",
"r = 1\n",
"bins = 200 + 1\n",
"N = 10_000\n",
"d_range = (2, 3, 4, 6, 10, 16)\n",
"ax = subplot(111)\n",
"for d in d_range:\n",
" A = torch.randn(N, d) / sqrt(d)\n",
" B = torch.randn(N, d) / sqrt(d)\n",
" dist = torch.cdist(A[None,:,:], B[None,:,:])\n",
" h = torch.histc(dist.min(-1)[0].view(-1), bins, 0, r) / N\n",
" ax.plot(torch.linspace(0, r, bins).numpy(), h.numpy())\n",
"l = [f'{d}D' for d in d_range]\n",
"ax.legend(l, ncol=3)\n",
"ax.set_ylim(ymin=0, ymax=.10)\n",
"suptitle('Distribution of min distance')"
]
},
{
Expand All @@ -154,11 +179,9 @@
"metadata": {},
"outputs": [],
"source": [
"plot(c, g)\n",
"plot(c[:-1], (g[1:]-g[:-1])*100)\n",
"legend(['arccos', 'D[arccos]'])\n",
"axhline()\n",
"axvline()"
"# Where are those two blue spikes coming from?\n",
"ϕ = arange(0, 1, 1/100) * π\n",
"f = cos(ϕ)"
]
},
{
Expand All @@ -169,19 +192,20 @@
},
"outputs": [],
"source": [
"# Compute the histogram of points uniformly distributed on a unit sphere in d dimensions\n",
"r = 3\n",
"bins = 200 + 1\n",
"N = 10_000\n",
"d_range = (2, 3, 4, 6, 10, 16)\n",
"ax1 = subplot(111)\n",
"for d in d_range:\n",
" A = torch.randn(N, d) / sqrt(d)\n",
" B = torch.randn(N, d) / sqrt(d)\n",
" h = torch.histc(torch.cdist(A[None,:,:], B[None,:,:]).view(-1), bins, 0, r) / N\n",
" ax1.plot(torch.linspace(0, r, bins).numpy(), h.numpy())\n",
"ax1.legend([f'{d}D' for d in d_range], ncol=3)\n",
"ax1.set_ylim(ymin=0);"
"plot(ϕ, f)\n",
"axhline()\n",
"axvline()\n",
"legend(['cosine'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c = arange(-1+1/100, 1, 1/100)\n",
"g = arccos(c)"
]
},
{
Expand All @@ -192,22 +216,11 @@
},
"outputs": [],
"source": [
"# Compute the histogram of points uniformly distributed on a unit sphere in d dimensions\n",
"r = 1\n",
"bins = 200 + 1\n",
"N = 10_000\n",
"d_range = (2, 3, 4, 6, 10, 16)\n",
"ax = subplot(111)\n",
"for N in (5_000, 25_000):\n",
" for d in d_range:\n",
" A = torch.randn(N, d) / sqrt(d)\n",
" B = torch.randn(N, d) / sqrt(d)\n",
" dist = torch.cdist(A[None,:,:], B[None,:,:])\n",
" h = torch.histc(dist.min(-1)[0].view(-1), bins, 0, r) / N\n",
" ax.plot(torch.linspace(0, r, bins).numpy(), h.numpy())\n",
"l = [f'{d}D' for d in d_range]\n",
"ax.legend(l = l, ncol=3)\n",
"ax.set_ylim(ymin=0, ymax=.10)"
"plot(c, g)\n",
"plot(c[:-1], (g[1:]-g[:-1])*100)\n",
"legend(['arccos', 'D[arccos]'])\n",
"axhline()\n",
"axvline()"
]
}
],
Expand Down

0 comments on commit 5ec56a3

Please sign in to comment.