Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CategoricalMADE #1269

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions sbi/neural_nets/estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from sbi.neural_nets.estimators.categorical_net import (
CategoricalMADE,
CategoricalMassEstimator,
CategoricalNet,
)
from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator
from sbi.neural_nets.estimators.mixed_density_estimator import (
MixedDensityEstimator,
)
from sbi.neural_nets.estimators.mixed_density_estimator import MixedDensityEstimator
from sbi.neural_nets.estimators.nflows_flow import NFlowsFlow
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.estimators.zuko_flow import ZukoFlow
177 changes: 110 additions & 67 deletions sbi/neural_nets/estimators/categorical_net.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Optional
from typing import Callable, Optional

import torch
from nflows.nn.nde.made import MADE
from nflows.utils import torchutils
from torch import Tensor, nn
from torch.distributions import Categorical
from torch.nn import Sigmoid, Softmax
from torch.nn import functional as F

from sbi.neural_nets.estimators.base import ConditionalDensityEstimator


class CategoricalNet(nn.Module):
"""Conditional density (mass) estimation for a categorical random variable.
class CategoricalMADE(MADE):
"""Conditional density (mass) estimation for a n-dim categorical random variable.

Takes as input parameters theta and learns the parameters p of a Categorical.

Expand All @@ -21,93 +23,134 @@ class CategoricalNet(nn.Module):

def __init__(
self,
num_input: int,
num_categories: int,
num_hidden: int = 20,
num_layers: int = 2,
embedding_net: Optional[nn.Module] = None,
num_categories: Tensor, # Tensor[int]
hidden_features: int,
context_features: Optional[int] = None,
num_blocks: int = 2,
use_residual_blocks: bool = True,
random_mask: bool = False,
activation: Callable = F.relu,
dropout_probability: float = 0.0,
use_batch_norm: bool = False,
epsilon: float = 1e-2,
custom_initialization: bool = True,
embedding_net: nn.Module = nn.Identity(),
):
jnsbck marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize the neural net.

Args:
num_input: number of input units, i.e., dimensionality of the features.
num_categories: number of output units, i.e., number of categories.
num_categories: number of categories for each variable. len(categories)
defines the number of input units, i.e., dimensionality of the features.
max(categories) defines the number of output units, i.e., the largest
number of categories.
num_hidden: number of hidden units per layer.
num_layers: number of hidden layers.
embedding_net: emebedding net for input.
"""
super().__init__()

self.num_hidden = num_hidden
self.num_input = num_input
self.activation = Sigmoid()
self.softmax = Softmax(dim=1)
self.num_categories = num_categories

# Maybe add embedding net in front.
if embedding_net is not None:
self.input_layer = nn.Sequential(
embedding_net, nn.Linear(num_input, num_hidden)
)
else:
self.input_layer = nn.Linear(num_input, num_hidden)
if use_residual_blocks and random_mask:
raise ValueError("Residual blocks can't be used with random masks.")

self.num_variables = len(num_categories)
self.num_categories = int(torch.max(num_categories))
self.mask = torch.zeros(self.num_variables, self.num_categories)
for i, c in enumerate(num_categories):
self.mask[i, :c] = 1

super().__init__(
self.num_variables,
hidden_features,
context_features=context_features,
num_blocks=num_blocks,
output_multiplier=self.num_categories,
use_residual_blocks=use_residual_blocks,
random_mask=random_mask,
activation=activation,
dropout_probability=dropout_probability,
use_batch_norm=use_batch_norm,
)

# Repeat hidden units hidden layers times.
self.hidden_layers = nn.ModuleList()
for _ in range(num_layers):
self.hidden_layers.append(nn.Linear(num_hidden, num_hidden))
self.embedding_net = embedding_net
self.hidden_features = hidden_features
self.epsilon = epsilon
self.context_features = context_features

self.output_layer = nn.Linear(num_hidden, num_categories)
if custom_initialization:
self._initialize()

def forward(self, condition: Tensor) -> Tensor:
"""Return categorical probability predicted from a batch of inputs.
def forward(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor:
r"""Forward pass of the categorical density estimator network to compute the
conditional density at a given time.

Args:
condition: batch of context parameters for the net.
input: Original data, x0. (batch_size, *input_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better keep this general: "input variable", because it could be x for MNLE, or theta when doing mixed NPE (see Daniel's recent PR).

condition: Conditioning variable. (batch_size, *condition_shape)

Returns:
Tensor: batch of predicted categorical probabilities.
Predicted categorical logits. (batch_size, *input_shape,
num_categories)
"""
# forward path
condition = self.activation(self.input_layer(condition))
embedded_context = self.embedding_net.forward(context)
return super().forward(inputs, context=embedded_context)

# iterate n hidden layers, input condition and calculate tanh activation
for layer in self.hidden_layers:
condition = self.activation(layer(condition))
def compute_probs(self, outputs):
jnsbck marked this conversation as resolved.
Show resolved Hide resolved
ps = F.softmax(outputs, dim=-1) * self.mask
ps = ps / ps.sum(dim=-1, keepdim=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this numerically stable? Better use logsumexp?

return ps

return self.softmax(self.output_layer(condition))

def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
"""Return categorical log probability of categories input, given condition.
def log_prob(self, inputs: Tensor, context: Optional[Tensor] = None) -> Tensor:
r"""Return log-probability of samples.

Args:
input: categories to evaluate.
condition: parameters.
input: Input datapoints of shape `(batch_size, *input_shape)`.
context: Context of shape `(batch_size, *condition_shape)`.

Returns:
Tensor: log probs with shape (input.shape[0],)
Log-probabilities of shape `(batch_size, num_variables, num_categories)`.
"""
# Predict categorical ps and evaluate.
ps = self.forward(condition)
# Squeeze the last dimension (event dim) because `Categorical` has
# `event_shape=()` but our data usually has an event_shape of `(1,)`.
return Categorical(probs=ps).log_prob(input.squeeze(dim=-1))

def sample(self, sample_shape: torch.Size, condition: Tensor) -> Tensor:
"""Returns samples from categorical random variable with probs predicted from
the neural net.
outputs = self.forward(inputs, context=context)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these shapes correct?

outputs = outputs.reshape(*inputs.shape, self.num_categories)
ps = self.compute_probs(outputs)

# categorical log prob
log_prob = torch.log(ps.gather(-1, inputs.unsqueeze(-1).long()))
log_prob = log_prob.squeeze(-1).sum(dim=-1)
Comment on lines +111 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very naive question here: the outputs are coming from the MADE, i.e., the conditional dependencies are already taken care of internally right?

I am just wondering because for the 1-D case, we used the network-predicted ps to construct a Categorical distribution and then evaluated the inputs under that distribution. This is not needed here because the underlying MADE takes both the inputs and the context and outputs unnormalized conditional probabilities already?


return log_prob

def sample(self, sample_shape, context=None):
jnsbck marked this conversation as resolved.
Show resolved Hide resolved
# Ensure sample_shape is a tuple
if isinstance(sample_shape, int):
sample_shape = (sample_shape,)
sample_shape = torch.Size(sample_shape)

# Calculate total number of samples
num_samples = int(torch.prod(torch.tensor(sample_shape)))

# Prepare context
if context is not None:
batch_dim = context.shape[0]
if context.ndim == 2:
context = context.unsqueeze(0)
if batch_dim == 1:
context = torchutils.repeat_rows(context, num_samples)
else:
context_dim = 0 if self.context_features is None else self.context_features
context = torch.zeros(num_samples, context_dim)
batch_dim = 1

Args:
sample_shape: number of samples to obtain.
condition: batch of parameters for prediction.
with torch.no_grad():
samples = torch.zeros(num_samples, batch_dim, self.num_variables)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be initialized with uniform torch.rand?

print(samples.shape, context.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debugging leftover?

for i in range(self.num_variables):
outputs = self.forward(samples, context)
outputs = outputs.reshape(*samples.shape, self.num_categories)
ps = self.compute_probs(outputs)
samples[:, :, i] = Categorical(probs=ps[:, :, i]).sample()
Comment on lines +144 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as above: these samples are internally autoregressive, right? So each discrete variable is sampled given the upstream discrete variables?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just confused because I would have expected that we for each iteration we need pass the so far sampled discrete samples as context; but this seems to be happening implicitly in the MADE?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see it: in line 148 you are updating samples with the new samples of the current i. It probably boils down to the same thing, but you could also update all sofar sampled samples, i.e.,

amples[:, :, :(i+1)] = Categorical(probs=ps[:, :, :(i+1)]).sample()

?


Returns:
Tensor: Samples with shape (num_samples, 1)
"""
return samples.reshape(*sample_shape, batch_dim, self.num_variables)

# Predict Categorical ps and sample.
ps = self.forward(condition)
return Categorical(probs=ps).sample(sample_shape=sample_shape)
def _initialize(self):
pass
Comment on lines +152 to +153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the custom init? is it a abstract method from nflows MADE?
if so, we should probably raise "not implemented" here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is part of nflows

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

however, cannot raise not implemented, since this is run on init of MADE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I am missing something the _initialize() is needed only in MixtureOfGaussiansMADE(MADE):, not in MADE, so it's not needed here?



class CategoricalMassEstimator(ConditionalDensityEstimator):
Expand All @@ -117,12 +160,12 @@ class CategoricalMassEstimator(ConditionalDensityEstimator):
"""

def __init__(
self, net: CategoricalNet, input_shape: torch.Size, condition_shape: torch.Size
self, net: CategoricalMADE, input_shape: torch.Size, condition_shape: torch.Size
) -> None:
"""Initialize the mass estimator.

Args:
net: CategoricalNet.
net: CategoricalMADE.
input_shape: Shape of the input data.
condition_shape: Shape of the condition data
"""
Expand Down
14 changes: 11 additions & 3 deletions sbi/neural_nets/estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def sample(
sample_shape=sample_shape,
condition=condition,
)
# Trailing `1` because `Categorical` has event_shape `()`.
discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1)
num_variables = self.discrete_net.net.num_variables
discrete_samples = discrete_samples.reshape(
num_samples * batch_dim, num_variables
)
jnsbck marked this conversation as resolved.
Show resolved Hide resolved

# repeat the batch of embedded condition to match number of choices.
condition_event_dim = embedded_condition.dim() - 1
Expand Down Expand Up @@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
f"{input_batch_dim} do not match."
)

cont_input, disc_input = _separate_input(input)
num_discrete_variables = self.discrete_net.net.num_variables
cont_input, disc_input = _separate_input(input, num_discrete_variables)
# Embed continuous condition
embedded_condition = self.condition_embedding(condition)
# expand and repeat to match batch of inputs.
Expand Down Expand Up @@ -204,3 +207,8 @@ def _separate_input(
Assumes the discrete data to live in the last columns of input.
"""
return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:]


def _is_discrete(input: Tensor) -> Tensor:
"""Infer discrete columns in input data."""
return torch.tensor([torch.allclose(col, col.round()) for col in input.T])
42 changes: 28 additions & 14 deletions sbi/neural_nets/net_builders/categorial.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import Optional

from torch import Tensor, nn, unique
from torch import Tensor, nn, tensor, unique

from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import (
standardizing_net,
z_score_parser,
from sbi.neural_nets.estimators import (
CategoricalMADE,
CategoricalMassEstimator,
)
from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device


Expand All @@ -21,6 +23,7 @@ def build_categoricalmassestimator(
z_score_y: Optional[str] = "independent",
num_hidden: int = 20,
num_layers: int = 2,
num_categories: Optional[Tensor] = None,
embedding_net: nn.Module = nn.Identity(),
):
"""Returns a density estimator for a categorical random variable.
Expand All @@ -33,28 +36,39 @@ def build_categoricalmassestimator(
num_hidden: Number of hidden units per layer.
num_layers: Number of hidden layers.
embedding_net: Embedding net for y.
num_categories: number of categories for each variable.
"""

if z_score_x != "none":
raise ValueError("Categorical input should not be z-scored.")
if num_categories is None:
warnings.warn(
"Inferring num_categories from batch_x. Ensure all categories are present.",
stacklevel=2,
)

check_data_device(batch_x, batch_y)
if batch_x.shape[1] > 1:
raise NotImplementedError("CategoricalMassEstimator only supports 1D input.")
num_categories = unique(batch_x).numel()
dim_condition = get_numel(batch_y, embedding_net=embedding_net)

z_score_y_bool, structured_y = z_score_parser(z_score_y)
y_numel = get_numel(batch_y, embedding_net=embedding_net)

if z_score_y_bool:
embedding_net = nn.Sequential(
standardizing_net(batch_y, structured_y), embedding_net
)

categorical_net = CategoricalNet(
num_input=dim_condition,
if num_categories is None:
batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
inferred_categories = tensor([
unique(col).numel() for col in batch_x_discrete.T
])
num_categories = inferred_categories

categorical_net = CategoricalMADE(
num_categories=num_categories,
num_hidden=num_hidden,
num_layers=num_layers,
hidden_features=num_hidden,
context_features=y_numel,
num_blocks=num_layers,
embedding_net=embedding_net,
)

Expand Down
Loading
Loading