diff --git a/sbi/analysis/conditional_density.py b/sbi/analysis/conditional_density.py index 05651ce88..3e206afaa 100644 --- a/sbi/analysis/conditional_density.py +++ b/sbi/analysis/conditional_density.py @@ -10,6 +10,7 @@ from torch import Tensor from torch.distributions import Distribution +from sbi.inference.potentials.base_potential import BasePotential from sbi.sbi_types import Shape, TorchTransform from sbi.utils.conditional_density_utils import ( ConditionedPotential, @@ -234,7 +235,7 @@ def log_prob(self, theta: Tensor) -> Tensor: def conditonal_potential( - potential_fn: Callable, + potential_fn: BasePotential, theta_transform: TorchTransform, prior: Distribution, condition: Tensor, @@ -257,7 +258,7 @@ def conditonal_potential( def conditional_potential( - potential_fn: Callable, + potential_fn: BasePotential, theta_transform: TorchTransform, prior: Distribution, condition: Tensor, diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index 3dead6ade..9272b2d68 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -1,7 +1,8 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Apache License Version 2.0, see +from __future__ import annotations -from typing import Callable, Optional, Tuple +from typing import Optional, Tuple import torch from torch import Tensor @@ -23,7 +24,7 @@ def posterior_estimator_based_potential( prior: Distribution, x_o: Optional[Tensor], enable_transform: bool = True, -) -> Tuple[Callable, TorchTransform]: +) -> Tuple[PosteriorBasedPotential, TorchTransform]: r"""Returns the potential for posterior-based methods. It also returns a transformation that can be used to transform the potential into diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py index 2c875247e..ee8606389 100644 --- a/sbi/utils/conditional_density_utils.py +++ b/sbi/utils/conditional_density_utils.py @@ -11,6 +11,7 @@ from torch import Tensor from torch.distributions import Distribution +from sbi.inference.potentials.base_potential import BasePotential from sbi.utils.torchutils import ensure_theta_batched from sbi.utils.user_input_checks import process_x @@ -273,7 +274,7 @@ def condition_mog( class ConditionedPotential: def __init__( self, - potential_fn: Callable, + potential_fn: BasePotential, condition: Tensor, dims_to_sample: List[int], ):