Skip to content

Commit

Permalink
Merge pull request #32 from hanjinliu/add-methods
Browse files Browse the repository at this point in the history
Fix performance of lazy filter, add more lazy methods
  • Loading branch information
hanjinliu authored Jan 8, 2025
2 parents ad68584 + 54062fc commit 3011c63
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 110 deletions.
13 changes: 0 additions & 13 deletions Makefile

This file was deleted.

2 changes: 1 addition & 1 deletion impy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.4.6"
__version__ = "2.4.7"
__author__ = "Hanjin Liu"
__email__ = "[email protected]"

Expand Down
4 changes: 3 additions & 1 deletion impy/arrays/_utils/_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def subpixel_pcc(
f0: np.ndarray,
f1: np.ndarray,
upsample_factor: int,
max_shifts: tuple[float, ...] | None = None,
max_shifts: float | tuple[float, ...] | None = None,
):
power, product = _pcc_and_power_spec(f0, f1, max_shifts)
maxima = xp.unravel_index(xp.argmax(power), power.shape)
Expand All @@ -60,6 +60,8 @@ def subpixel_pcc(
power = abs2(cross_correlation)

if max_shifts is not None:
if np.isscalar(max_shifts):
max_shifts = (max_shifts,) * f0.ndim
max_shifts = xp.asarray(max_shifts)
_upsampled_left_shifts = (shifts + max_shifts) * upsample_factor
_upsampled_right_shifts = (max_shifts - shifts) * upsample_factor
Expand Down
17 changes: 14 additions & 3 deletions impy/arrays/_utils/_skimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,19 @@

# same as the function in skimage.filters._fft_based (available in scikit-image >= 0.19)
@lru_cache(maxsize=4)
def _get_ND_butterworth_filter(shape: tuple[int, ...], cutoff: float, order: int,
high_pass: bool, real: bool):
def _get_ND_butterworth_filter(
shape: tuple[int, ...],
cutoff: float,
order: int,
high_pass: bool,
real: bool,
):
if cutoff == 0:
if high_pass:
wfilt = np.ones(shape, dtype=np.float32)
else:
wfilt = np.zeros(shape, dtype=np.float32)
return wfilt
ranges = []
for d, fc in zip(shape, cutoff):
axis = np.arange(-(d - 1) // 2, (d - 1) // 2 + 1, dtype=np.float32) / (d*fc)
Expand All @@ -19,4 +30,4 @@ def _get_ND_butterworth_filter(shape: tuple[int, ...], cutoff: float, order: int
wfilt = 1 / (1 + q2**order)
if high_pass:
wfilt = 1 - wfilt
return wfilt
return wfilt
22 changes: 16 additions & 6 deletions impy/arrays/imgarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,11 @@ def shift(
cval = cval(self.value)

prefilter = prefilter or order > 1
mtx = _transform.compose_affine_matrix(translation=translation)

return self._apply_dask(
_transform.warp,
_transform.shift,
c_axes=complement_axes(dims, self.axes),
kwargs=dict(matrix=mtx, order=order, mode=mode, cval=cval,
kwargs=dict(shift=translation, order=order, mode=mode, cval=cval,
prefilter=prefilter)
)

Expand Down Expand Up @@ -3986,6 +3985,7 @@ def track_drift(
along: AxisLike | None = None,
show_drift: bool = False,
upsample_factor: int = 10,
max_shift: nDFloat | None = None,
) -> MarkerFrame:
"""
Calculate yx-directional drift using the method equivalent to
Expand All @@ -3999,6 +3999,8 @@ def track_drift(
If True, plot the result.
upsample_factor : int, default is 10
Up-sampling factor when calculating phase cross correlation.
max_shift : tuple of float, optional
Maximum shift in spatial directions.
Returns
-------
Expand All @@ -4013,7 +4015,9 @@ def track_drift(
raise ValueError("`along` must be single character.")
if not isinstance(upsample_factor, int):
raise TypeError(f"upsample-factor must be integer but got {type(upsample_factor)}")

if max_shift is not None:
if np.isscalar(max_shift):
max_shift = np.full(self.ndim - 1, max_shift)
result = np.zeros((self.sizeof(along), self.ndim-1), dtype=np.float32)
c_axes = complement_axes(along, self.axes)
last_img = None
Expand All @@ -4022,7 +4026,12 @@ def track_drift(
img = xp.asarray(img)
if last_img is not None:
result[i] = xp.asnumpy(
_corr.subpixel_pcc(last_img, img, upsample_factor=upsample_factor)[0]
_corr.subpixel_pcc(
last_img,
img,
max_shifts=max_shift,
upsample_factor=upsample_factor,
)[0]
)
last_img = img
else:
Expand All @@ -4047,6 +4056,7 @@ def drift_correction(
*,
zero_ave: bool = True,
along: AxisLike | None = None,
max_shift: nDFloat | None = None,
order: int = 1,
mode: str = "constant",
cval: float = 0,
Expand Down Expand Up @@ -4109,7 +4119,7 @@ def drift_correction(
)
return out

shift = ref.track_drift(along=along).values
shift = ref.track_drift(along=along, max_shift=max_shift).values

else:
shift = np.asarray(shift, dtype=np.float32)
Expand Down
Loading

0 comments on commit 3011c63

Please sign in to comment.