Skip to content

Commit

Permalink
Parameter bound inf (#181)
Browse files Browse the repository at this point in the history
* allow for inf to be set. inf breaks the gradient. Now inf can be approximated.
* better str repr for control regions

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Nov 17, 2023
1 parent 4167a6c commit 99d7daf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/stream_ml/core/params/bounds/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ParameterBounds(
param_name: ParamNameTupleOpts | None = None
scaler: InitVar[ParamScaler[Array] | None] = None
name: str | None = None # the name of the prior
neg_inf: Array | float = -float("inf")

array_namespace: ArrayNamespace[Array]

Expand Down Expand Up @@ -129,7 +130,7 @@ def logpdf(
bp = self.xp.zeros_like(mpars[self.param_name])
return array_at(
bp, ~within_bounds(mpars[self.param_name], self.lower, self.upper)
).set(-self.xp.inf)
).set(self.neg_inf)

@abstractmethod
def __call__(
Expand Down
26 changes: 24 additions & 2 deletions src/stream_ml/core/prior/_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@

__all__: tuple[str, ...] = ()

from dataclasses import KW_ONLY, dataclass
from dataclasses import KW_ONLY, dataclass, fields
from typing import TYPE_CHECKING

from stream_ml.core._data import Data
from stream_ml.core.prior import Prior
from stream_ml.core.typing import Array, NNModel

if TYPE_CHECKING:
from stream_ml.core import Data, ModelAPI, Params
from typing import Any

from stream_ml.core import ModelAPI, Params
from stream_ml.core.typing import ArrayNamespace


Expand Down Expand Up @@ -141,3 +144,22 @@ def logpdf(
lnpdf[where] = (cmp_arr[where] - (self._y[where] + self._w[where])) ** 2 # type: ignore[index]

return -self.lamda * self.xp.sum(lnpdf) # (C, F) -> 1

def __str__(self) -> str:
"""String representation."""
fs = (
f"{f.name}={_as_str(getattr(self, f.name))}"
if f.name != "array_namespace"
else f"{f.name}={(self.xp if isinstance(self.xp, str) else self.xp.__name__)!r}" # noqa: E501
for f in fields(self)
)
return f"{self.__class__.__name__}({' '.join(fs)})"


def _as_str(v: Any) -> str:
"""Get string representation."""
if isinstance(v, Data):
return "..."
elif isinstance(v, str):
return f"{v!r}"
return str(v)

0 comments on commit 99d7daf

Please sign in to comment.