Skip to content

Commit

Permalink
fix: type annotation in class 'ConditionedPotential'. (#1222)
Browse files Browse the repository at this point in the history
Correct the return type of the function 'posterior_estimator_based_potential'
  • Loading branch information
schroedk authored Aug 19, 2024
1 parent 593e153 commit 5de7784
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
5 changes: 3 additions & 2 deletions sbi/analysis/conditional_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch import Tensor
from torch.distributions import Distribution

from sbi.inference.potentials.base_potential import BasePotential
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils.conditional_density_utils import (
ConditionedPotential,
Expand Down Expand Up @@ -234,7 +235,7 @@ def log_prob(self, theta: Tensor) -> Tensor:


def conditonal_potential(
potential_fn: Callable,
potential_fn: BasePotential,
theta_transform: TorchTransform,
prior: Distribution,
condition: Tensor,
Expand All @@ -257,7 +258,7 @@ def conditonal_potential(


def conditional_potential(
potential_fn: Callable,
potential_fn: BasePotential,
theta_transform: TorchTransform,
prior: Distribution,
condition: Tensor,
Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from __future__ import annotations

from typing import Callable, Optional, Tuple
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand All @@ -23,7 +24,7 @@ def posterior_estimator_based_potential(
prior: Distribution,
x_o: Optional[Tensor],
enable_transform: bool = True,
) -> Tuple[Callable, TorchTransform]:
) -> Tuple[PosteriorBasedPotential, TorchTransform]:
r"""Returns the potential for posterior-based methods.
It also returns a transformation that can be used to transform the potential into
Expand Down
3 changes: 2 additions & 1 deletion sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import Tensor
from torch.distributions import Distribution

from sbi.inference.potentials.base_potential import BasePotential
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.user_input_checks import process_x

Expand Down Expand Up @@ -273,7 +274,7 @@ def condition_mog(
class ConditionedPotential:
def __init__(
self,
potential_fn: Callable,
potential_fn: BasePotential,
condition: Tensor,
dims_to_sample: List[int],
):
Expand Down

0 comments on commit 5de7784

Please sign in to comment.