Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit autoupdate (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] authored Jan 29, 2024
1 parent 68f2330 commit c8ae32b
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
- --fix

- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.1.1
hooks:
- id: black
additional_dependencies: [toml]
Expand Down
6 changes: 3 additions & 3 deletions src/stream_mapper/pytorch/_connect/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ def _from_ndarraytype_to_tensor(
array.flags.writeable = True
return replace(data, array=xp.asarray(array, **kwargs))

ASTYPE_REGISTRY[
(asdf.tags.core.ndarray.NDArrayType, xp.Tensor)
] = _from_ndarraytype_to_tensor
ASTYPE_REGISTRY[(asdf.tags.core.ndarray.NDArrayType, xp.Tensor)] = (
_from_ndarraytype_to_tensor
)
4 changes: 1 addition & 3 deletions src/stream_mapper/pytorch/builtin/_skewnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def ln_likelihood(
# TODO: I suspect there are better ways to write this
sigma_o = data[cens].array[idx]
sigma = self.xp.exp(ln_s)
skew = (
skew * sigma / self.xp.sqrt(sigma**2 + (1 + skew**2) * sigma_o**2)
)
skew = skew * sigma / self.xp.sqrt(sigma**2 + (1 + skew**2) * sigma_o**2)
ln_s = self.xp.log(sigma**2 + sigma_o**2) / 2

# Find where -inf
Expand Down
4 changes: 1 addition & 3 deletions src/stream_mapper/pytorch/builtin/_truncskewnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def ln_likelihood(
# TODO: I suspect there are better ways to write this
sigma_o = data[cens].array[idx]
sigma = self.xp.exp(ln_s)
skew = (
skew * sigma / self.xp.sqrt(sigma**2 + (1 + skew**2) * sigma_o**2)
)
skew = skew * sigma / self.xp.sqrt(sigma**2 + (1 + skew**2) * sigma_o**2)
ln_s = self.xp.log(sigma**2 + sigma_o**2) / 2

# Find where -inf
Expand Down
8 changes: 5 additions & 3 deletions src/stream_mapper/pytorch/builtin/compat/nflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def _log_prob(self, data: Data[Array], idx: Array) -> Array:
"""Log-probability of the array."""
return self.net.log_prob(
inputs=data[self.coord_names].array[idx],
context=data[self.indep_coord_names].array[idx]
if self.indep_coord_names is not None
else None,
context=(
data[self.indep_coord_names].array[idx]
if self.indep_coord_names is not None
else None
),
)

0 comments on commit c8ae32b

Please sign in to comment.