Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Nastya Krouglova authored and Nastya Krouglova committed Apr 3, 2024
1 parent 5a2db87 commit cb87ed0
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 48 deletions.
2 changes: 1 addition & 1 deletion examples/00_HH_simulator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
"ax.set_xticks([])\n",
"ax.set_yticks([-80, -20, 40])\n",
"\n",
"# plot the injected current \n",
"# plot the injected current\n",
"ax = plt.subplot(gs[1])\n",
"plt.plot(t, I_inj * A_soma * 1e3, \"k\", lw=2)\n",
"plt.xlabel(\"time (ms)\")\n",
Expand Down
29 changes: 0 additions & 29 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from functools import partial
from typing import List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union
from warnings import warn

import torch
Expand All @@ -13,7 +12,6 @@
from pyknos.nflows.nn import nets
from pyknos.nflows.transforms.splines import (
rational_quadratic,
rational_quadratic,
)
from torch import Tensor, nn, relu, tanh, tensor, uint8

Expand All @@ -28,33 +26,6 @@
from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device


def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]:
"""
Get the number of elements in the input and output space.
Args:
batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring.
batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring.
embedding_net: Optional embedding network for y.
Returns:
Tuple of the number of elements in the input and output space.
"""
x_numel = batch_x[0].numel()
# Infer the output dimensionality of the embedding_net by making a forward pass.
check_data_device(batch_x, batch_y)
check_embedding_net_device(embedding_net=embedding_net, datum=batch_y)
y_numel = embedding_net(batch_y[:1]).numel()
if x_numel == 1:
warn(
"In one-dimensional output space, this flow is limited to Gaussians",
stacklevel=2,
)

return x_numel, y_numel


def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[Tensor, Tensor]:
"""
Get the number of elements in the input and output space.
Expand Down
4 changes: 2 additions & 2 deletions tutorials/00_getting_started_flexible.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
"metadata": {},
"outputs": [],
"source": [
"inference = SNPE(prior=prior) "
"inference = SNPE(prior=prior)"
]
},
{
Expand Down Expand Up @@ -266,7 +266,7 @@
"outputs": [],
"source": [
"theta_true = prior.sample((1,))\n",
"# generate our observation \n",
"# generate our observation\n",
"x_obs = simulator(theta_true)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions tutorials/01_gaussian_amortized.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
"# plot posterior samples\n",
"_ = analysis.pairplot(\n",
" posterior_samples_1, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(5, 5),\n",
" labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"], \n",
" labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n",
" points=theta_1 # add ground truth thetas\n",
")"
]
Expand Down Expand Up @@ -238,7 +238,7 @@
"# plot posterior samples\n",
"_ = analysis.pairplot(\n",
" posterior_samples_2, limits=[[-2, 2], [-2, 2], [-2, 2]], figsize=(5, 5),\n",
" labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"], \n",
" labels=[r\"$\\theta_1$\", r\"$\\theta_2$\", r\"$\\theta_3$\"],\n",
" points=theta_2 # add ground truth thetas\n",
")"
]
Expand Down
28 changes: 14 additions & 14 deletions tutorials/17_importance_sampled_posteriors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
"class Simulator:\n",
" def __init__(self):\n",
" pass\n",
" \n",
"\n",
" def log_likelihood(self, theta, x):\n",
" return MultivariateNormal(theta, eye(2)).log_prob(x)\n",
"\n",
Expand Down Expand Up @@ -312,7 +312,7 @@
"source": [
"# get weighted samples\n",
"theta_inferred_is = theta_inferred[torch.where(w > torch.rand(len(w)) * torch.max(w))]\n",
"# *Note*: we here perform rejection sampling, as the plotting function \n",
"# *Note*: we here perform rejection sampling, as the plotting function\n",
"# used below does not support weighted samples. In general, with rejection\n",
"# sampling the number of samples will be smaller than the effective sample\n",
"# size unless we allow for duplicate samples.\n",
Expand All @@ -323,8 +323,8 @@
"\n",
"# plot\n",
"fig, ax = marginal_plot(\n",
" [theta_inferred, theta_inferred_is, gt_samples], \n",
" limits=[[-5, 5], [-5, 5]], \n",
" [theta_inferred, theta_inferred_is, gt_samples],\n",
" limits=[[-5, 5], [-5, 5]],\n",
" figsize=(5, 1.5),\n",
" diag=\"kde\", # smooth histogram\n",
")\n",
Expand Down Expand Up @@ -22243,8 +22243,8 @@
],
"source": [
"fig, ax = marginal_plot(\n",
" [theta_inferred_sir_2, theta_inferred_sir_32, gt_samples], \n",
" limits=[[-5, 5], [-5, 5]], \n",
" [theta_inferred_sir_2, theta_inferred_sir_32, gt_samples],\n",
" limits=[[-5, 5], [-5, 5]],\n",
" figsize=(5, 1.5),\n",
" diag=\"kde\", # smooth histogram\n",
")\n",
Expand Down Expand Up @@ -22280,8 +22280,8 @@
],
"source": [
"fig, ax = marginal_plot(\n",
" [gt_samples, theta_inferred], \n",
" limits=[[-5, 5], [-5, 5]], \n",
" [gt_samples, theta_inferred],\n",
" limits=[[-5, 5], [-5, 5]],\n",
" weights=[None, w],\n",
" figsize=(5, 1.5),\n",
" diag=\"kde\", # smooth histogram\n",
Expand Down Expand Up @@ -22400,9 +22400,9 @@
"\n",
"for i in range(len(observations)):\n",
" fig, ax = marginal_plot(\n",
" [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], \n",
" limits=[[-5, 5], [-5, 5]], \n",
" points=theta_gt[i], \n",
" [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],\n",
" limits=[[-5, 5], [-5, 5]],\n",
" points=theta_gt[i],\n",
" figsize=(5, 1.5),\n",
" diag=\"kde\", # smooth histogram\n",
" )\n",
Expand Down Expand Up @@ -23967,9 +23967,9 @@
"\n",
"for i in range(len(observations)):\n",
" fig, ax = marginal_plot(\n",
" [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]], \n",
" limits=[[-5, 5], [-5, 5]], \n",
" points=theta_gt[i], \n",
" [non_corrected_samples_for_all_observations[i], corrected_samples_for_all_observations[i], true_samples[i]],\n",
" limits=[[-5, 5], [-5, 5]],\n",
" points=theta_gt[i],\n",
" figsize=(5, 1.5),\n",
" diag=\"kde\", # smooth histogram\n",
" )\n",
Expand Down

0 comments on commit cb87ed0

Please sign in to comment.