Skip to content

Commit

Permalink
Zuko density estimators (#1088) (#1116)
Browse files Browse the repository at this point in the history
* Zuko density estimators (#1088)

* update zuko to 1.1.0

* test zuko_gmm commit

* build_zuko_nsf added

* add build_zuko_naf, update test

* add license change to pr template.

* CLN pyproject.toml (#1009)

* CLN pyproject.toml

* CLN optional deps comment

* CLN alphabetical order

* fix x_o and broken link tutorial 7 (#1003)

* fix x_o and broken link tutorial 7

* typo in title

* suppress plotting output

---------

Co-authored-by: Matthijs <[email protected]>

* replace prepare_for_sbi in tutorials (#1013)

* add zuko density estimators

* not working gmm

* update tests for PR

* update PR for pyright

* resolve pyright

* add reportArgumentType

* resolve pyright issue

* resolve all issues pyright

* resolve pyright

* add typing and docstring

* add functions from factory to test

* remove comment mdn file

* add docstrings flow file

* add docstring in density_estimator_test.py

* Update sbi/neural_nets/flow.py

Co-authored-by: Sebastian Bischoff <[email protected]>

* Update sbi/neural_nets/flow.py

Co-authored-by: Sebastian Bischoff <[email protected]>

* Update sbi/neural_nets/flow.py

Co-authored-by: Sebastian Bischoff <[email protected]>

* removed pyright

---------

Co-authored-by: bkmi <[email protected]>
Co-authored-by: Nastya Krouglova <[email protected]>
Co-authored-by: Jan Boelts <[email protected]>
Co-authored-by: Thomas Moreau <[email protected]>
Co-authored-by: Matthijs Pals <[email protected]>
Co-authored-by: Matthijs <[email protected]>
Co-authored-by: zinaStef <[email protected]>
Co-authored-by: Sebastian Bischoff <[email protected]>

* merge

* hate

* merge

* merge

* merge

* merge

* MERGE

* remove cnf

* implement changes Jan

* Update sbi/neural_nets/factory.py

Co-authored-by: Jan <[email protected]>

* resolve issues Jan

* undo changes to tutorials folder.

* sort dependencies.

---------

Co-authored-by: bkmi <[email protected]>
Co-authored-by: Nastya Krouglova <[email protected]>
Co-authored-by: Jan Boelts <[email protected]>
Co-authored-by: Thomas Moreau <[email protected]>
Co-authored-by: Matthijs Pals <[email protected]>
Co-authored-by: Matthijs <[email protected]>
Co-authored-by: zinaStef <[email protected]>
Co-authored-by: Sebastian Bischoff <[email protected]>
Co-authored-by: Jan <[email protected]>
  • Loading branch information
10 people authored Apr 5, 2024
1 parent b2d7d21 commit cbf9dca
Show file tree
Hide file tree
Showing 6 changed files with 806 additions and 163 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ dependencies = [
"tensorboard",
"torch>=1.8.0",
"tqdm",
"zuko>=1.0.0",
"pymc>=5.0.0",
"zuko>=1.1.0",
]

[project.optional-dependencies]
Expand Down
8 changes: 4 additions & 4 deletions sbi/neural_nets/density_estimators/zuko_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import Tensor, nn
from zuko.flows import Flow
from zuko.flows.core import Flow

from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.sbi_types import Shape
Expand Down Expand Up @@ -125,6 +125,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
emb_cond = emb_cond.expand(batch_shape + (emb_cond.shape[-1],))

dists = self.net(emb_cond)

log_probs = dists.log_prob(input)

return log_probs
Expand Down Expand Up @@ -166,7 +167,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:

emb_cond = self._embedding_net(condition)
dists = self.net(emb_cond)
# zuko.sample() returns (*sample_shape, *batch_shape, input_size).

samples = dists.sample(sample_shape).reshape(*batch_shape, *sample_shape, -1)

return samples
Expand All @@ -190,9 +191,8 @@ def sample_and_log_prob(

emb_cond = self._embedding_net(condition)
dists = self.net(emb_cond)
samples, log_probs = dists.rsample_and_log_prob(sample_shape)
# zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...).

samples, log_probs = dists.rsample_and_log_prob(sample_shape)
samples = samples.reshape(*batch_shape, *sample_shape, -1)
log_probs = log_probs.reshape(*batch_shape, *sample_shape)

Expand Down
67 changes: 37 additions & 30 deletions sbi/neural_nets/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,37 @@
build_maf,
build_maf_rqs,
build_nsf,
build_zuko_bpf,
build_zuko_gf,
build_zuko_maf,
build_zuko_naf,
build_zuko_ncsf,
build_zuko_nice,
build_zuko_nsf,
build_zuko_sospf,
build_zuko_unaf,
)
from sbi.neural_nets.mdn import build_mdn
from sbi.neural_nets.mnle import build_mnle

model_builders = {
"mdn": build_mdn,
"made": build_made,
"maf": build_maf,
"maf_rqs": build_maf_rqs,
"nsf": build_nsf,
"mnle": build_mnle,
"zuko_nice": build_zuko_nice,
"zuko_maf": build_zuko_maf,
"zuko_nsf": build_zuko_nsf,
"zuko_ncsf": build_zuko_ncsf,
"zuko_sospf": build_zuko_sospf,
"zuko_naf": build_zuko_naf,
"zuko_unaf": build_zuko_unaf,
"zuko_gf": build_zuko_gf,
"zuko_bpf": build_zuko_bpf,
}


def classifier_nn(
model: str,
Expand Down Expand Up @@ -162,22 +188,10 @@ def likelihood_nn(
)

def build_fn(batch_theta, batch_x):
if model == "mdn":
return build_mdn(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "made":
return build_made(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "maf":
return build_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "maf_rqs":
return build_maf_rqs(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "mnle":
return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs)
elif model == "zuko_maf":
return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs)
else:
raise NotImplementedError
if model not in model_builders:
raise NotImplementedError(f"Model {model} in not implemented")

return model_builders[model](batch_x=batch_x, batch_y=batch_theta, **kwargs)

return build_fn

Expand Down Expand Up @@ -265,20 +279,13 @@ def build_fn_snpe_a(batch_theta, batch_x, num_components):
)

def build_fn(batch_theta, batch_x):
if model == "mdn":
return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "made":
return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "maf":
return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "maf_rqs":
return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "nsf":
return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
elif model == "zuko_maf":
return build_zuko_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
else:
raise NotImplementedError
if model not in model_builders:
raise NotImplementedError(f"Model {model} in not implemented")

# The naming might be a bit confusing.
# batch_x are the latent variables, batch_y the conditioned variables.
# batch_theta are the parameters and batch_x the observable variables.
return model_builders[model](batch_x=batch_theta, batch_y=batch_x, **kwargs)

if model == "mdn_snpe_a":
if num_components != 10:
Expand Down
Loading

0 comments on commit cbf9dca

Please sign in to comment.