From cbf9dca7f209f1c055645b758937bd79edfe67a0 Mon Sep 17 00:00:00 2001 From: Nastya Krouglova <41705732+anastasiakrouglova@users.noreply.github.com> Date: Fri, 5 Apr 2024 11:51:16 +0200 Subject: [PATCH] Zuko density estimators (#1088) (#1116) * Zuko density estimators (#1088) * update zuko to 1.1.0 * test zuko_gmm commit * build_zuko_nsf added * add build_zuko_naf, update test * add license change to pr template. * CLN pyproject.toml (#1009) * CLN pyproject.toml * CLN optional deps comment * CLN alphabetical order * fix x_o and broken link tutorial 7 (#1003) * fix x_o and broken link tutorial 7 * typo in title * suppress plotting output --------- Co-authored-by: Matthijs * replace prepare_for_sbi in tutorials (#1013) * add zuko density estimators * not working gmm * update tests for PR * update PR for pyright * resolve pyright * add reportArgumentType * resolve pyright issue * resolve all issues pyright * resolve pyright * add typing and docstring * add functions from factory to test * remove comment mdn file * add docstrings flow file * add docstring in density_estimator_test.py * Update sbi/neural_nets/flow.py Co-authored-by: Sebastian Bischoff * Update sbi/neural_nets/flow.py Co-authored-by: Sebastian Bischoff * Update sbi/neural_nets/flow.py Co-authored-by: Sebastian Bischoff * removed pyright --------- Co-authored-by: bkmi <12955549+bkmi@users.noreply.github.com> Co-authored-by: Nastya Krouglova Co-authored-by: Jan Boelts Co-authored-by: Thomas Moreau Co-authored-by: Matthijs Pals <34062419+Matthijspals@users.noreply.github.com> Co-authored-by: Matthijs Co-authored-by: zinaStef <49067201+zinaStef@users.noreply.github.com> Co-authored-by: Sebastian Bischoff * merge * hate * merge * merge * merge * merge * MERGE * remove cnf * implement changes Jan * Update sbi/neural_nets/factory.py Co-authored-by: Jan * resolve issues Jan * undo changes to tutorials folder. * sort dependencies. --------- Co-authored-by: bkmi <12955549+bkmi@users.noreply.github.com> Co-authored-by: Nastya Krouglova Co-authored-by: Jan Boelts Co-authored-by: Thomas Moreau Co-authored-by: Matthijs Pals <34062419+Matthijspals@users.noreply.github.com> Co-authored-by: Matthijs Co-authored-by: zinaStef <49067201+zinaStef@users.noreply.github.com> Co-authored-by: Sebastian Bischoff Co-authored-by: Jan --- pyproject.toml | 2 +- .../density_estimators/zuko_flow.py | 8 +- sbi/neural_nets/factory.py | 67 +- sbi/neural_nets/flow.py | 758 ++++++++++++++++-- tests/density_estimator_test.py | 104 ++- tests/neural_nets_factory.py | 30 +- 6 files changed, 806 insertions(+), 163 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fd0f9c49f..027a83f08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,8 @@ dependencies = [ "tensorboard", "torch>=1.8.0", "tqdm", - "zuko>=1.0.0", "pymc>=5.0.0", + "zuko>=1.1.0", ] [project.optional-dependencies] diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index 13c1e1e94..e8c04178a 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -2,7 +2,7 @@ import torch from torch import Tensor, nn -from zuko.flows import Flow +from zuko.flows.core import Flow from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.sbi_types import Shape @@ -125,6 +125,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: emb_cond = emb_cond.expand(batch_shape + (emb_cond.shape[-1],)) dists = self.net(emb_cond) + log_probs = dists.log_prob(input) return log_probs @@ -166,7 +167,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: emb_cond = self._embedding_net(condition) dists = self.net(emb_cond) - # zuko.sample() returns (*sample_shape, *batch_shape, input_size). + samples = dists.sample(sample_shape).reshape(*batch_shape, *sample_shape, -1) return samples @@ -190,9 +191,8 @@ def sample_and_log_prob( emb_cond = self._embedding_net(condition) dists = self.net(emb_cond) - samples, log_probs = dists.rsample_and_log_prob(sample_shape) - # zuko.sample_and_log_prob() returns (*sample_shape, *batch_shape, ...). + samples, log_probs = dists.rsample_and_log_prob(sample_shape) samples = samples.reshape(*batch_shape, *sample_shape, -1) log_probs = log_probs.reshape(*batch_shape, *sample_shape) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index 1273b4587..5bddcc0f5 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -16,11 +16,37 @@ build_maf, build_maf_rqs, build_nsf, + build_zuko_bpf, + build_zuko_gf, build_zuko_maf, + build_zuko_naf, + build_zuko_ncsf, + build_zuko_nice, + build_zuko_nsf, + build_zuko_sospf, + build_zuko_unaf, ) from sbi.neural_nets.mdn import build_mdn from sbi.neural_nets.mnle import build_mnle +model_builders = { + "mdn": build_mdn, + "made": build_made, + "maf": build_maf, + "maf_rqs": build_maf_rqs, + "nsf": build_nsf, + "mnle": build_mnle, + "zuko_nice": build_zuko_nice, + "zuko_maf": build_zuko_maf, + "zuko_nsf": build_zuko_nsf, + "zuko_ncsf": build_zuko_ncsf, + "zuko_sospf": build_zuko_sospf, + "zuko_naf": build_zuko_naf, + "zuko_unaf": build_zuko_unaf, + "zuko_gf": build_zuko_gf, + "zuko_bpf": build_zuko_bpf, +} + def classifier_nn( model: str, @@ -162,22 +188,10 @@ def likelihood_nn( ) def build_fn(batch_theta, batch_x): - if model == "mdn": - return build_mdn(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "made": - return build_made(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "maf": - return build_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "maf_rqs": - return build_maf_rqs(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "nsf": - return build_nsf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "mnle": - return build_mnle(batch_x=batch_x, batch_y=batch_theta, **kwargs) - elif model == "zuko_maf": - return build_zuko_maf(batch_x=batch_x, batch_y=batch_theta, **kwargs) - else: - raise NotImplementedError + if model not in model_builders: + raise NotImplementedError(f"Model {model} in not implemented") + + return model_builders[model](batch_x=batch_x, batch_y=batch_theta, **kwargs) return build_fn @@ -265,20 +279,13 @@ def build_fn_snpe_a(batch_theta, batch_x, num_components): ) def build_fn(batch_theta, batch_x): - if model == "mdn": - return build_mdn(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "made": - return build_made(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "maf": - return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "maf_rqs": - return build_maf_rqs(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "nsf": - return build_nsf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - elif model == "zuko_maf": - return build_zuko_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs) - else: - raise NotImplementedError + if model not in model_builders: + raise NotImplementedError(f"Model {model} in not implemented") + + # The naming might be a bit confusing. + # batch_x are the latent variables, batch_y the conditioned variables. + # batch_theta are the parameters and batch_x the observable variables. + return model_builders[model](batch_x=batch_theta, batch_y=batch_x, **kwargs) if model == "mdn_snpe_a": if num_components != 10: diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index a37e69f63..1bbb5e477 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -2,7 +2,7 @@ # under the Affero General Public License v3, see . from functools import partial -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union from warnings import warn import torch @@ -26,6 +26,33 @@ from sbi.utils.user_input_checks import check_data_device, check_embedding_net_device +def get_numel(batch_x: Tensor, batch_y: Tensor, embedding_net) -> Tuple[int, int]: + """ + Get the number of elements in the input and output space. + + Args: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + embedding_net: Optional embedding network for y. + + Returns: + Tuple of the number of elements in the input and output space. + + """ + x_numel = batch_x[0].numel() + # Infer the output dimensionality of the embedding_net by making a forward pass. + check_data_device(batch_x, batch_y) + check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) + y_numel = embedding_net(batch_y[:1]).numel() + if x_numel == 1: + warn( + "In one-dimensional output space, this flow is limited to Gaussians", + stacklevel=2, + ) + + return x_numel, y_numel + + def build_made( batch_x: Tensor, batch_y: Tensor, @@ -58,18 +85,7 @@ def build_made( Returns: Neural network. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() - - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=2, - ) + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) transform = transforms.IdentityTransform() @@ -142,18 +158,7 @@ def build_maf( Returns: Neural network. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() - - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=2, - ) + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) transform_list = [] for _ in range(num_transforms): @@ -185,7 +190,7 @@ def build_maf( standardizing_net(batch_y, structured_y), embedding_net ) - # Combine transforms. + # Combine transforms transform = transforms.CompositeTransform(transform_list) distribution = get_base_dist(x_numel, **kwargs) @@ -251,18 +256,7 @@ def build_maf_rqs( Returns: Neural network. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() - - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=2, - ) + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) transform_list = [] for _ in range(num_transforms): @@ -355,12 +349,7 @@ def build_nsf( Returns: Neural network. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) # Define mask function to alternate between predicted x-dimensions. def mask_in_layer(i): @@ -433,6 +422,61 @@ def mask_in_layer(i): return flow +def build_zuko_nice( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + randmask: bool = False, + **kwargs, +) -> ZukoFlow: + """ + Build a Non-linear Independent Components Estimation (NICE) flow. + + Affine transformations are used by default, instead of the additive transformations + used by Dinh et al. (2014) originally. + + References: + | NICE: Non-linear Independent Components Estimation (Dinh et al., 2014) + | https://arxiv.org/abs/1410.8516 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + randmask: Whether to use random masks in the flow. Defaults to False. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "NICE" + additional_kwargs = {"randmask": randmask, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + def build_zuko_maf( batch_x: Tensor, batch_y: Tensor, @@ -441,13 +485,17 @@ def build_zuko_maf( hidden_features: Union[Sequence[int], int] = 50, num_transforms: int = 5, embedding_net: nn.Module = nn.Identity(), - residual: bool = True, randperm: bool = False, **kwargs, ) -> ZukoFlow: - """Builds MAF p(x|y). + """ + Build a Masked Autoregressive Flow (MAF). - Args: + References: + | Masked Autoregressive Flow for Density Estimation (Papamakarios et al., 2017) + | https://arxiv.org/abs/1705.07057 + + Arguments: batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. z_score_x: Whether to z-score xs passing into the network, can be one of: @@ -458,66 +506,598 @@ def build_zuko_maf( sample is, for example, a time series or an image. z_score_y: Whether to z-score ys passing into the network, same options as z_score_x. - hidden_features: Number of hidden features. - num_transforms: Number of transforms. - embedding_net: Optional embedding network for y. - residual: whether to use residual blocks in the coupling layer. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + randperm: Whether to use random permutations in the flow. Defaults to False. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "MAF" + additional_kwargs = {"randperm": randperm, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_nsf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + num_bins: int = 8, + **kwargs, +) -> ZukoFlow: + """ + Build a Neural Spline Flow (NSF) with monotonic rational-quadratic spline + transformations. + + By default, transformations are fully autoregressive. Coupling transformations + can be obtained by setting :py:`passes=2`. + + Warning: + Spline transformations are defined over the domain :math:`[-5, 5]`. Any feature + outside of this domain is not transformed. It is recommended to standardize + features (zero mean, unit variance) before training. + + References: + | Neural Spline Flows (Durkan et al., 2019) + | https://arxiv.org/abs/1906.04032 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + num_bins: The number of bins in the spline transformations. Defaults to 8. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "NSF" + additional_kwargs = {"bins": num_bins, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_ncsf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + num_bins: int = 8, + **kwargs, +) -> ZukoFlow: + r""" + Build a Neural Circular Spline Flow (NCSF). + + Circular spline transformations are obtained by composing circular domain shifts + with regular spline transformations. Features are assumed to lie in the half-open + interval :math:`[-\pi, \pi[`. + + References: + | Normalizing Flows on Tori and Spheres (Rezende et al., 2020) + | https://arxiv.org/abs/2002.02428 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + num_bins: The number of bins in the spline transformations. Defaults to 8. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "NCSF" + additional_kwargs = {"bins": num_bins, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_sospf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + degree: int = 4, + polynomials: int = 3, + **kwargs, +) -> ZukoFlow: + """ + Build a Sum-of-Squares Polynomial Flow (SOSPF). + + References: + | Sum-of-Squares Polynomial Flow (Jaini et al., 2019) + | https://arxiv.org/abs/1905.02325 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + degree: The degree of the polynomials. Defaults to 4. + polynomials: The number of polynomials. Defaults to 3. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "SOSPF" + additional_kwargs = {"degree": degree, "polynomials": polynomials, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_naf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + randperm: bool = False, + signal: int = 16, + **kwargs, +) -> ZukoFlow: + """ + Build a Neural Autoregressive Flow (NAF). + + Warning: + Invertibility is only guaranteed for features within the interval :math:`[-10, + 10]`. It is recommended to standardize features (zero mean, unit variance) + before training. + + References: + | Neural Autoregressive Flows (Huang et al., 2018) + | https://arxiv.org/abs/1804.00779 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). randperm: Whether features are randomly permuted between transformations or not. - kwargs: Additional arguments that are passed by the build function but are not - relevant for maf and are therefore ignored. + If :py:`False`, features are in ascending (descending) order for even + (odd) transformations. + signal: The number of signal features of the monotonic network. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "NAF" + additional_kwargs = { + "randperm": randperm, + "signal": signal, + # "network": network, + **kwargs, + } + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_unaf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + randperm: bool = False, + signal: int = 16, + **kwargs, +) -> ZukoFlow: + """ + Build an Unconstrained Neural Autoregressive Flow (UNAF). + + Warning: + Invertibility is only guaranteed for features within the interval :math:`[-10, + 10]`. It is recommended to standardize features (zero mean, unit variance) + before training. + + References: + | Unconstrained Monotonic Neural Networks (Wehenkel et al., 2019) + | https://arxiv.org/abs/1908.05164 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + randperm: Whether features are randomly permuted between transformations or not. + If :py:`False`, features are in ascending (descending) order for even + (odd) transformations. + signal: The number of signal features of the monotonic network. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "UNAF" + additional_kwargs = { + "randperm": randperm, + "signal": signal, + # "network": network, + **kwargs, + } + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_cnf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + **kwargs, +) -> ZukoFlow: + """ + Build a Continuous Normalizing Flow (CNF) with a free-form Jacobian transformation. + + References: + | Neural Ordinary Differential Equations (Chen el al., 2018) + | https://arxiv.org/abs/1806.07366 + + | FFJORD: Free-form Continuous Dynamics for Scalable Reversible + | Generative Models (Grathwohl et al., 2018) + | https://arxiv.org/abs/1810.01367 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "CNF" + additional_kwargs = {**kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_gf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 3, + embedding_net: nn.Module = nn.Identity(), + components: int = 8, + **kwargs, +) -> ZukoFlow: + """ + Build a Gaussianization Flow (GF). + + Warning: + Invertibility is only guaranteed for features within the interval :math:`[-10, + 10]`. It is recommended to standardize features (zero mean, unit variance) + before training. + + References: + | Gaussianization Flows (Meng et al., 2020) + | https://arxiv.org/abs/2003.01941 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + components: The number of components in the Gaussian mixture model. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "GF" + additional_kwargs = {"components": components, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_bpf( + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 3, + embedding_net: nn.Module = nn.Identity(), + degree: int = 16, + linear: bool = False, + **kwargs, +) -> ZukoFlow: + """ + Build a Bernstein polynomial flow (BPF). + + Warning: + Invertibility is only guaranteed for features within the interval :math:`[-10, + 10]`. It is recommended to standardize features (zero mean, unit variance) + before training. + + References: + | Short-Term Density Forecasting of Low-Voltage Load using + | Bernstein-Polynomial Normalizing Flows (Arpogaus et al., 2022) + | https://arxiv.org/abs/2204.13939 + + Arguments: + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + degree: The degree :math:`M` of the Bernstein polynomial. + linear: Whether to use a linear or sigmoid mapping to :math:`[0, 1]`. + **kwargs: Additional keyword arguments to pass to the flow constructor. + """ + which_nf = "BPF" + additional_kwargs = {"degree": degree, "linear": linear, **kwargs} + flow = build_zuko_flow( + which_nf, + batch_x, + batch_y, + z_score_x, + z_score_y, + hidden_features, + num_transforms, + embedding_net, + **additional_kwargs, + ) + + return flow + + +def build_zuko_flow( + which_nf: str, + batch_x: Tensor, + batch_y: Tensor, + z_score_x: Optional[str] = "independent", + z_score_y: Optional[str] = "independent", + hidden_features: Union[Sequence[int], int] = 50, + num_transforms: int = 5, + embedding_net: nn.Module = nn.Identity(), + **kwargs, +) -> ZukoFlow: + """ + Fundamental building blocks to build a Zuko normalizing flow model. + + Args: + which_nf (str): The type of normalizing flow to build. + batch_x: Batch of xs, used to infer dimensionality and (optional) z-scoring. + batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. + z_score_x: Whether to z-score xs passing into the network, can be one of: + - `none`, or None: do not z-score. + - `independent`: z-score each dimension independently. + - `structured`: treat dimensions as related, therefore compute mean and std + over the entire batch, instead of per-dimension. Should be used when each + sample is, for example, a time series or an image. + z_score_y: Whether to z-score ys passing into the network, same options as + z_score_x. + hidden_features: The number of hidden features in the flow. Defaults to 50. + num_transforms: The number of transformations in the flow. Defaults to 5. + embedding_net: The embedding network to use. Defaults to nn.Identity(). + **kwargs: Additional keyword arguments to pass to the flow constructor. Returns: - Neural network. + ZukoFlow: The constructed Zuko normalizing flow model. """ - x_numel = batch_x[0].numel() - # Infer the output dimensionality of the embedding_net by making a forward pass. - check_data_device(batch_x, batch_y) - check_embedding_net_device(embedding_net=embedding_net, datum=batch_y) - embedding_net.eval() - y_numel = embedding_net(batch_y[:1]).numel() - if x_numel == 1: - warn( - "In one-dimensional output space, this flow is limited to Gaussians", - stacklevel=1, - ) + + x_numel, y_numel = get_numel(batch_x, batch_y, embedding_net) if isinstance(hidden_features, int): hidden_features = [hidden_features] * num_transforms - if x_numel == 1: - maf = zuko.flows.MAF( - features=x_numel, - context=y_numel, - hidden_features=hidden_features, - transforms=num_transforms, + build_nf = getattr(zuko.flows, which_nf) + + if which_nf == "CNF": + flow_built = build_nf( + features=x_numel, context=y_numel, hidden_features=hidden_features, **kwargs ) else: - maf = zuko.flows.MAF( + flow_built = build_nf( features=x_numel, context=y_numel, hidden_features=hidden_features, transforms=num_transforms, - randperm=randperm, - residual=residual, + **kwargs, ) - transforms = maf.transform - z_score_x_bool, structured_x = z_score_parser(z_score_x) - if z_score_x_bool: - transforms = ( - transforms, - standardizing_transform_zuko(batch_x, structured_x), - ) + # Continuous normalizing flows (CNF) only have one transform, + # so we need to handle them slightly differently. + if which_nf == "CNF": + transform = flow_built.transform - z_score_y_bool, structured_y = z_score_parser(z_score_y) - if z_score_y_bool: - # Prepend standardizing transform to y-embedding. - embedding_net = nn.Sequential( - standardizing_net(batch_y, structured_y), embedding_net - ) + z_score_x_bool, structured_x = z_score_parser(z_score_x) + if z_score_x_bool: + transform = ( + transform, + standardizing_transform_zuko(batch_x, structured_x), + ) - # Combine transforms. - neural_net = zuko.flows.Flow(transforms, maf.base) + z_score_y_bool, structured_y = z_score_parser(z_score_y) + if z_score_y_bool: + # Prepend standardizing transform to y-embedding. + embedding_net = nn.Sequential( + standardizing_transform_zuko(batch_y, structured_y), embedding_net + ) + + # Combine transforms. + neural_net = zuko.flows.Flow(transform, flow_built.base) + else: + transforms = flow_built.transform.transforms + + z_score_x_bool, structured_x = z_score_parser(z_score_x) + if z_score_x_bool: + transforms = ( + *transforms, + standardizing_transform_zuko(batch_x, structured_x), + ) + + z_score_y_bool, structured_y = z_score_parser(z_score_y) + if z_score_y_bool: + # Prepend standardizing transform to y-embedding. + embedding_net = nn.Sequential( + standardizing_net(batch_y, structured_y), embedding_net + ) + + # Combine transforms. + neural_net = zuko.flows.Flow(transforms, flow_built.base) flow = ZukoFlow(neural_net, embedding_net, condition_shape=batch_y[0].shape) diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 2468a0fbc..8bd8bbb27 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -10,21 +10,82 @@ from torch import eye, zeros from torch.distributions import MultivariateNormal -from sbi.neural_nets.density_estimators import NFlowsFlow, ZukoFlow from sbi.neural_nets.density_estimators.shape_handling import reshape_to_iid_batch_event -from sbi.neural_nets.flow import build_nsf, build_zuko_maf +from sbi.neural_nets.flow import ( + build_maf, + build_maf_rqs, + build_nsf, + build_zuko_bpf, + build_zuko_gf, + build_zuko_maf, + build_zuko_naf, + build_zuko_ncsf, + build_zuko_nice, + build_zuko_nsf, + build_zuko_sospf, + build_zuko_unaf, +) + + +def get_batch_input(nsamples: int, input_dims: int) -> torch.Tensor: + r"""Generate a batch of input samples from a multivariate normal distribution. + + Args: + nsamples (int): The number of samples to generate. + input_dims (int): The dimensionality of the input samples. + + Returns: + torch.Tensor: A tensor of shape (nsamples, input_dims) + containing the generated samples. + """ + input_mvn = MultivariateNormal( + loc=zeros(input_dims), covariance_matrix=eye(input_dims) + ) + return input_mvn.sample((nsamples,)) + + +def get_batch_context(nsamples: int, condition_shape: tuple[int, ...]) -> torch.Tensor: + r"""Generate a batch of context samples from a multivariate normal distribution. + + Args: + nsamples (int): The number of context samples to generate. + condition_shape (tuple[int, ...]): The shape of the condition for each sample. + + Returns: + torch.Tensor: A tensor containing the generated context samples. + """ + context_mvn = MultivariateNormal( + loc=zeros(*condition_shape), covariance_matrix=eye(condition_shape[-1]) + ) + return context_mvn.sample((nsamples,)) -@pytest.mark.parametrize("density_estimator", (NFlowsFlow, ZukoFlow)) +@pytest.mark.parametrize( + "build_density_estimator", + ( + build_maf, + build_maf_rqs, + build_nsf, + build_zuko_nice, + build_zuko_maf, + build_zuko_nsf, + build_zuko_ncsf, + build_zuko_sospf, + build_zuko_naf, + build_zuko_unaf, + build_zuko_gf, + build_zuko_bpf, + ), +) @pytest.mark.parametrize("input_dims", (1, 2)) @pytest.mark.parametrize( "condition_shape", ((1,), (2,), (1, 1), (2, 2), (1, 1, 1), (2, 2, 2)) ) -def test_api_density_estimator(density_estimator, input_dims, condition_shape): +def test_api_density_estimator(build_density_estimator, input_dims, condition_shape): r"""Checks whether we can evaluate and sample from density estimators correctly. Args: - density_estimator: DensityEstimator subclass. + build_density_estimator: function that creates a DensityEstimator subclass. input_dim: Dimensionality of the input. context_shape: Dimensionality of the context. """ @@ -32,14 +93,8 @@ def test_api_density_estimator(density_estimator, input_dims, condition_shape): nsamples = 10 nsamples_test = 5 - input_mvn = MultivariateNormal( - loc=zeros(input_dims), covariance_matrix=eye(input_dims) - ) - batch_input = input_mvn.sample((nsamples,)) - context_mvn = MultivariateNormal( - loc=zeros(*condition_shape), covariance_matrix=eye(condition_shape[-1]) - ) - batch_context = context_mvn.sample((nsamples,)) + batch_input = get_batch_input(nsamples, input_dims) + batch_context = get_batch_context(nsamples, condition_shape) class EmbeddingNet(torch.nn.Module): def forward(self, x): @@ -47,22 +102,13 @@ def forward(self, x): x = torch.sum(x, dim=-1) return x - if density_estimator == NFlowsFlow: - estimator = build_nsf( - batch_input, - batch_context, - hidden_features=10, - num_transforms=2, - embedding_net=EmbeddingNet(), - ) - elif density_estimator == ZukoFlow: - estimator = build_zuko_maf( - batch_input, - batch_context, - hidden_features=10, - num_transforms=2, - embedding_net=EmbeddingNet(), - ) + estimator = build_density_estimator( + batch_input, + batch_context, + hidden_features=10, + num_transforms=2, + embedding_net=EmbeddingNet(), + ) # Loss is only required to work for batched inputs and contexts loss = estimator.loss(batch_input, batch_context) diff --git a/tests/neural_nets_factory.py b/tests/neural_nets_factory.py index f5a2775f3..af8c4d4a1 100644 --- a/tests/neural_nets_factory.py +++ b/tests/neural_nets_factory.py @@ -2,6 +2,24 @@ from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn +models_to_test = [ + "mdn", + "made", + "maf", + "maf_rqs", + "nsf", + "mnle", + "zuko_bpf", + "zuko_gf", + "zuko_maf", + "zuko_naf", + "zuko_ncsf", + "zuko_nice", + "zuko_nsf", + "zuko_sospf", + "zuko_unaf", +] + @pytest.mark.parametrize( "model", ["linear", "mlp", "resnet"], ids=["linear", "mlp", "resnet"] @@ -12,22 +30,14 @@ def test_deprecated_import_classifier_nn(model: str): assert callable(build_fcn) -@pytest.mark.parametrize( - "model", - ["mdn", "made", "maf", "maf_rqs", "nsf", "mnle", "zuko_maf"], - ids=["mdn", "made", "maf", "maf_rqs", "nsf", "mnle", "zuko_maf"], -) +@pytest.mark.parametrize("model", models_to_test, ids=models_to_test) def test_deprecated_import_likelihood_nn(model: str): with pytest.warns(DeprecationWarning): build_fcn = likelihood_nn(model) assert callable(build_fcn) -@pytest.mark.parametrize( - "model", - ["mdn", "made", "maf", "maf_rqs", "nsf", "mnle", "zuko_maf"], - ids=["mdn", "made", "maf", "maf_rqs", "nsf", "mnle", "zuko_maf"], -) +@pytest.mark.parametrize("model", models_to_test, ids=models_to_test) def test_deprecated_import_posterior_nn(model: str): with pytest.warns(DeprecationWarning): build_fcn = posterior_nn(model)