Skip to content

Commit

Permalink
All other methods
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jan 30, 2025
1 parent 9152d28 commit df1f30d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
48 changes: 43 additions & 5 deletions sbi/inference/potentials/score_fn_iid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from torch import Tensor
from torch.distributions import Distribution

from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.utils.torchutils import ensure_theta_batched
from sbi.inference.potentials.score_utils import (
add_diag_or_dense,
denoise,
marginalize,
mv_diag_or_dense,
solve_diag_or_dense,
)
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.utils.torchutils import ensure_theta_batched


class ScoreFnIID:
Expand Down Expand Up @@ -220,11 +220,11 @@ def marginal_posterior_precision_est_fn(
std = self.score_estimator.std_t_fn(a)

if precisions_posteriors.ndim == 3:
I = torch.eye(precisions_posteriors.shape[-1])
Ident = torch.eye(precisions_posteriors.shape[-1])
else:
I = torch.ones_like(precisions_posteriors)
Ident = torch.ones_like(precisions_posteriors)

marginal_precisions = m**2 / std**2 * I + precisions_posteriors
marginal_precisions = m**2 / std**2 * Ident + precisions_posteriors
return marginal_precisions

def marginal_prior_score_fn(self, a: Tensor, theta: Tensor) -> Tensor:
Expand Down Expand Up @@ -304,3 +304,41 @@ def __call__(self, a: Tensor, theta: Tensor, x_o: Tensor, **kwargs) -> Tensor:
score = solve_diag_or_dense(Lam, score)

return score


class GaussCorrectedScoreFn(AbstractGaussCorrectedScoreFn):
def __init__(
self,
score_estimator: ConditionalScoreEstimator,
prior: Distribution,
posterior_precision: Tensor,
) -> None:
r"""Initializes the GaussCorrectedScoreFn class.
Args:
score_estimator: The neural network modelling the score.
prior: The prior distribution.
"""
super().__init__(score_estimator, prior)
self.posterior_precision = posterior_precision

def posterior_precision_est_fn(self, x_o: Tensor) -> Tensor:
r"""Estimates the posterior precision.
Args:
x_o: Observed data.
Returns:
Estimated posterior precision.
"""
return self.posterior_precision


class AutoGaussCorrectedScoreFn(AbstractGaussCorrectedScoreFn):
# TODO: Move over..
pass


class JacCorrectedScoreFn(AbstractGaussCorrectedScoreFn):
pass
# TODO: Move over...
4 changes: 2 additions & 2 deletions sbi/inference/potentials/score_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import Tensor
import torch
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal

# Automatic denoising -----------------------------------------------------
Expand Down Expand Up @@ -187,4 +187,4 @@ def add_diag_or_dense(A_diag_or_dense: Tensor, B_diag_or_dense: Tensor) -> Tenso
elif A_diag_or_dense.ndim == 1 and B_diag_or_dense.ndim == 2:
return torch.diag(A_diag_or_dense) + B_diag_or_dense
else:
raise ValueError("Incompatible dimensions for addition")
raise ValueError("Incompatible dimensions for addition")

0 comments on commit df1f30d

Please sign in to comment.