Skip to content

Commit

Permalink
cleanup diffusion units
Browse files Browse the repository at this point in the history
  • Loading branch information
rpauszek committed Jan 30, 2025
1 parent 12330ae commit 4eb9cc5
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 39 deletions.
13 changes: 13 additions & 0 deletions lumicks/pylake/kymo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,7 @@ class PositionUnit(Enum):
um = UnitInfo(name="um", label=r"μm")
kbp = UnitInfo(name="kbp", label="kbp")
pixel = UnitInfo(name="pixel", label="pixels")
au = UnitInfo(name="au", label="au")

def __str__(self):
return self.value.name
Expand All @@ -1118,6 +1119,18 @@ def __hash__(self):
def label(self):
return self.value.label

def as_diffusion(self):
return {
"unit": f"{self}^2 / s",
"_unit_label": f"{self.label}²/s",
}

def as_squared(self):
return {
"unit": f"{self}^2",
"_unit_label": f"{self.label}²",
}


@dataclass(frozen=True)
class PositionCalibration:
Expand Down
17 changes: 7 additions & 10 deletions lumicks/pylake/kymotracker/detail/msd_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import numpy.typing as npt

from ...kymo import PositionUnit


@dataclass(frozen=True)
class DiffusionEstimate:
Expand Down Expand Up @@ -270,7 +272,7 @@ def calculate_msd(frame_idx, position, max_lag):


def calculate_ensemble_msd(
line_msds, time_step, unit="au", unit_label="au", min_count=2
line_msds, time_step, unit=PositionUnit.au, min_count=2
) -> EnsembleMSD:
"""Calculate ensemble MSDs.
Expand Down Expand Up @@ -305,9 +307,8 @@ def calculate_ensemble_msd(
variance=variance,
counts=counts,
effective_sample_size=effective_sample_size,
unit=f"{unit}^2",
_time_step=time_step,
_unit_label=f"{unit_label}²",
**unit.as_squared(),
)


Expand Down Expand Up @@ -507,8 +508,7 @@ def estimate_diffusion_constant_simple(
time_step,
max_lag,
method,
unit="au",
unit_label="au",
unit=PositionUnit.au
):
r"""Estimate diffusion constant
Expand Down Expand Up @@ -616,8 +616,7 @@ def estimate_diffusion_constant_simple(
num_points=len(coordinate),
localization_variance=intercept / 2.0,
method=method,
unit=unit,
_unit_label=unit_label,
**unit.as_diffusion(),
)


Expand Down Expand Up @@ -994,7 +993,6 @@ def estimate_diffusion_cve(
dt,
blur_constant,
unit,
unit_label,
localization_var=None,
var_of_localization_var=None,
) -> DiffusionEstimate:
Expand Down Expand Up @@ -1034,9 +1032,8 @@ def estimate_diffusion_cve(
num_points=len(coordinate),
localization_variance=localization_var,
method="cve",
unit=unit,
_unit_label=unit_label,
variance_of_localization_variance=var_of_localization_var,
**unit.as_diffusion()
)


Expand Down
9 changes: 2 additions & 7 deletions lumicks/pylake/kymotracker/kymotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,10 +1039,6 @@ def estimate_diffusion(
)

frame_idx, positions = np.array(self.time_idx, dtype=int), np.array(self.position)
unit_labels = {
"unit": f"{self._kymo._calibration.unit}^2 / s",
"unit_label": f"{self._kymo._calibration.unit.label}²/s",
}

if method == "cve":
try:
Expand All @@ -1068,7 +1064,7 @@ def estimate_diffusion(
frame_idx,
positions,
self._line_time_seconds,
**unit_labels,
unit=self._kymo._calibration.unit,
blur_constant=blur,
localization_var=localization_variance,
var_of_localization_var=variance_of_localization_variance,
Expand All @@ -1095,7 +1091,7 @@ def estimate_diffusion(
self._line_time_seconds,
max_lag,
method,
**unit_labels,
unit=self._kymo._calibration.unit,
)


Expand Down Expand Up @@ -2150,5 +2146,4 @@ def ensemble_msd(self, max_lag=None, min_count=2) -> EnsembleMSD:
time_step=self._kymos[0].line_time_seconds,
min_count=min_count,
unit=self._calibration.unit,
unit_label=self._calibration.unit.label,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import matplotlib.pyplot as plt

from lumicks.pylake.kymo import PositionUnit
from lumicks.pylake.detail.utilities import temp_seed
from lumicks.pylake.simulation.diffusion import _simulate_diffusion_1d
from lumicks.pylake.kymotracker.detail.msd_estimation import *
Expand Down Expand Up @@ -49,11 +50,23 @@ def test_estimate(frame_idx, coordinate, time_step, max_lag, diffusion_const):
time_step,
max_lag,
"ols",
"au",
PositionUnit.au,
)
np.testing.assert_allclose(float(diffusion_est), diffusion_const)


def test_bad_unit():
frame_idx = np.array([1, 2, 3, 4, 5])
coordinate = np.array([-1.0, 1.0, -1.0, -3.0, -5.0])
dt = 0.5
max_lag = 50

with pytest.raises(AttributeError, match="'str' object has no attribute 'as_diffusion'"):
estimate_diffusion_constant_simple(frame_idx, coordinate, dt, max_lag, "ols", unit="um")
with pytest.raises(AttributeError, match="'str' object has no attribute 'as_diffusion'"):
estimate_diffusion_cve(frame_idx, coordinate, dt, 0, unit="um")


def test_maxlag_asserts():
# Max_lag has to be bigger than 2
with pytest.raises(ValueError):
Expand Down Expand Up @@ -223,7 +236,7 @@ def test_diffusion_estimate_ols(
with temp_seed(0):
trace = _simulate_diffusion_1d(diffusion, num_points, time_step, obs_noise)
diffusion_est = estimate_diffusion_constant_simple(
np.arange(num_points), trace, time_step, max_lag, "ols", "mu^2/s", r"$\mu^2/s$"
np.arange(num_points), trace, time_step, max_lag, "ols", PositionUnit.au
)

np.testing.assert_allclose(float(diffusion_est), diff_est)
Expand All @@ -233,8 +246,8 @@ def test_diffusion_estimate_ols(
np.testing.assert_allclose(diffusion_est.std_err, std_err_est)
np.testing.assert_allclose(diffusion_est.localization_variance, loc_variance)
assert diffusion_est.method == "ols"
assert diffusion_est.unit == "mu^2/s"
assert diffusion_est._unit_label == r"$\mu^2/s$"
assert diffusion_est.unit == "au^2 / s"
assert diffusion_est._unit_label == "au²/s"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -268,7 +281,7 @@ def test_regression_ols_with_skipped_frames(

with pytest.warns(RuntimeWarning, match="Your tracks have missing frames"):
diffusion_est = estimate_diffusion_constant_simple(
frame_idx, trace, time_step, max_lag, "ols", "mu^2/s", r"$\mu^2/s$"
frame_idx, trace, time_step, max_lag, "ols", PositionUnit.au
)

np.testing.assert_allclose(float(diffusion_est), diff_est)
Expand Down Expand Up @@ -333,7 +346,7 @@ def test_diffusion_estimate_gls(
with temp_seed(0):
trace = _simulate_diffusion_1d(diffusion, num_points, time_step, obs_noise)
diffusion_est = estimate_diffusion_constant_simple(
np.arange(num_points), trace, time_step, max_lag, "gls", "mu^2/s", r"$\mu^2/s$"
np.arange(num_points), trace, time_step, max_lag, "gls", PositionUnit.au
)

np.testing.assert_allclose(float(diffusion_est), diff_est)
Expand All @@ -343,25 +356,27 @@ def test_diffusion_estimate_gls(
np.testing.assert_allclose(diffusion_est.std_err, std_err_est)
np.testing.assert_allclose(diffusion_est.localization_variance, loc_variance)
assert diffusion_est.method == "gls"
assert diffusion_est.unit == "mu^2/s"
assert diffusion_est._unit_label == r"$\mu^2/s$"
assert diffusion_est.unit == "au^2 / s"
assert diffusion_est._unit_label == "au²/s"


def test_bad_input():
with pytest.raises(ValueError, match="Invalid method selected."):
estimate_diffusion_constant_simple(np.arange(5), np.arange(5), 1, 2, "glo", "unit")
estimate_diffusion_constant_simple(np.arange(5), np.arange(5), 1, 2, "glo", PositionUnit.au)

with pytest.raises(
ValueError, match="You need at least two lags to estimate a diffusion constant"
):
estimate_diffusion_constant_simple(np.arange(5), np.arange(5), 1, 1, "gls", "unit")
estimate_diffusion_constant_simple(np.arange(5), np.arange(5), 1, 1, "gls", PositionUnit.au)


def test_singular_handling():
with temp_seed(0):
trace = _simulate_diffusion_1d(0, 30, 3, 0)
with pytest.warns(RuntimeWarning, match="Covariance matrix is singular"):
estimate_diffusion_constant_simple(np.arange(len(trace)), trace, 1, 3, "gls", "unit")
estimate_diffusion_constant_simple(
np.arange(len(trace)), trace, 1, 3, "gls", PositionUnit.au
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -474,8 +489,7 @@ def test_estimate_diffusion_cve(
trace,
time_step,
blur_constant,
"mu^2/s",
r"$\mu^2/s$",
PositionUnit.au,
localization_var,
var_of_localization_var,
)
Expand All @@ -493,8 +507,8 @@ def test_estimate_diffusion_cve(
else:
assert diffusion_est.variance_of_localization_variance is None
assert diffusion_est.method == "cve"
assert diffusion_est.unit == "mu^2/s"
assert diffusion_est._unit_label == r"$\mu^2/s$"
assert diffusion_est.unit == "au^2 / s"
assert diffusion_est._unit_label == "au²/s"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -646,7 +660,7 @@ def test_ensemble_msd():
]

# By default, the single lag rho (5) should be ignored
result = calculate_ensemble_msd(track_msds, 1.0, unit="what_a_unit", unit_label="label_ahoy")
result = calculate_ensemble_msd(track_msds, 1.0, unit=PositionUnit.um)
np.testing.assert_allclose(result.lags, frame_diffs)
np.testing.assert_allclose(result.msd, frame_diffs**2)
num_means = np.array([3, 3, 3, 2]) # number of means contributing to the estimate
Expand All @@ -655,8 +669,8 @@ def test_ensemble_msd():
# Tracks are equal length, so the effective sample size is just the means that contributed
np.testing.assert_allclose(result.effective_sample_size, num_means)
np.testing.assert_allclose(result.sem, np.sqrt(0.02 / ((num_means - 1) * num_means)))
assert result.unit == "what_a_unit^2"
assert result._unit_label == "label_ahoy²"
assert result.unit == "um^2"
assert result._unit_label == r"μm²"


def test_ensemble_msd_unequal_points():
Expand All @@ -676,6 +690,8 @@ def test_ensemble_msd_unequal_points():
# ESS is less than 2 since we used weighting
np.testing.assert_allclose(result.effective_sample_size, np.ones(5) * 9 / 5)
np.testing.assert_allclose(result.sem, np.ones(5) * np.sqrt(5 / 2))
assert result.unit == "au^2"
assert result._unit_label == "au²"


def test_ensemble_msd_little_data():
Expand All @@ -686,23 +702,23 @@ def test_ensemble_msd_little_data():
with pytest.raises(
ValueError, match="Need more than one average to compute a weighted variance"
):
calculate_ensemble_msd([trk1, trk1, trk2], 1.0, unit="au", unit_label="au", min_count=0)
calculate_ensemble_msd([trk1, trk1, trk2], 1.0, min_count=0)

for msds in ([trk1], []):
with pytest.raises(
ValueError, match="You need at least two tracks to compute the ensemble MSD"
):
calculate_ensemble_msd(msds, 1.0, unit="au", unit_label="au", min_count=0)
calculate_ensemble_msd(msds, 1.0, min_count=0)


def test_ensemble_msd_plot():
"""Test whether the plot spins up"""
frame_diffs = np.arange(1, 5, 1)
trk1 = [frame_diffs, frame_diffs**2, np.arange(len(frame_diffs), 0, -1)]
calculate_ensemble_msd([trk1, trk1, trk1], 1.0, unit="au", unit_label="label_unit").plot()
calculate_ensemble_msd([trk1, trk1, trk1], 1.0, PositionUnit.kbp).plot()
axis = plt.gca()
lines = axis.lines[0]
np.testing.assert_allclose(lines.get_xdata(), frame_diffs)
np.testing.assert_allclose(lines.get_ydata(), frame_diffs**2)
assert axis.xaxis.get_label().get_text() == "Time [s]"
assert axis.yaxis.get_label().get_text() == "Squared Displacement [label_unit²]"
assert axis.yaxis.get_label().get_text() == "Squared Displacement [kbp²]"

0 comments on commit 4eb9cc5

Please sign in to comment.