diff --git a/sbi/neural_nets/estimators/__init__.py b/sbi/neural_nets/estimators/__init__.py
index 370dfb01e..121fe401e 100644
--- a/sbi/neural_nets/estimators/__init__.py
+++ b/sbi/neural_nets/estimators/__init__.py
@@ -1,12 +1,10 @@
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from sbi.neural_nets.estimators.categorical_net import (
+ CategoricalMADE,
CategoricalMassEstimator,
- CategoricalNet,
)
from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator
-from sbi.neural_nets.estimators.mixed_density_estimator import (
- MixedDensityEstimator,
-)
+from sbi.neural_nets.estimators.mixed_density_estimator import MixedDensityEstimator
from sbi.neural_nets.estimators.nflows_flow import NFlowsFlow
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.estimators.zuko_flow import ZukoFlow
diff --git a/sbi/neural_nets/estimators/categorical_net.py b/sbi/neural_nets/estimators/categorical_net.py
index e1f3ea8ca..c09118c79 100644
--- a/sbi/neural_nets/estimators/categorical_net.py
+++ b/sbi/neural_nets/estimators/categorical_net.py
@@ -1,18 +1,20 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
-from typing import Optional
+from typing import Callable, Optional
import torch
+from nflows.nn.nde.made import MADE
+from nflows.utils import torchutils
from torch import Tensor, nn
from torch.distributions import Categorical
-from torch.nn import Sigmoid, Softmax
+from torch.nn import functional as F
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
-class CategoricalNet(nn.Module):
- """Conditional density (mass) estimation for a categorical random variable.
+class CategoricalMADE(MADE):
+ """Conditional density (mass) estimation for a n-dim categorical random variable.
Takes as input parameters theta and learns the parameters p of a Categorical.
@@ -21,93 +23,118 @@ class CategoricalNet(nn.Module):
def __init__(
self,
- num_input: int,
- num_categories: int,
- num_hidden: int = 20,
- num_layers: int = 2,
- embedding_net: Optional[nn.Module] = None,
+ num_categories: Tensor, # Tensor[int]
+ hidden_features: int,
+ context_features: Optional[int] = None,
+ num_blocks: int = 2,
+ use_residual_blocks: bool = True,
+ random_mask: bool = False,
+ activation: Callable = F.relu,
+ dropout_probability: float = 0.0,
+ use_batch_norm: bool = False,
+ epsilon: float = 1e-2,
+ embedding_net: nn.Module = nn.Identity(),
):
"""Initialize the neural net.
Args:
- num_input: number of input units, i.e., dimensionality of the features.
- num_categories: number of output units, i.e., number of categories.
+ num_categories: number of categories for each variable. len(categories)
+ defines the number of input units, i.e., dimensionality of the features.
+ max(categories) defines the number of output units, i.e., the largest
+ number of categories.
num_hidden: number of hidden units per layer.
num_layers: number of hidden layers.
embedding_net: emebedding net for input.
"""
- super().__init__()
-
- self.num_hidden = num_hidden
- self.num_input = num_input
- self.activation = Sigmoid()
- self.softmax = Softmax(dim=1)
- self.num_categories = num_categories
-
- # Maybe add embedding net in front.
- if embedding_net is not None:
- self.input_layer = nn.Sequential(
- embedding_net, nn.Linear(num_input, num_hidden)
- )
- else:
- self.input_layer = nn.Linear(num_input, num_hidden)
-
- # Repeat hidden units hidden layers times.
- self.hidden_layers = nn.ModuleList()
- for _ in range(num_layers):
- self.hidden_layers.append(nn.Linear(num_hidden, num_hidden))
-
- self.output_layer = nn.Linear(num_hidden, num_categories)
+ if use_residual_blocks and random_mask:
+ raise ValueError("Residual blocks can't be used with random masks.")
- def forward(self, condition: Tensor) -> Tensor:
- """Return categorical probability predicted from a batch of inputs.
+ self.num_variables = len(num_categories)
+ self.num_categories = int(torch.max(num_categories))
+ self.mask = torch.zeros(self.num_variables, self.num_categories)
+ for i, c in enumerate(num_categories):
+ self.mask[i, :c] = 1
- Args:
- condition: batch of context parameters for the net.
-
- Returns:
- Tensor: batch of predicted categorical probabilities.
- """
- # forward path
- condition = self.activation(self.input_layer(condition))
-
- # iterate n hidden layers, input condition and calculate tanh activation
- for layer in self.hidden_layers:
- condition = self.activation(layer(condition))
+ super().__init__(
+ self.num_variables,
+ hidden_features,
+ context_features=context_features,
+ num_blocks=num_blocks,
+ output_multiplier=self.num_categories,
+ use_residual_blocks=use_residual_blocks,
+ random_mask=random_mask,
+ activation=activation,
+ dropout_probability=dropout_probability,
+ use_batch_norm=use_batch_norm,
+ )
- return self.softmax(self.output_layer(condition))
+ self.embedding_net = embedding_net
+ self.hidden_features = hidden_features
+ self.epsilon = epsilon
+ self.context_features = context_features
- def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
- """Return categorical log probability of categories input, given condition.
+ def forward(self, input: Tensor, condition: Optional[Tensor] = None) -> Tensor:
+ r"""Forward pass of the categorical density estimator network to compute the
+ conditional density at a given time.
Args:
- input: categories to evaluate.
- condition: parameters.
+ input: Inputs datapoints of shape `(batch_size, *input_shape)`
+ condition: Conditioning variable. `(batch_size, *condition_shape)`
Returns:
- Tensor: log probs with shape (input.shape[0],)
+ Predicted categorical logits. `(batch_size, *input_shape,
+ num_categories)`
"""
- # Predict categorical ps and evaluate.
- ps = self.forward(condition)
- # Squeeze the last dimension (event dim) because `Categorical` has
- # `event_shape=()` but our data usually has an event_shape of `(1,)`.
- return Categorical(probs=ps).log_prob(input.squeeze(dim=-1))
+ embedded_condition = self.embedding_net.forward(condition)
+ out = super().forward(input, context=embedded_condition)
+ return out.masked_fill(~self.mask.bool().flatten(), float("-inf"))
- def sample(self, sample_shape: torch.Size, condition: Tensor) -> Tensor:
- """Returns samples from categorical random variable with probs predicted from
- the neural net.
+ def log_prob(self, input: Tensor, condition: Optional[Tensor] = None) -> Tensor:
+ r"""Return log-probability of samples.
Args:
- sample_shape: number of samples to obtain.
- condition: batch of parameters for prediction.
+ input: Input datapoints of shape `(batch_size, *input_shape)`.
+ condition: Conditioning variable. `(batch_size, *condition_shape)`.
Returns:
- Tensor: Samples with shape (num_samples, 1)
+ Log-probabilities of shape `(batch_size, num_variables, num_categories)`.
"""
+ outputs = self.forward(input, condition=condition)
+ outputs = outputs.reshape(*input.shape, self.num_categories)
+ log_prob = Categorical(logits=outputs).log_prob(input).sum(dim=-1)
+ return log_prob
+
+ def sample(self, sample_shape, context=None):
+ # Ensure sample_shape is a tuple
+ if isinstance(sample_shape, int):
+ sample_shape = (sample_shape,)
+ sample_shape = torch.Size(sample_shape)
+
+ # Calculate total number of samples
+ num_samples = int(torch.prod(torch.tensor(sample_shape)))
+
+ # Prepare context
+ if context is not None:
+ batch_dim = context.shape[0]
+ if context.ndim == 2:
+ context = context.unsqueeze(0)
+ if batch_dim == 1:
+ context = torchutils.repeat_rows(context, num_samples)
+ else:
+ context_dim = 0 if self.context_features is None else self.context_features
+ context = torch.zeros(num_samples, context_dim)
+ batch_dim = 1
+
+ with torch.no_grad():
+ samples = torch.randn(num_samples, batch_dim, self.num_variables)
+ for i in range(self.num_variables):
+ outputs = self.forward(samples, context)
+ outputs = outputs.reshape(*samples.shape, self.num_categories)
+ samples[:, :, : i + 1] = Categorical(
+ logits=outputs[:, :, : i + 1]
+ ).sample()
- # Predict Categorical ps and sample.
- ps = self.forward(condition)
- return Categorical(probs=ps).sample(sample_shape=sample_shape)
+ return samples.reshape(*sample_shape, batch_dim, self.num_variables)
class CategoricalMassEstimator(ConditionalDensityEstimator):
@@ -117,12 +144,12 @@ class CategoricalMassEstimator(ConditionalDensityEstimator):
"""
def __init__(
- self, net: CategoricalNet, input_shape: torch.Size, condition_shape: torch.Size
+ self, net: CategoricalMADE, input_shape: torch.Size, condition_shape: torch.Size
) -> None:
"""Initialize the mass estimator.
Args:
- net: CategoricalNet.
+ net: CategoricalMADE.
input_shape: Shape of the input data.
condition_shape: Shape of the condition data
"""
diff --git a/sbi/neural_nets/estimators/mixed_density_estimator.py b/sbi/neural_nets/estimators/mixed_density_estimator.py
index dedba1b52..e46d7baeb 100644
--- a/sbi/neural_nets/estimators/mixed_density_estimator.py
+++ b/sbi/neural_nets/estimators/mixed_density_estimator.py
@@ -80,8 +80,10 @@ def sample(
sample_shape=sample_shape,
condition=condition,
)
- # Trailing `1` because `Categorical` has event_shape `()`.
- discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1)
+ num_variables = self.discrete_net.net.num_variables
+ discrete_samples = discrete_samples.reshape(
+ num_samples * batch_dim, num_variables
+ )
# repeat the batch of embedded condition to match number of choices.
condition_event_dim = embedded_condition.dim() - 1
@@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
f"{input_batch_dim} do not match."
)
- cont_input, disc_input = _separate_input(input)
+ num_discrete_variables = self.discrete_net.net.num_variables
+ cont_input, disc_input = _separate_input(input, num_discrete_variables)
# Embed continuous condition
embedded_condition = self.condition_embedding(condition)
# expand and repeat to match batch of inputs.
@@ -204,3 +207,8 @@ def _separate_input(
Assumes the discrete data to live in the last columns of input.
"""
return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:]
+
+
+def _is_discrete(input: Tensor) -> Tensor:
+ """Infer discrete columns in input data."""
+ return torch.tensor([torch.allclose(col, col.round()) for col in input.T])
diff --git a/sbi/neural_nets/net_builders/categorial.py b/sbi/neural_nets/net_builders/categorial.py
index 2c2add091..18ced4397 100644
--- a/sbi/neural_nets/net_builders/categorial.py
+++ b/sbi/neural_nets/net_builders/categorial.py
@@ -1,16 +1,18 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
+import warnings
from typing import Optional
-from torch import Tensor, nn, unique
+from torch import Tensor, nn, tensor, unique
-from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet
-from sbi.utils.nn_utils import get_numel
-from sbi.utils.sbiutils import (
- standardizing_net,
- z_score_parser,
+from sbi.neural_nets.estimators import (
+ CategoricalMADE,
+ CategoricalMassEstimator,
)
+from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete
+from sbi.utils.nn_utils import get_numel
+from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device
@@ -21,6 +23,7 @@ def build_categoricalmassestimator(
z_score_y: Optional[str] = "independent",
num_hidden: int = 20,
num_layers: int = 2,
+ num_categories: Optional[Tensor] = None,
embedding_net: nn.Module = nn.Identity(),
):
"""Returns a density estimator for a categorical random variable.
@@ -33,28 +36,39 @@ def build_categoricalmassestimator(
num_hidden: Number of hidden units per layer.
num_layers: Number of hidden layers.
embedding_net: Embedding net for y.
+ num_categories: number of categories for each variable.
"""
if z_score_x != "none":
raise ValueError("Categorical input should not be z-scored.")
+ if num_categories is None:
+ warnings.warn(
+ "Inferring num_categories from batch_x. Ensure all categories are present.",
+ stacklevel=2,
+ )
check_data_device(batch_x, batch_y)
- if batch_x.shape[1] > 1:
- raise NotImplementedError("CategoricalMassEstimator only supports 1D input.")
- num_categories = unique(batch_x).numel()
- dim_condition = get_numel(batch_y, embedding_net=embedding_net)
z_score_y_bool, structured_y = z_score_parser(z_score_y)
+ y_numel = get_numel(batch_y, embedding_net=embedding_net)
+
if z_score_y_bool:
embedding_net = nn.Sequential(
standardizing_net(batch_y, structured_y), embedding_net
)
- categorical_net = CategoricalNet(
- num_input=dim_condition,
+ if num_categories is None:
+ batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
+ inferred_categories = tensor([
+ unique(col).numel() for col in batch_x_discrete.T
+ ])
+ num_categories = inferred_categories
+
+ categorical_net = CategoricalMADE(
num_categories=num_categories,
- num_hidden=num_hidden,
- num_layers=num_layers,
+ hidden_features=num_hidden,
+ context_features=y_numel,
+ num_blocks=num_layers,
embedding_net=embedding_net,
)
diff --git a/sbi/neural_nets/net_builders/mnle.py b/sbi/neural_nets/net_builders/mnle.py
index fd7ef6d3a..2227ccef9 100644
--- a/sbi/neural_nets/net_builders/mnle.py
+++ b/sbi/neural_nets/net_builders/mnle.py
@@ -8,8 +8,13 @@
from torch import Tensor, nn
from sbi.neural_nets.estimators import MixedDensityEstimator
-from sbi.neural_nets.estimators.mixed_density_estimator import _separate_input
-from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator
+from sbi.neural_nets.estimators.mixed_density_estimator import (
+ _is_discrete,
+ _separate_input,
+)
+from sbi.neural_nets.net_builders.categorial import (
+ build_categoricalmassestimator,
+)
from sbi.neural_nets.net_builders.flow import (
build_made,
build_maf,
@@ -26,10 +31,7 @@
build_zuko_unaf,
)
from sbi.neural_nets.net_builders.mdn import build_mdn
-from sbi.utils.sbiutils import (
- standardizing_net,
- z_score_parser,
-)
+from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device
model_builders = {
@@ -56,6 +58,7 @@ def build_mnle(
z_score_x: Optional[str] = "independent",
z_score_y: Optional[str] = "independent",
flow_model: str = "nsf",
+ num_categorical_columns: Optional[Tensor] = None,
embedding_net: nn.Module = nn.Identity(),
combined_embedding_net: Optional[nn.Module] = None,
num_transforms: int = 2,
@@ -102,6 +105,8 @@ def build_mnle(
as z_score_x.
flow_model: type of flow model to use for the continuous part of the
data.
+ num_categorical_columns: Number of categorical columns of each variable in the
+ input data. If None, the function will infer this from the data.
embedding_net: Optional embedding network for y, required if y is > 1D.
combined_embedding_net: Optional embedding for combining the discrete
part of the input and the embedded condition into a joined
@@ -125,13 +130,17 @@ def build_mnle(
warnings.warn(
"The mixed neural likelihood estimator assumes that x contains "
- "continuous data in the first n-1 columns (e.g., reaction times) and "
- "categorical data in the last column (e.g., corresponding choices). If "
+ "continuous data in the first n-k columns (e.g., reaction times) and "
+ "categorical data in the last k columns (e.g., corresponding choices). If "
"this is not the case for the passed `x` do not use this function.",
stacklevel=2,
)
# Separate continuous and discrete data.
- cont_x, disc_x = _separate_input(batch_x)
+ if num_categorical_columns is None:
+ num_disc = int(torch.sum(_is_discrete(batch_x)))
+ else:
+ num_disc = len(num_categorical_columns)
+ cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc)
# Set up y-embedding net with z-scoring.
z_score_y_bool, structured_y = z_score_parser(z_score_y)
@@ -152,6 +161,7 @@ def build_mnle(
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
+ num_categories=num_categorical_columns,
)
if combined_embedding_net is None:
diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py
index 0f006dde1..b958637eb 100644
--- a/tests/density_estimator_test.py
+++ b/tests/density_estimator_test.py
@@ -412,7 +412,4 @@ def test_mixed_density_estimator(
# Test samples
samples = density_estimator.sample(sample_shape, condition=conditions)
- if density_estimator_build_fn == build_categoricalmassestimator:
- # Our categorical is always 1D and does not return `input_event_shape`.
- input_event_shape = ()
assert samples.shape == (*sample_shape, batch_dim, *input_event_shape)
diff --git a/tutorials/Example_01_DecisionMakingModel.ipynb b/tutorials/Example_01_DecisionMakingModel.ipynb
index eb16182c2..409a111af 100644
--- a/tutorials/Example_01_DecisionMakingModel.ipynb
+++ b/tutorials/Example_01_DecisionMakingModel.ipynb
@@ -103,7 +103,7 @@
" \"\"\"Returns a sample from a mixed distribution given parameters theta.\n",
"\n",
" Args:\n",
- " theta: batch of parameters, shape (batch_size, 2) concentration_scaling:\n",
+ " theta: batch of parameters, shape (batch_size, 1 + num_categories) concentration_scaling:\n",
" scaling factor for the concentration parameter of the InverseGamma\n",
" distribution, mimics an experimental condition.\n",
"\n",
diff --git a/tutorials/example_01_utils.py b/tutorials/example_01_utils.py
index 620058d05..a90d58dfe 100644
--- a/tutorials/example_01_utils.py
+++ b/tutorials/example_01_utils.py
@@ -54,7 +54,7 @@ def iid_likelihood(self, theta: Tensor) -> Tensor:
rate=beta,
).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1))
- joint_likelihood = (logprob_choices + logprob_rts).squeeze()
+ joint_likelihood = torch.sum(logprob_choices, dim=2) + logprob_rts.squeeze()
assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]])
return joint_likelihood.sum(1)