Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow for the parameter bounds to be splines #186

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
66 changes: 57 additions & 9 deletions src/stream_mapper/core/_core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from stream_mapper.core.utils.frozen_dict import FrozenDict, FrozenDictField

if TYPE_CHECKING:
from collections.abc import Iterator

from stream_mapper.core import Data
from stream_mapper.core.prior import Prior
from stream_mapper.core.typing import ParamNameAllOpts, ParamsLikeDict
Expand Down Expand Up @@ -163,14 +165,19 @@ def __post_init__(self) -> None:
# Type hint
self._nn_namespace_: NNNamespace[NNModel, Array]

# Coordinate bounds are necessary (before they were auto-filled).
if self.coord_bounds.keys() != set(self.coord_names):
# Coordinate bounds are necessary and the orders must match.
if tuple(self.coord_bounds.keys()) != tuple(self.coord_names):
msg = (
f"`coord_bounds` ({tuple(self.coord_bounds.keys())}) do not match "
f"`coord_names` ({self.coord_names})."
)
raise ValueError(msg)

# also pre-compute whether the bounds are callable
self._bounds_callable = FrozenDict(
{k: (callable(a), callable(b)) for k, (a, b) in self.coord_bounds.items()}
)

# coord_err_names must be None or the same length as coord_names.
# we can't check that the names are the same, because they aren't.
# TODO: better way to ensure that
Expand All @@ -190,6 +197,36 @@ def __post_init__(self) -> None:
)
raise ValueError(msg)

def _get_lower_upper_bound(self, x: Array, /) -> tuple[Array, Array]:
"""Get the lower bound."""
if x.ndim > 1 and x.shape[1] > 1:
msg = "x must be 1D"
raise ValueError(msg)
_0 = self.xp.zeros_like(x)
a = self.xp.concat(
[
self.xp.atleast_2d(
self.xp.asarray(
a(x[:, 0]) if self._bounds_callable[k][0] else _0 + a # type: ignore[arg-type, operator]
)
)
for k, (a, _) in self.coord_bounds.items()
],
-1,
)
b = self.xp.concat(
[
self.xp.atleast_2d(
self.xp.asarray(
b(x[:, 0]) if self._bounds_callable[k][0] else _0 + b # type: ignore[arg-type, operator]
)
)
for k, (_, b) in self.coord_bounds.items()
],
-1,
)
return self.xp.atleast_2d(a), self.xp.atleast_2d(b)

# ========================================================================

def _stack_param(self, p: Params[Array], k: str, cns: tuple[str, ...], /) -> Array:
Expand Down Expand Up @@ -277,15 +314,26 @@ def _ln_prior_coord_bnds(self, data: Data[Array], /) -> Array:
coordinate bounds, where it is -inf.
"""
shape = data.array.shape[:1] + data.array.shape[2:]

# don't require all coordinates to be present in the data,
# e.g. "distmod" on an isochrone model.
# TODO: enable this for multiple context dimension. Right now it's only 1.
context = data[self.indep_coord_names].array[..., 0]
kab: Iterator[tuple[str, tuple[Array | float, Array | float]]] = (
( # type: ignore[misc]
k,
(
a(context) if self._bounds_callable[k][0] else a, # type: ignore[arg-type, operator]
b(context) if self._bounds_callable[k][1] else b, # type: ignore[arg-type, operator]
nstarman marked this conversation as resolved.
Show resolved Hide resolved
),
)
for k, (a, b) in self.coord_bounds.items()
if k in data.names
)

where = reduce(
self.xp.logical_or,
(
~within_bounds(data[k], *v)
for k, v in self.coord_bounds.items()
if k in data.names
# don't require all coordinates to be present in the data,
# e.g. "distmod" on an isochrone model.
),
(~within_bounds(data[k], a[:, 0], b[:, 0]) for k, (a, b) in kab),
nstarman marked this conversation as resolved.
Show resolved Hide resolved
self.xp.zeros(shape, dtype=bool),
)
return self.xp.where(
Expand Down
16 changes: 3 additions & 13 deletions src/stream_mapper/core/builtin/_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,6 @@ def __post_init__(self) -> None:
msg = f"Missing parameter for coordinate(s) {missing}"
raise ValueError(msg)

# Pre-compute the associated constant factors
_a = [a for k, (a, _) in self.coord_bounds.items() if k in self.params]
_b = [b for k, (_, b) in self.coord_bounds.items() if k in self.params]

self._a = self.xp.asarray(_a)[None, :] # ([N], F)
self._b = self.xp.asarray(_b)[None, :] # ([N], F)

# ========================================================================
# Statistics

def ln_likelihood(
self,
mpars: Params[Array],
Expand Down Expand Up @@ -128,15 +118,15 @@ def ln_likelihood(
# slope is a parameter. If it is not, then we assume it is 0.
# When the slope is 0, the log-likelihood reduces to a Uniform.
ms = self._stack_param(mpars, "slope", self.coord_names)[idx]
a, b = self._get_lower_upper_bound(data[self.indep_coord_names].array)

# the distribution is not affected by the errors!
# if self.coord_err_names is not None: pass
_0 = self.xp.zeros_like(x)
value = exponential_logpdf(
x[idx],
m=ms,
a=(_0 + self._a)[idx],
b=(_0 + self._b)[idx],
a=a[idx],
b=b[idx],
xp=self.xp,
nil=-self.xp.inf,
m_eps=self.m_eps,
Expand Down
8 changes: 3 additions & 5 deletions src/stream_mapper/core/builtin/_truncnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,16 @@ def ln_likelihood(
cns, cens = self.coord_names, self.coord_err_names
x = data[cns].array

a, b = self.xp.asarray([self.coord_bounds[k] for k in cns]).T[:, None, :]
a, b = self._get_lower_upper_bound(data[self.indep_coord_names].array)
# a, b = self.xp.asarray([self.coord_bounds[k] for k in cns]).T[:, None, :]
mu = self._stack_param(mpars, "mu", cns)[idx]
ln_s = self._stack_param(mpars, "ln-sigma", cns)[idx]
if cens is not None:
# it's fine if sigma_o is 0
sigma_o = data[cens].array[idx]
ln_s = self.xp.logaddexp(2 * ln_s, 2 * self.xp.log(sigma_o)) / 2

_0 = self.xp.zeros_like(x)
value = logpdf(
x[idx], loc=mu, ln_sigma=ln_s, a=(_0 + a)[idx], b=(_0 + b)[idx], xp=self.xp
)
value = logpdf(x[idx], loc=mu, ln_sigma=ln_s, a=a[idx], b=b[idx], xp=self.xp)

lnliks = self.xp.full_like(x, 0) # missing data is ignored
lnliks = array_at(lnliks, idx).set(value)
Expand Down
5 changes: 2 additions & 3 deletions src/stream_mapper/core/builtin/_truncskewnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def ln_likelihood(
cns, cens = self.coord_names, self.coord_err_names
x = data[cns].array

a, b = self.xp.asarray([self.coord_bounds[k] for k in cns]).T[:, None, :]
a, b = self._get_lower_upper_bound(data[self.indep_coord_names].array)
mu = self._stack_param(mpars, "mu", cns)[idx]
ln_s = self._stack_param(mpars, "ln-sigma", cns)[idx]
skew = self._stack_param(mpars, "skew", cns)[idx]
Expand All @@ -76,9 +76,8 @@ def ln_likelihood(
skew**2 / (1 + (sigma_o / self.xp.exp(ln_s)) ** 2 * (1 + skew**2))
)

_0 = self.xp.zeros_like(x)
value = logpdf(
x[idx], loc=mu, ln_sigma=ln_s, skew=skew, a=_0 + a, b=_0 + b, xp=self.xp
x[idx], loc=mu, ln_sigma=ln_s, skew=skew, a=a[idx], b=b[idx], xp=self.xp
)

lnliks = self.xp.full_like(x, 0) # missing data is ignored
Expand Down
4 changes: 3 additions & 1 deletion src/stream_mapper/core/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
)


from collections.abc import Callable
from typing import TypeAlias

from stream_mapper.core.typing._array import Array, Array_co, ArrayLike
from stream_mapper.core.typing._nn import NNModel, NNNamespace
from stream_mapper.core.typing._xp import ArrayNamespace

BoundsT: TypeAlias = tuple[float, float]
BoundT: TypeAlias = float | Callable[[float], float]
BoundsT: TypeAlias = tuple[BoundT, BoundT]

ParamNameTupleOpts: TypeAlias = tuple[str] | tuple[str, str]
ParamNameAllOpts: TypeAlias = str | ParamNameTupleOpts
Expand Down
5 changes: 5 additions & 0 deletions src/stream_mapper/core/typing/_xp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def atleast_1d(array: Array) -> Array:
"""At least 1D."""
...

@staticmethod
def atleast_2d(array: Array) -> Array:
"""At least 2D."""
...

@staticmethod
def clip(array: Array, *args: Any) -> Array:
"""Clip."""
Expand Down
Loading