Skip to content

Commit

Permalink
Add log1mexp and logdiffexp functions (#1960)
Browse files Browse the repository at this point in the history
* Add log1mexp and logdiffexp functions and associated docs

* Move #noqa after a line break to avoid it getting included in link

* Clarify case of a == jnp.inf in logdiffexp docs

* Add needed extra line break

* Improve conditional structure, add basic tests

* Add type hints

* Add gradient tests

* Add more tests, including numerical tests of gradients

* Make custom jvp syntax follow jax docs, add some more tests

* Run ruff

* Run ruff

* Allow 'ans' in codespell, remove noqa statements

* Attribution for custom jvp approach

* Reformat files
  • Loading branch information
dylanhmorris authored Jan 26, 2025
1 parent 88405d5 commit 93e11c2
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ repos:
- id: codespell
stages: [pre-commit, commit-msg]
args:
[--ignore-words-list, "Teh,aas", --check-filenames, --skip, "*.ipynb"]
[--ignore-words-list, "Teh,aas,ans", --check-filenames, --skip, "*.ipynb"]
12 changes: 12 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1100,3 +1100,15 @@ BlockNeuralAutoregressiveTransform
:undoc-members:
:show-inheritance:
:member-order: bysource


Utilities
---------

log1mexp
^^^^^^^^
.. autofunction:: numpyro.distributions.util.log1mexp

logdiffexp
^^^^^^^^^^
.. autofunction:: numpyro.distributions.util.logdiffexp
58 changes: 58 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import digamma
from jax.typing import ArrayLike

from numpyro.util import not_jax_tracer

Expand Down Expand Up @@ -419,6 +420,63 @@ def logmatmulexp(x, y):
return xy + x_shift + y_shift


@jax.custom_jvp
def log1mexp(x: ArrayLike) -> ArrayLike:
"""
Numerically stable calculation of the quantity
:math:`\\log(1 - \\exp(x))`, following the algorithm
of `Mächler 2012`_.
.. _Mächler 2012: https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
Returns ``-jnp.inf`` when ``x == 0`` and ``jnp.nan``
when ``x > 0``.
:param x: A number or array of numbers.
:return: The value of :math:`\\log(1 - \\exp(x))`.
"""
return jnp.where(
x > -0.6931472, # approx log(2)
jnp.log(-jnp.expm1(x)),
jnp.log1p(-jnp.exp(x)),
)


# Custom jvp for log1mexp to handle the gradient when x is near 0.
#
# Inspired by the approach taken here for the function log1mexp(-x):
# https://github.com/google-research/google-research/blob/14e984cdb8630a7e3d210dff8760fc06d490fc4b/diffusion_distillation/diffusion_distillation/utils.py#L364-L370
# That code is (c) 2024 The Google Research Authors and licensed under
# an Apache 2.0 License.
log1mexp.defjvps(lambda t, ans, x: -t / jnp.expm1(-x))


def logdiffexp(a: ArrayLike, b: ArrayLike) -> ArrayLike:
"""
Numerically stable calculation of the
quantity :math:`\\log(\\exp(a) - \\exp(b))`,
provided :math:`+\\infty > a \\ge b`,
following the algorithm of `Mächler 2012`_.
.. _Mächler 2012: https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
Returns ``-jnp.inf`` when ``a == b``,
including when ``a == b == -jnp.inf``,
since this corresponds to ``jnp.log(0)``.
Returns ``jnp.nan`` when ``a < b`` or
``a == jnp.inf``.
:param a: A number or array of numbers.
:param b: A number or array of numbers.
:return: The value of :math:`\\log(\\exp(a) - \\exp(b))`.
"""
return jnp.where(
(a < jnp.inf) & (a > b),
a + log1mexp(b - a),
jnp.where(a == b, -jnp.inf, jnp.nan),
)


def clamp_probs(probs):
finfo = jnp.finfo(jnp.result_type(probs, float))
return jnp.clip(probs, finfo.tiny, 1.0 - finfo.eps)
Expand Down
157 changes: 155 additions & 2 deletions test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from numbers import Number

import numpy as np
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
import pytest
import scipy

import jax
from jax import lax, random, vmap
from jax import grad, lax, random, vmap
import jax.numpy as jnp
from jax.scipy.special import expit, xlog1py, xlogy
from jax.test_util import check_grads

import numpyro.distributions as dist
from numpyro.distributions.util import (
Expand All @@ -20,6 +21,8 @@
binomial,
categorical,
cholesky_update,
log1mexp,
logdiffexp,
multinomial,
safe_normalize,
vec_to_tril_matrix,
Expand Down Expand Up @@ -78,6 +81,156 @@ def test_categorical_stats(p):
assert_allclose(counts / float(n), p, atol=0.01)


@pytest.mark.parametrize("x", [-80.5632, -0.32523, -0.5, -20.53, -8.032])
def test_log1mexp_grads(x):
check_grads(log1mexp, (x,), order=3)


@pytest.mark.parametrize(
"x, expected",
[
(jnp.array([0.01, 0, -jnp.inf]), jnp.array([jnp.nan, -jnp.inf, 0])),
(0.001, jnp.nan),
(0, -jnp.inf),
(-jnp.inf, 0),
],
)
def test_log1mexp_bounds_handling(x, expected):
"""
log1mexp(x) should be nan for x > 0.
log1mexp(x) should be -inf for x == 0.
log1mexp(-inf) should be 0.
This should work vectorized and not interfere
with other calculations.
"""
assert_array_equal(log1mexp(x), expected)


@pytest.mark.parametrize("x", [jnp.array([-0.6, -8.32, -3]), -2.5, -0.01])
def test_log1mexp_agrees_with_basic(x):
"""
log1mexp should agree with a basic implementation
for values where the basic implementation is stable.
"""
assert_array_almost_equal(log1mexp(x), jnp.log(1 - jnp.exp(x)))


def test_log1mexp_stable():
"""
log1mexp should be stable at (negative) values of
x that very small and very large in absolute
value, where the basic implementation is not.
"""

def basic(x):
return jnp.log(1 - jnp.exp(x))

# this should perhaps be made finfo-aware
assert jnp.isinf(basic(-1e-20))
assert not jnp.isinf(log1mexp(-1e-20))
assert_array_almost_equal(log1mexp(-1e-20), jnp.log(-jnp.expm1(-1e-20)))
assert abs(basic(-50)) < abs(log1mexp(-50))
assert_array_almost_equal(log1mexp(-50), jnp.log1p(-jnp.exp(-50)))


@pytest.mark.parametrize("x", [-30.0, -2.53, -1e-4, -1e-9, -1e-15, -1e-40])
def test_log1mexp_grad_stable(x):
"""
Custom JVP for log1mexp should make gradient computation
numerically stable, even near zero, where the basic approach
can encounter divide-by-zero problems and yield nan.
The two approaches should produce almost equal answers elsewhere.
"""

def log1mexp_no_custom(x):
return jnp.where(
x > -0.6931472, # approx log(2)
jnp.log(-jnp.expm1(x)),
jnp.log1p(-jnp.exp(x)),
)

grad_custom = grad(log1mexp)(x)
grad_no_custom = grad(log1mexp_no_custom)(x)

assert_array_almost_equal(grad_custom, -1 / jnp.expm1(-x))

if not jnp.isnan(grad_no_custom):
assert_array_almost_equal(grad_custom, grad_no_custom)


@pytest.mark.parametrize(
"a, b", [(-20.0, -35.0), (-0.32523, -0.34), (20.53, 19.035), (8.032, 7.032)]
)
def test_logdiffexp_grads(a, b):
check_grads(logdiffexp, (a, b), order=3, rtol=0.01)


@pytest.mark.parametrize(
"a, b, expected",
[
(
jnp.array([jnp.inf, 0, 6.5, 4.99999, -jnp.inf]),
jnp.array([5, 0, 6.5, 5, -jnp.inf]),
jnp.array([jnp.nan, -jnp.inf, -jnp.inf, jnp.nan, -jnp.inf]),
),
(jnp.inf, 0.3532, jnp.nan),
(0, 0, -jnp.inf),
(-jnp.inf, -jnp.inf, -jnp.inf),
(5.6, 5.6, -jnp.inf),
(1e34, 1e34 / 0.9999, jnp.nan),
],
)
def test_logdiffexp_bounds_handling(a, b, expected):
"""
Test bounds handling for logdiffexp.
logdiffexp(jnp.inf, anything) should be nan,
logdiffexp(a, b) for a < b should be nan, even if numbers
are very close.
logdiffexp(a, b) for a == b should be -jnp.inf
even if a == b == -jnp.inf (log(0 - 0))
"""
assert_array_equal(logdiffexp(a, b), expected)


@pytest.mark.parametrize(
"a, b", [(jnp.array([53, 23.532, 8, -1.35]), jnp.array([56, -63.2, 2, -5.32]))]
)
def test_logdiffexp_agrees_with_basic(a, b):
"""
logdiffexp should agree with a basic implementation
for values at which the basic implementation is stable.
"""
assert_array_almost_equal(logdiffexp(a, b), jnp.log(jnp.exp(a) - jnp.exp(b)))


@pytest.mark.parametrize("a, b", [(500, 499), (-499, -500), (500, 500)])
def test_logdiffexp_stable(a, b):
"""
logdiffexp should be numerically stable at values
where the basic implementation is not.
"""

def basic(a, b):
return jnp.log(jnp.exp(a) - jnp.exp(b))

if a > 0 or a == b:
assert jnp.isnan(basic(a, b))
else:
assert basic(a, b) == -jnp.inf
result = logdiffexp(a, b)
assert not jnp.isnan(result)
if not a == b:
assert result < a
else:
assert result == -jnp.inf


@pytest.mark.parametrize(
"p, shape",
[
Expand Down

0 comments on commit 93e11c2

Please sign in to comment.