Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #173

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.14"
rev: "v0.2.0"
hooks:
- id: ruff
args:
Expand Down
1 change: 1 addition & 0 deletions src/stream_mapper/pytorch/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def forward(self, data: Data[Array]) -> Array:
-------
(N, 3) Array
fraction, mean, ln-sigma

"""
if self.net is None:
return self.xp.asarray([])
Expand Down
3 changes: 3 additions & 0 deletions src/stream_mapper/pytorch/_connect/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _array_at_pytorch(array: Array, idx: Any, /, *, inplace: bool = True) -> Arr
-------
ArrayAt[Array]
Setter.

"""
return ArrayAt(array if inplace else array.clone(), idx, inplace=inplace)

Expand All @@ -67,6 +68,7 @@ def _get_namespace_pytorch(array: Array, /) -> ArrayNamespace[Array]:
Returns
-------
ArrayNamespace[Array]

"""
return cast("ArrayNamespace[Array]", xp)

Expand All @@ -83,5 +85,6 @@ def _copy_pytorch(array: Array, /) -> Array:
Returns
-------
Array

"""
return array.clone()
1 change: 1 addition & 0 deletions src/stream_mapper/pytorch/_connect/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _within_bounds_pytorch(
-------
ndarray
Boolean array indicating whether the value is within the bounds.

"""
inbounds = xp.ones_like(value, dtype=xp.bool)
if lower_bound is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/stream_mapper/pytorch/_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class IndependentModels(ModelsBase, CoreIndependentModels[Array, NNModel]):
Mapping of parameter names to priors. This is useful for setting priors
on parameters across models, e.g. the background and stream models in a
mixture model.

"""

def __post_init__(self) -> None:
Expand All @@ -92,6 +93,7 @@ def forward(self, data: Data[Array], /) -> Array:
-------
Array
fraction, mean, ln-sigma.

"""
pred = self.xp.concatenate(
tuple(model(data) for model in self.components.values()), dim=1
Expand All @@ -115,6 +117,7 @@ class MixtureModel(ModelsBase, CoreMixtureModel[Array, NNModel]):
control over the type of the models attribute.
net : NNModel, optional postional-only
The neural network that is used to combine the components.

"""

net: NNField[NNModel, NNModel] = NNField(default=MISSING)
Expand All @@ -138,6 +141,7 @@ def forward(self, data: Data[Array], /) -> Array:
-------
Array
fraction, mean, ln-sigma.

"""
# Predict the weights, except the background weight, which is
# always 1 - sum(weights).
Expand Down
2 changes: 2 additions & 0 deletions src/stream_mapper/pytorch/builtin/_isochrone.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class IsochroneMVNorm(ModelBase):
coord_bounds=(...), # photometry mag_names=("g",),
mag_err_names=("g_err",), color_names=("g-r",),
color_err_names=("g-r_err",), phot_bounds=(...),

"""

net: NNField[NNModel, None] = NNField(default=None)
Expand Down Expand Up @@ -293,6 +294,7 @@ def ln_likelihood(
Returns
-------
Array[(N,)]

"""
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
Expand Down
1 change: 1 addition & 0 deletions src/stream_mapper/pytorch/builtin/_multinormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def ln_likelihood(
Returns
-------
Array

"""
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
Expand Down
1 change: 1 addition & 0 deletions src/stream_mapper/pytorch/builtin/_skewnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def ln_likelihood(
Returns
-------
Array

"""
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
Expand Down
3 changes: 3 additions & 0 deletions src/stream_mapper/pytorch/builtin/_sloped.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Sloped(ModelBase):
net : nn.Module, keyword-only
The network to use. If not provided, a new one will be created. Must be
a layer with 1 input and ``len(param names)-1`` outputs.

"""

_: KW_ONLY
Expand Down Expand Up @@ -100,6 +101,7 @@ def ln_likelihood(
Returns
-------
Array

"""
data = self.data_scaler.transform(
data, names=names_intersect(data.names, self.data_scaler.names), xp=self.xp
Expand Down Expand Up @@ -146,6 +148,7 @@ def forward(self, data: Data[Array]) -> Array:
-------
Array
fraction, mean, ln-sigma

"""
# The forward step runs on the normalized coordinates
data = self.data_scaler.transform(
Expand Down
1 change: 1 addition & 0 deletions src/stream_mapper/pytorch/builtin/_truncskewnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def ln_likelihood(
Returns
-------
Array

"""
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
Expand Down
2 changes: 2 additions & 0 deletions src/stream_mapper/pytorch/builtin/compat/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def ln_likelihood(
Returns
-------
Array

"""
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
Expand Down Expand Up @@ -103,5 +104,6 @@ def forward(self, data: Data[Array]) -> Array:
Returns
-------
Array

"""
return xp.asarray([])
2 changes: 2 additions & 0 deletions src/stream_mapper/pytorch/builtin/compat/kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def ln_likelihood(
Returns
-------
Array

"""
# TODO: support `where` argument.
with xp.no_grad():
Expand All @@ -85,5 +86,6 @@ def forward(self, data: Data[Array]) -> Array:
Returns
-------
Array

"""
return xp.asarray([])
1 change: 1 addition & 0 deletions src/stream_mapper/pytorch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def sequential(
Returns
-------
`torch.nn.Sequential`

"""
activation_func = nn.Tanh if activation is None else activation

Expand Down
1 change: 1 addition & 0 deletions src/stream_mapper/pytorch/params/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def scaled_sigmoid(x: Array, /, lower: Array, upper: Array) -> Array:
Returns
-------
Array

"""
if xp.isneginf(lower) and xp.isposinf(upper):
return x
Expand Down
Loading