diff --git a/sbi/neural_nets/estimators/__init__.py b/sbi/neural_nets/estimators/__init__.py index 370dfb01e..121fe401e 100644 --- a/sbi/neural_nets/estimators/__init__.py +++ b/sbi/neural_nets/estimators/__init__.py @@ -1,12 +1,10 @@ from sbi.neural_nets.estimators.base import ConditionalDensityEstimator from sbi.neural_nets.estimators.categorical_net import ( + CategoricalMADE, CategoricalMassEstimator, - CategoricalNet, ) from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator -from sbi.neural_nets.estimators.mixed_density_estimator import ( - MixedDensityEstimator, -) +from sbi.neural_nets.estimators.mixed_density_estimator import MixedDensityEstimator from sbi.neural_nets.estimators.nflows_flow import NFlowsFlow from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator from sbi.neural_nets.estimators.zuko_flow import ZukoFlow diff --git a/sbi/neural_nets/estimators/categorical_net.py b/sbi/neural_nets/estimators/categorical_net.py index e1f3ea8ca..c09118c79 100644 --- a/sbi/neural_nets/estimators/categorical_net.py +++ b/sbi/neural_nets/estimators/categorical_net.py @@ -1,18 +1,20 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see -from typing import Optional +from typing import Callable, Optional import torch +from nflows.nn.nde.made import MADE +from nflows.utils import torchutils from torch import Tensor, nn from torch.distributions import Categorical -from torch.nn import Sigmoid, Softmax +from torch.nn import functional as F from sbi.neural_nets.estimators.base import ConditionalDensityEstimator -class CategoricalNet(nn.Module): - """Conditional density (mass) estimation for a categorical random variable. +class CategoricalMADE(MADE): + """Conditional density (mass) estimation for a n-dim categorical random variable. Takes as input parameters theta and learns the parameters p of a Categorical. @@ -21,93 +23,118 @@ class CategoricalNet(nn.Module): def __init__( self, - num_input: int, - num_categories: int, - num_hidden: int = 20, - num_layers: int = 2, - embedding_net: Optional[nn.Module] = None, + num_categories: Tensor, # Tensor[int] + hidden_features: int, + context_features: Optional[int] = None, + num_blocks: int = 2, + use_residual_blocks: bool = True, + random_mask: bool = False, + activation: Callable = F.relu, + dropout_probability: float = 0.0, + use_batch_norm: bool = False, + epsilon: float = 1e-2, + embedding_net: nn.Module = nn.Identity(), ): """Initialize the neural net. Args: - num_input: number of input units, i.e., dimensionality of the features. - num_categories: number of output units, i.e., number of categories. + num_categories: number of categories for each variable. len(categories) + defines the number of input units, i.e., dimensionality of the features. + max(categories) defines the number of output units, i.e., the largest + number of categories. num_hidden: number of hidden units per layer. num_layers: number of hidden layers. embedding_net: emebedding net for input. """ - super().__init__() - - self.num_hidden = num_hidden - self.num_input = num_input - self.activation = Sigmoid() - self.softmax = Softmax(dim=1) - self.num_categories = num_categories - - # Maybe add embedding net in front. - if embedding_net is not None: - self.input_layer = nn.Sequential( - embedding_net, nn.Linear(num_input, num_hidden) - ) - else: - self.input_layer = nn.Linear(num_input, num_hidden) - - # Repeat hidden units hidden layers times. - self.hidden_layers = nn.ModuleList() - for _ in range(num_layers): - self.hidden_layers.append(nn.Linear(num_hidden, num_hidden)) - - self.output_layer = nn.Linear(num_hidden, num_categories) + if use_residual_blocks and random_mask: + raise ValueError("Residual blocks can't be used with random masks.") - def forward(self, condition: Tensor) -> Tensor: - """Return categorical probability predicted from a batch of inputs. + self.num_variables = len(num_categories) + self.num_categories = int(torch.max(num_categories)) + self.mask = torch.zeros(self.num_variables, self.num_categories) + for i, c in enumerate(num_categories): + self.mask[i, :c] = 1 - Args: - condition: batch of context parameters for the net. - - Returns: - Tensor: batch of predicted categorical probabilities. - """ - # forward path - condition = self.activation(self.input_layer(condition)) - - # iterate n hidden layers, input condition and calculate tanh activation - for layer in self.hidden_layers: - condition = self.activation(layer(condition)) + super().__init__( + self.num_variables, + hidden_features, + context_features=context_features, + num_blocks=num_blocks, + output_multiplier=self.num_categories, + use_residual_blocks=use_residual_blocks, + random_mask=random_mask, + activation=activation, + dropout_probability=dropout_probability, + use_batch_norm=use_batch_norm, + ) - return self.softmax(self.output_layer(condition)) + self.embedding_net = embedding_net + self.hidden_features = hidden_features + self.epsilon = epsilon + self.context_features = context_features - def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: - """Return categorical log probability of categories input, given condition. + def forward(self, input: Tensor, condition: Optional[Tensor] = None) -> Tensor: + r"""Forward pass of the categorical density estimator network to compute the + conditional density at a given time. Args: - input: categories to evaluate. - condition: parameters. + input: Inputs datapoints of shape `(batch_size, *input_shape)` + condition: Conditioning variable. `(batch_size, *condition_shape)` Returns: - Tensor: log probs with shape (input.shape[0],) + Predicted categorical logits. `(batch_size, *input_shape, + num_categories)` """ - # Predict categorical ps and evaluate. - ps = self.forward(condition) - # Squeeze the last dimension (event dim) because `Categorical` has - # `event_shape=()` but our data usually has an event_shape of `(1,)`. - return Categorical(probs=ps).log_prob(input.squeeze(dim=-1)) + embedded_condition = self.embedding_net.forward(condition) + out = super().forward(input, context=embedded_condition) + return out.masked_fill(~self.mask.bool().flatten(), float("-inf")) - def sample(self, sample_shape: torch.Size, condition: Tensor) -> Tensor: - """Returns samples from categorical random variable with probs predicted from - the neural net. + def log_prob(self, input: Tensor, condition: Optional[Tensor] = None) -> Tensor: + r"""Return log-probability of samples. Args: - sample_shape: number of samples to obtain. - condition: batch of parameters for prediction. + input: Input datapoints of shape `(batch_size, *input_shape)`. + condition: Conditioning variable. `(batch_size, *condition_shape)`. Returns: - Tensor: Samples with shape (num_samples, 1) + Log-probabilities of shape `(batch_size, num_variables, num_categories)`. """ + outputs = self.forward(input, condition=condition) + outputs = outputs.reshape(*input.shape, self.num_categories) + log_prob = Categorical(logits=outputs).log_prob(input).sum(dim=-1) + return log_prob + + def sample(self, sample_shape, context=None): + # Ensure sample_shape is a tuple + if isinstance(sample_shape, int): + sample_shape = (sample_shape,) + sample_shape = torch.Size(sample_shape) + + # Calculate total number of samples + num_samples = int(torch.prod(torch.tensor(sample_shape))) + + # Prepare context + if context is not None: + batch_dim = context.shape[0] + if context.ndim == 2: + context = context.unsqueeze(0) + if batch_dim == 1: + context = torchutils.repeat_rows(context, num_samples) + else: + context_dim = 0 if self.context_features is None else self.context_features + context = torch.zeros(num_samples, context_dim) + batch_dim = 1 + + with torch.no_grad(): + samples = torch.randn(num_samples, batch_dim, self.num_variables) + for i in range(self.num_variables): + outputs = self.forward(samples, context) + outputs = outputs.reshape(*samples.shape, self.num_categories) + samples[:, :, : i + 1] = Categorical( + logits=outputs[:, :, : i + 1] + ).sample() - # Predict Categorical ps and sample. - ps = self.forward(condition) - return Categorical(probs=ps).sample(sample_shape=sample_shape) + return samples.reshape(*sample_shape, batch_dim, self.num_variables) class CategoricalMassEstimator(ConditionalDensityEstimator): @@ -117,12 +144,12 @@ class CategoricalMassEstimator(ConditionalDensityEstimator): """ def __init__( - self, net: CategoricalNet, input_shape: torch.Size, condition_shape: torch.Size + self, net: CategoricalMADE, input_shape: torch.Size, condition_shape: torch.Size ) -> None: """Initialize the mass estimator. Args: - net: CategoricalNet. + net: CategoricalMADE. input_shape: Shape of the input data. condition_shape: Shape of the condition data """ diff --git a/sbi/neural_nets/estimators/mixed_density_estimator.py b/sbi/neural_nets/estimators/mixed_density_estimator.py index dedba1b52..e46d7baeb 100644 --- a/sbi/neural_nets/estimators/mixed_density_estimator.py +++ b/sbi/neural_nets/estimators/mixed_density_estimator.py @@ -80,8 +80,10 @@ def sample( sample_shape=sample_shape, condition=condition, ) - # Trailing `1` because `Categorical` has event_shape `()`. - discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1) + num_variables = self.discrete_net.net.num_variables + discrete_samples = discrete_samples.reshape( + num_samples * batch_dim, num_variables + ) # repeat the batch of embedded condition to match number of choices. condition_event_dim = embedded_condition.dim() - 1 @@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: f"{input_batch_dim} do not match." ) - cont_input, disc_input = _separate_input(input) + num_discrete_variables = self.discrete_net.net.num_variables + cont_input, disc_input = _separate_input(input, num_discrete_variables) # Embed continuous condition embedded_condition = self.condition_embedding(condition) # expand and repeat to match batch of inputs. @@ -204,3 +207,8 @@ def _separate_input( Assumes the discrete data to live in the last columns of input. """ return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:] + + +def _is_discrete(input: Tensor) -> Tensor: + """Infer discrete columns in input data.""" + return torch.tensor([torch.allclose(col, col.round()) for col in input.T]) diff --git a/sbi/neural_nets/net_builders/categorial.py b/sbi/neural_nets/net_builders/categorial.py index 2c2add091..18ced4397 100644 --- a/sbi/neural_nets/net_builders/categorial.py +++ b/sbi/neural_nets/net_builders/categorial.py @@ -1,16 +1,18 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +import warnings from typing import Optional -from torch import Tensor, nn, unique +from torch import Tensor, nn, tensor, unique -from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet -from sbi.utils.nn_utils import get_numel -from sbi.utils.sbiutils import ( - standardizing_net, - z_score_parser, +from sbi.neural_nets.estimators import ( + CategoricalMADE, + CategoricalMassEstimator, ) +from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete +from sbi.utils.nn_utils import get_numel +from sbi.utils.sbiutils import standardizing_net, z_score_parser from sbi.utils.user_input_checks import check_data_device @@ -21,6 +23,7 @@ def build_categoricalmassestimator( z_score_y: Optional[str] = "independent", num_hidden: int = 20, num_layers: int = 2, + num_categories: Optional[Tensor] = None, embedding_net: nn.Module = nn.Identity(), ): """Returns a density estimator for a categorical random variable. @@ -33,28 +36,39 @@ def build_categoricalmassestimator( num_hidden: Number of hidden units per layer. num_layers: Number of hidden layers. embedding_net: Embedding net for y. + num_categories: number of categories for each variable. """ if z_score_x != "none": raise ValueError("Categorical input should not be z-scored.") + if num_categories is None: + warnings.warn( + "Inferring num_categories from batch_x. Ensure all categories are present.", + stacklevel=2, + ) check_data_device(batch_x, batch_y) - if batch_x.shape[1] > 1: - raise NotImplementedError("CategoricalMassEstimator only supports 1D input.") - num_categories = unique(batch_x).numel() - dim_condition = get_numel(batch_y, embedding_net=embedding_net) z_score_y_bool, structured_y = z_score_parser(z_score_y) + y_numel = get_numel(batch_y, embedding_net=embedding_net) + if z_score_y_bool: embedding_net = nn.Sequential( standardizing_net(batch_y, structured_y), embedding_net ) - categorical_net = CategoricalNet( - num_input=dim_condition, + if num_categories is None: + batch_x_discrete = batch_x[:, _is_discrete(batch_x)] + inferred_categories = tensor([ + unique(col).numel() for col in batch_x_discrete.T + ]) + num_categories = inferred_categories + + categorical_net = CategoricalMADE( num_categories=num_categories, - num_hidden=num_hidden, - num_layers=num_layers, + hidden_features=num_hidden, + context_features=y_numel, + num_blocks=num_layers, embedding_net=embedding_net, ) diff --git a/sbi/neural_nets/net_builders/mnle.py b/sbi/neural_nets/net_builders/mnle.py index fd7ef6d3a..2227ccef9 100644 --- a/sbi/neural_nets/net_builders/mnle.py +++ b/sbi/neural_nets/net_builders/mnle.py @@ -8,8 +8,13 @@ from torch import Tensor, nn from sbi.neural_nets.estimators import MixedDensityEstimator -from sbi.neural_nets.estimators.mixed_density_estimator import _separate_input -from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator +from sbi.neural_nets.estimators.mixed_density_estimator import ( + _is_discrete, + _separate_input, +) +from sbi.neural_nets.net_builders.categorial import ( + build_categoricalmassestimator, +) from sbi.neural_nets.net_builders.flow import ( build_made, build_maf, @@ -26,10 +31,7 @@ build_zuko_unaf, ) from sbi.neural_nets.net_builders.mdn import build_mdn -from sbi.utils.sbiutils import ( - standardizing_net, - z_score_parser, -) +from sbi.utils.sbiutils import standardizing_net, z_score_parser from sbi.utils.user_input_checks import check_data_device model_builders = { @@ -56,6 +58,7 @@ def build_mnle( z_score_x: Optional[str] = "independent", z_score_y: Optional[str] = "independent", flow_model: str = "nsf", + num_categorical_columns: Optional[Tensor] = None, embedding_net: nn.Module = nn.Identity(), combined_embedding_net: Optional[nn.Module] = None, num_transforms: int = 2, @@ -102,6 +105,8 @@ def build_mnle( as z_score_x. flow_model: type of flow model to use for the continuous part of the data. + num_categorical_columns: Number of categorical columns of each variable in the + input data. If None, the function will infer this from the data. embedding_net: Optional embedding network for y, required if y is > 1D. combined_embedding_net: Optional embedding for combining the discrete part of the input and the embedded condition into a joined @@ -125,13 +130,17 @@ def build_mnle( warnings.warn( "The mixed neural likelihood estimator assumes that x contains " - "continuous data in the first n-1 columns (e.g., reaction times) and " - "categorical data in the last column (e.g., corresponding choices). If " + "continuous data in the first n-k columns (e.g., reaction times) and " + "categorical data in the last k columns (e.g., corresponding choices). If " "this is not the case for the passed `x` do not use this function.", stacklevel=2, ) # Separate continuous and discrete data. - cont_x, disc_x = _separate_input(batch_x) + if num_categorical_columns is None: + num_disc = int(torch.sum(_is_discrete(batch_x))) + else: + num_disc = len(num_categorical_columns) + cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc) # Set up y-embedding net with z-scoring. z_score_y_bool, structured_y = z_score_parser(z_score_y) @@ -152,6 +161,7 @@ def build_mnle( num_hidden=hidden_features, num_layers=hidden_layers, embedding_net=embedding_net, + num_categories=num_categorical_columns, ) if combined_embedding_net is None: diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 0f006dde1..b958637eb 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -412,7 +412,4 @@ def test_mixed_density_estimator( # Test samples samples = density_estimator.sample(sample_shape, condition=conditions) - if density_estimator_build_fn == build_categoricalmassestimator: - # Our categorical is always 1D and does not return `input_event_shape`. - input_event_shape = () assert samples.shape == (*sample_shape, batch_dim, *input_event_shape) diff --git a/tutorials/Example_01_DecisionMakingModel.ipynb b/tutorials/Example_01_DecisionMakingModel.ipynb index eb16182c2..409a111af 100644 --- a/tutorials/Example_01_DecisionMakingModel.ipynb +++ b/tutorials/Example_01_DecisionMakingModel.ipynb @@ -103,7 +103,7 @@ " \"\"\"Returns a sample from a mixed distribution given parameters theta.\n", "\n", " Args:\n", - " theta: batch of parameters, shape (batch_size, 2) concentration_scaling:\n", + " theta: batch of parameters, shape (batch_size, 1 + num_categories) concentration_scaling:\n", " scaling factor for the concentration parameter of the InverseGamma\n", " distribution, mimics an experimental condition.\n", "\n", diff --git a/tutorials/example_01_utils.py b/tutorials/example_01_utils.py index 620058d05..a90d58dfe 100644 --- a/tutorials/example_01_utils.py +++ b/tutorials/example_01_utils.py @@ -54,7 +54,7 @@ def iid_likelihood(self, theta: Tensor) -> Tensor: rate=beta, ).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1)) - joint_likelihood = (logprob_choices + logprob_rts).squeeze() + joint_likelihood = torch.sum(logprob_choices, dim=2) + logprob_rts.squeeze() assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]]) return joint_likelihood.sum(1)