Skip to content

Commit

Permalink
dwelltime: improve gradient computation
Browse files Browse the repository at this point in the history
this avoids near division by zeros that can occur when an amplitude is close to zero
  • Loading branch information
JoepVanlier committed Jan 8, 2025
1 parent 78a58f4 commit 22a8300
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
* Fixed a bug where bead edge determination could fail with an unhandled exception during background estimation. This raised a `np.linalg.LinAlgError` when determining the background failed rather than the expected `RuntimeError`. In this case, a simple median is used as a fallback option.
* Fix a bug to ensure that [`lk.GaussianMixtureModel`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.GaussianMixtureModel.html) can also be used with a single state.
* Fixed bug that prevented opening the force distance widgets when using them with the `widget` backend on `matplotlib >= 3.9.0`.
* Prevent near `0/0` during fitting when components of a [`DwelltimeModel`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.DwelltimeModel.html) are near zero. Note that these only occurred during the computation of the model derivatives during the fitting procedure and should not impact the model simulation itself.

## v1.5.3 | 2024-10-29

Expand Down
26 changes: 16 additions & 10 deletions lumicks/pylake/population/dwelltime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,12 +1069,11 @@ def _exponential_mixture_log_likelihood_jacobian(params, t, t_min, t_max, t_step
components = np.log(norm_factor) + np.log(amplitudes) + -np.log(lifetimes) - t / lifetimes

# The derivative of logsumexp is given by: sum(exp(fi(x)) dfi(x)/dx) / sum(exp(fi(x)))
total_denom = np.exp(scipy.special.logsumexp(components, axis=0))
sum_components = np.sum(np.exp(components), axis=0)
dtotal_damp = (sum_components * dlognorm_damp + np.exp(components) * dlogamp_damp) / total_denom
dtotal_dtau = (
sum_components * dlognorm_dtau + np.exp(components) * dlogtauterm_dtau
) / total_denom
log_sum_exp_components = scipy.special.logsumexp(components, axis=0)
normalized_exp_components = np.exp(components - log_sum_exp_components)
dtotal_damp = dlognorm_damp + normalized_exp_components * dlogamp_damp
dtotal_dtau = dlognorm_dtau + normalized_exp_components * dlogtauterm_dtau

unsummed_gradient = np.vstack((dtotal_damp, dtotal_dtau))

return -np.sum(unsummed_gradient, axis=1)
Expand Down Expand Up @@ -1264,6 +1263,8 @@ def _handle_amplitude_constraint(

def _exponential_mle_bounds(n_components, min_observation_time, max_observation_time):
return (
# Note: the standard error computation relies on the lower bound on the amplitude as it
# keeps the amplitude from going negative.
*[(1e-9, 1.0 - 1e-9) for _ in range(n_components)],
*[
(
Expand All @@ -1276,7 +1277,9 @@ def _exponential_mle_bounds(n_components, min_observation_time, max_observation_


def _calculate_std_errs(jac_fun, constraints, num_free_amps, current_params, fitted_param_mask):
hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-6)
# The minimum bound on amplitudes is 1e-9, by making the max step 1e-10, we ensure that
# we never go over the bound here
hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-10)

if constraints:
from scipy.linalg import null_space
Expand Down Expand Up @@ -1418,9 +1421,12 @@ def jac_fun(params):

std_errs = np.full(current_params.shape, np.nan)
if use_jacobian:
std_errs[fitted_param_mask] = _calculate_std_errs(
jac_fun, constraints, num_free_amps, current_params, fitted_param_mask
)
try:
std_errs[fitted_param_mask] = _calculate_std_errs(
jac_fun, constraints, num_free_amps, current_params, fitted_param_mask
)
except np.linalg.linalg.LinAlgError:
pass # We silence these until the standard error API is publicly available

return current_params, -result.fun, std_errs

Expand Down
30 changes: 20 additions & 10 deletions lumicks/pylake/population/tests/test_dwelltimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import matplotlib.pyplot as plt

from lumicks.pylake import DwelltimeModel
from lumicks.pylake.detail.utilities import temp_seed
from lumicks.pylake.population.dwelltime import (
DwelltimeBootstrap,
_exponential_mle_optimize,
Expand Down Expand Up @@ -255,16 +256,17 @@ def test_dwelltime_profiles(exponential_data, exp_name, reference_bounds, reinte

@pytest.mark.parametrize(
# fmt:off
"exp_name, n_components, ref_std_errs",
"exp_name, n_components, ref_std_errs, tolerance",
[
("dataset_2exp", 1, [np.nan, 0.117634]), # Amplitude is not fitted!
("dataset_2exp", 2, [0.072455, 0.072456, 0.212814, 0.449388]),
("dataset_2exp_discrete", 2, [0.068027, 0.068027, 0.21403 , 0.350355]),
("dataset_2exp_discrete", 3, [0.097556, 0.380667, 0.395212, 0.252004, 1.229997, 4.500617]),
("dataset_2exp_discrete", 4, [9.755185e-02, 4.999662e-05, 3.788707e-01, 3.934488e-01, 2.520029e-01, 1.889606e+00, 1.227551e+00, 4.489603e+00]),
("dataset_2exp", 1, [np.nan, 0.117634], 1e-4), # Amplitude is not fitted!
("dataset_2exp", 2, [0.072451, 0.072451, 0.212818, 0.4493979], 1e-3),
("dataset_2exp_discrete", 2, [0.068016, 0.068016, 0.213995 , 0.350305], 1e-3),
# Over-fitted, hence coarse tolerances
("dataset_2exp_discrete", 3, [0.09755, 0.37899, 0.3933, 0.25203, 1.22622, 4.4805], 1e-1),
("dataset_2exp_discrete", 4, [0.097569, 3.5477e-05, 0.37874, 0.393413, 0.25203, 1.5543, 1.2278, 4.48905], 1e-1),
]
)
def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs):
def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs, tolerance):
dataset = exponential_data[exp_name]

fit = DwelltimeModel(
Expand All @@ -273,9 +275,9 @@ def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs):
**dataset["parameters"].observation_limits,
discretization_timestep=dataset["parameters"].dt,
)
np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=1e-4)
np.testing.assert_allclose(fit._err_amplitudes, ref_std_errs[:n_components], rtol=1e-4)
np.testing.assert_allclose(fit._err_lifetimes, ref_std_errs[n_components:], rtol=1e-4)
np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=tolerance)
np.testing.assert_allclose(fit._err_amplitudes, ref_std_errs[:n_components], rtol=tolerance)
np.testing.assert_allclose(fit._err_lifetimes, ref_std_errs[n_components:], rtol=tolerance)


@pytest.mark.parametrize("n_components", [2, 1])
Expand Down Expand Up @@ -677,3 +679,11 @@ def quick_fit(fixed_params):

with pytest.raises(StopIteration):
quick_fit([True, False]) # Lifetime unknown -> Need to fit


def test_bad_std_err():
"""This tests a case where one amplitude is specifically expected to collapse. If not handled
appropriately, this will spit out many errors."""
with temp_seed(10):
dwells = np.hstack((np.random.exponential(1, 10000), 5000))
_ = DwelltimeModel(dwells, n_components=2)

0 comments on commit 22a8300

Please sign in to comment.