Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup position unit #733

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions lumicks/pylake/kymo.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def plot(
axes.set_axis_off()

if scale_bar and not image_handle:
scale_bar._attach_scale_bar(axes, 60.0, 1.0, "s", self._calibration.unit_label)
scale_bar._attach_scale_bar(axes, 60.0, 1.0, "s", self._calibration.unit.label)

image = self._get_plot_data(channel, adjustment)

Expand Down Expand Up @@ -410,7 +410,7 @@ def plot(
**{**default_kwargs, **kwargs},
)
axes.set_xlabel("time (s)")
axes.set_ylabel(f"position ({self._calibration.unit_label})")
axes.set_ylabel(f"position ({self._calibration.unit.label})")
if show_title:
axes.set_title(self.name)

Expand Down Expand Up @@ -1107,6 +1107,7 @@ class PositionUnit(Enum):
um = UnitInfo(name="um", label=r"μm")
kbp = UnitInfo(name="kbp", label="kbp")
pixel = UnitInfo(name="pixel", label="pixels")
au = UnitInfo(name="au", label="au")

def __str__(self):
return self.value.name
Expand All @@ -1118,6 +1119,18 @@ def __hash__(self):
def label(self):
return self.value.label

def get_diffusion_labels(self) -> dict:
return {
"unit": f"{self}^2 / s",
"_unit_label": f"{self.label}²/s",
}

def get_squared_labels(self) -> dict:
return {
"unit": f"{self}^2",
"_unit_label": f"{self.label}²",
}


@dataclass(frozen=True)
class PositionCalibration:
Expand All @@ -1141,10 +1154,6 @@ def to_pixels(self, calibrated):
def pixelsize(self):
return np.abs(self.scale)

@property
def unit_label(self):
return self.unit.label

def downsample(self, factor):
return (
self
Expand Down
28 changes: 8 additions & 20 deletions lumicks/pylake/kymotracker/detail/msd_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import numpy.typing as npt

from ...kymo import PositionUnit


@dataclass(frozen=True)
class DiffusionEstimate:
Expand Down Expand Up @@ -269,9 +271,7 @@ def calculate_msd(frame_idx, position, max_lag):
return frame_lags, msd_estimates


def calculate_ensemble_msd(
line_msds, time_step, unit="au", unit_label="au", min_count=2
) -> EnsembleMSD:
def calculate_ensemble_msd(line_msds, time_step, unit=PositionUnit.au, min_count=2) -> EnsembleMSD:
"""Calculate ensemble MSDs.

Parameters
Expand Down Expand Up @@ -305,9 +305,8 @@ def calculate_ensemble_msd(
variance=variance,
counts=counts,
effective_sample_size=effective_sample_size,
unit=f"{unit}^2",
_time_step=time_step,
_unit_label=f"{unit_label}²",
**unit.get_squared_labels(),
)


Expand Down Expand Up @@ -502,13 +501,7 @@ def fallback(warning_message):


def estimate_diffusion_constant_simple(
frame_idx,
coordinate,
time_step,
max_lag,
method,
unit="au",
unit_label="au",
frame_idx, coordinate, time_step, max_lag, method, unit=PositionUnit.au
):
r"""Estimate diffusion constant

Expand Down Expand Up @@ -616,8 +609,7 @@ def estimate_diffusion_constant_simple(
num_points=len(coordinate),
localization_variance=intercept / 2.0,
method=method,
unit=unit,
_unit_label=unit_label,
**unit.get_diffusion_labels(),
)


Expand Down Expand Up @@ -994,7 +986,6 @@ def estimate_diffusion_cve(
dt,
blur_constant,
unit,
unit_label,
localization_var=None,
var_of_localization_var=None,
) -> DiffusionEstimate:
Expand Down Expand Up @@ -1034,9 +1025,8 @@ def estimate_diffusion_cve(
num_points=len(coordinate),
localization_variance=localization_var,
method="cve",
unit=unit,
_unit_label=unit_label,
variance_of_localization_variance=var_of_localization_var,
**unit.get_diffusion_labels(),
)


Expand Down Expand Up @@ -1159,14 +1149,12 @@ def ensemble_ols(kymotracks, max_lag):
time_step = kymotracks._kymos[0].line_time_seconds
to_time = 1.0 / (2.0 * time_step)

src_calibration = kymotracks._kymos[0]._calibration
return DiffusionEstimate(
value=slope * to_time,
std_err=np.sqrt(var_slope / np.mean(ensemble_msd.effective_sample_size)) * to_time,
num_lags=optimal_lags,
num_points=sum(len(t) for t in kymotracks),
localization_variance=intercept / 2.0,
method="ensemble ols",
unit=f"{src_calibration.unit}^2 / s",
_unit_label=f"{src_calibration.unit_label}²/s",
**kymotracks._calibration.unit.get_diffusion_labels(),
)
27 changes: 11 additions & 16 deletions lumicks/pylake/kymotracker/kymotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def export_kymotrackgroup_to_csv(
)

time_units = "seconds"
position_units = kymotrack_group._calibration_info.unit
position_units = kymotrack_group._calibration.unit

idx = np.hstack([np.full(len(track), idx) for idx, track in enumerate(kymotrack_group)])
coords_idx = np.hstack([track.coordinate_idx for track in kymotrack_group])
Expand Down Expand Up @@ -610,7 +610,7 @@ def plot_fit(self, node_idx, *, fit_kwargs=None, data_kwargs=None, show_data=Tru
)

plt.plot(*model_fit, **{"color": "C0"} | replace_key_aliases(fit_kwargs or {}, aliases))
plt.xlabel(f"Position [{self._kymo._calibration.unit_label}]")
plt.xlabel(f"Position [{self._kymo._calibration.unit.label}]")
plt.ylabel("Photon counts [#]")

def _check_ends_are_defined(self):
Expand Down Expand Up @@ -844,7 +844,7 @@ def plot(self, *, show_outline=True, show_labels=True, axes=None, **kwargs):
ax.plot(self.seconds, self.position, path_effects=[pe.Normal()], **kwargs)

if show_labels:
ax.set_ylabel(f"position ({self._kymo._calibration.unit_label})")
ax.set_ylabel(f"position ({self._kymo._calibration.unit.label})")
ax.set_xlabel("time (s)")

def msd(self, max_lag=None):
Expand Down Expand Up @@ -913,7 +913,7 @@ def plot_msd(self, max_lag=None, **kwargs):
lag_time, msd = self.msd(max_lag)
plt.plot(lag_time, msd, **kwargs)
plt.xlabel("Lag time [s]")
plt.ylabel(f"Mean Squared Displacement [{self._kymo._calibration.unit_label}²]")
plt.ylabel(f"Mean Squared Displacement [{self._kymo._calibration.unit.label}²]")

def estimate_diffusion(
self,
Expand Down Expand Up @@ -1039,10 +1039,6 @@ def estimate_diffusion(
)

frame_idx, positions = np.array(self.time_idx, dtype=int), np.array(self.position)
unit_labels = {
"unit": f"{self._kymo._calibration.unit}^2 / s",
"unit_label": f"{self._kymo._calibration.unit_label}²/s",
}

if method == "cve":
try:
Expand All @@ -1068,7 +1064,7 @@ def estimate_diffusion(
frame_idx,
positions,
self._line_time_seconds,
**unit_labels,
unit=self._kymo._calibration.unit,
blur_constant=blur,
localization_var=localization_variance,
var_of_localization_var=variance_of_localization_variance,
Expand All @@ -1095,7 +1091,7 @@ def estimate_diffusion(
self._line_time_seconds,
max_lag,
method,
**unit_labels,
unit=self._kymo._calibration.unit,
)


Expand Down Expand Up @@ -1216,7 +1212,7 @@ def _validate_single_linetime_pixelsize(self):
if len(pixel_sizes) == 1
else (
"All source kymographs must have the same pixel sizes, "
f"got {sorted(pixel_sizes)} {self._calibration_info.unit}."
f"got {sorted(pixel_sizes)} {self._calibration.unit}."
)
)

Expand Down Expand Up @@ -1281,7 +1277,7 @@ def _channel(self):
raise RuntimeError("No channel associated with this empty group (no tracks available)")

@property
def _calibration_info(self):
def _calibration(self):
try:
kymo = self._kymos[0]
return kymo._calibration
Expand Down Expand Up @@ -1473,7 +1469,7 @@ def plot(self, *, show_outline=True, show_labels=True, axes=None, **kwargs):
track.plot(show_outline=show_outline, show_labels=False, axes=ax, **kwargs)

if show_labels:
ax.set_ylabel(f"position ({self._calibration_info.unit_label})")
ax.set_ylabel(f"position ({self._calibration.unit.label})")
ax.set_xlabel("time (s)")

def _tracks_in_frame(self, frame_idx):
Expand Down Expand Up @@ -1902,7 +1898,7 @@ def plot_binding_histogram(self, kind, bins=10, **kwargs):
widths = np.diff(edges)
plt.bar(edges[:-1], counts, width=widths, align="edge", **kwargs)
plt.ylabel("Counts")
plt.xlabel(f"Position ({self._calibration_info.unit_label})")
plt.xlabel(f"Position ({self._calibration.unit.label})")

def _histogram_binding_profile(self, n_time_bins, bandwidth, n_position_points, roi=None):
"""Calculate a Kernel Density Estimate (KDE) of binding density along the tether for time bins.
Expand Down Expand Up @@ -2149,6 +2145,5 @@ def ensemble_msd(self, max_lag=None, min_count=2) -> EnsembleMSD:
line_msds=track_msds,
time_step=self._kymos[0].line_time_seconds,
min_count=min_count,
unit=self._calibration_info.unit,
unit_label=self._calibration_info.unit_label,
unit=self._calibration.unit,
)
Loading
Loading