From cb87ed041eb22ca90165d045dc575dc5675a648c Mon Sep 17 00:00:00 2001 From: Nastya Krouglova Date: Wed, 3 Apr 2024 20:16:45 +0200 Subject: [PATCH] merge --- examples/00_HH_simulator.ipynb | 2 +- sbi/neural_nets/flow.py | 29 ------------------- tutorials/00_getting_started_flexible.ipynb | 4 +-- tutorials/01_gaussian_amortized.ipynb | 4 +-- .../17_importance_sampled_posteriors.ipynb | 28 +++++++++--------- 5 files changed, 19 insertions(+), 48 deletions(-) diff --git a/examples/00_HH_simulator.ipynb b/examples/00_HH_simulator.ipynb index 269cd3663..3b88cce0b 100644 --- a/examples/00_HH_simulator.ipynb +++ b/examples/00_HH_simulator.ipynb @@ -256,7 +256,7 @@ "ax.set_xticks([])\n", "ax.set_yticks([-80, -20, 40])\n", "\n", - "# plot the injected current \n", + "# plot the injected current\n", "ax = plt.subplot(gs[1])\n", "plt.plot(t, I_inj * A_soma * 1e3, \"k\", lw=2)\n", "plt.xlabel(\"time (ms)\")\n", diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index b3028f7e8..709690284 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -3,7 +3,6 @@ from functools import partial from typing import List, Optional, Sequence, Tuple, Union -from typing import List, Optional, Sequence, Tuple, Union from warnings import warn import torch @@ -13,7 +12,6 @@ from pyknos.nflows.nn import nets from pyknos.nflows.transforms.splines import ( rational_quadratic, - rational_quadratic, ) from torch import Tensor, nn, relu, tanh, tensor, uint8 @@ -28,33 +26,6 @@ from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device -def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]: - """ - Get the number of elements in the input and output space. - - Args: - batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. - batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. - embedding_net: Optional embedding network for y. - - Returns: - Tuple of the number of elements in the input and output space. - - """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - y_numel = embedding_net(batch_y[:1]).numel() - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=2, - ) - - return x_numel, y_numel - - def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]: """ Get the number of elements in the input and output space. diff --git a/tutorials/00_getting_started_flexible.ipynb b/tutorials/00_getting_started_flexible.ipynb index 552bb71fc..b8094b100 100644 --- a/tutorials/00_getting_started_flexible.ipynb +++ b/tutorials/00_getting_started_flexible.ipynb @@ -136,7 +136,7 @@ "metadata": {}, "outputs": [], "source": [ - "inference = SNPE(prior=prior) " + "inference = SNPE(prior=prior)" ] }, { @@ -266,7 +266,7 @@ "outputs": [], "source": [ "theta_true = prior.sample((1,))\n", - "# generate our observation \n", + "# generate our observation\n", "x_obs = simulator(theta_true)" ] }, diff --git a/tutorials/01_gaussian_amortized.ipynb b/tutorials/01_gaussian_amortized.ipynb index 67b9e7833..69f798783 100644 --- a/tutorials/01_gaussian_amortized.ipynb +++ b/tutorials/01_gaussian_amortized.ipynb @@ -183,7 +183,7 @@ "# plot posterior samples\n", "_ = analysis.pairplot(\n", " posterior_samples_1, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(5, 5),\n", - " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"], \n", + " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n", " points=theta_1 # add ground truth thetas\n", ")" ] @@ -238,7 +238,7 @@ "# plot posterior samples\n", "_ = analysis.pairplot(\n", " posterior_samples_2, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(5, 5),\n", - " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"], \n", + " labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n", " points=theta_2 # add ground truth thetas\n", ")" ] diff --git a/tutorials/17_importance_sampled_posteriors.ipynb b/tutorials/17_importance_sampled_posteriors.ipynb index f23f92d8d..7c17ebf5b 100644 --- a/tutorials/17_importance_sampled_posteriors.ipynb +++ b/tutorials/17_importance_sampled_posteriors.ipynb @@ -162,7 +162,7 @@ "class Simulator:\n", " def __init__(self):\n", " pass\n", - " \n", + "\n", " def log_likelihood(self, theta, x):\n", " return MultivariateNormal(theta, eye(2)).log_prob(x)\n", "\n", @@ -312,7 +312,7 @@ "source": [ "# get weighted samples\n", "theta_inferred_is = theta_inferred[torch.where(w > torch.rand(len(w)) * torch.max(w))]\n", - "# *Note*: we here perform rejection sampling, as the plotting function \n", + "# *Note*: we here perform rejection sampling, as the plotting function\n", "# used below does not support weighted samples. In general, with rejection\n", "# sampling the number of samples will be smaller than the effective sample\n", "# size unless we allow for duplicate samples.\n", @@ -323,8 +323,8 @@ "\n", "# plot\n", "fig, ax = marginal_plot(\n", - " [theta_inferred, theta_inferred_is, gt_samples], \n", - " limits=[[-5, 5], [-5, 5]], \n", + " [theta_inferred, theta_inferred_is, gt_samples],\n", + " limits=[[-5, 5], [-5, 5]],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", ")\n", @@ -22243,8 +22243,8 @@ ], "source": [ "fig, ax = marginal_plot(\n", - " [theta_inferred_sir_2, theta_inferred_sir_32, gt_samples], \n", - " limits=[[-5, 5], [-5, 5]], \n", + " [theta_inferred_sir_2, theta_inferred_sir_32, gt_samples],\n", + " limits=[[-5, 5], [-5, 5]],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", ")\n", @@ -22280,8 +22280,8 @@ ], "source": [ "fig, ax = marginal_plot(\n", - " [gt_samples, theta_inferred], \n", - " limits=[[-5, 5], [-5, 5]], \n", + " [gt_samples, theta_inferred],\n", + " limits=[[-5, 5], [-5, 5]],\n", " weights=[None, w],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", @@ -22400,9 +22400,9 @@ "\n", "for i in range(len(observations)):\n", " fig, ax = marginal_plot(\n", - " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], \n", - " limits=[[-5, 5], [-5, 5]], \n", - " points=theta_gt[i], \n", + " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],\n", + " limits=[[-5, 5], [-5, 5]],\n", + " points=theta_gt[i],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", " )\n", @@ -23967,9 +23967,9 @@ "\n", "for i in range(len(observations)):\n", " fig, ax = marginal_plot(\n", - " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], \n", - " limits=[[-5, 5], [-5, 5]], \n", - " points=theta_gt[i], \n", + " [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],\n", + " limits=[[-5, 5], [-5, 5]],\n", + " points=theta_gt[i],\n", " figsize=(5, 1.5),\n", " diag=\"kde\", # smooth histogram\n", " )\n",