From 256057f8a15f7932315bbe68c370693644ccf1e1 Mon Sep 17 00:00:00 2001 From: Andrew Herzing Date: Thu, 12 Oct 2023 13:40:52 -0400 Subject: [PATCH] Fixed shift order conflict in stack_register --- tomotools/align.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tomotools/align.py b/tomotools/align.py index 471621aa..3ea12749 100644 --- a/tomotools/align.py +++ b/tomotools/align.py @@ -50,7 +50,7 @@ def apply_shifts(stack, shifts): for i in range(0, shifted.data.shape[0]): shifted.data[i, :, :] =\ ndimage.shift(shifted.data[i, :, :], - shift=[shifts[i, 1], shifts[i, 0]]) + shift=[shifts[i, 0], shifts[i, 1]]) shifted.metadata.Tomography.shifts = \ shifted.metadata.Tomography.shifts + shifts return shifted @@ -295,7 +295,7 @@ def calculate_shifts_com(stack, nslice, ratio): A[ntilts * j:ntilts * (j + 1), 0:ntilts] = Gam shifts = np.zeros([stack.data.shape[0], 2]) - shifts[:, 1] = np.dot(np.linalg.pinv(A), t_select)[:, 0] + shifts[:, 0] = np.dot(np.linalg.pinv(A), t_select)[:, 0] shifts = stack.metadata.Tomography.shifts + shifts return shifts @@ -319,7 +319,7 @@ def calculate_shifts_pc(stack, start, show_progressbar): """ def calc_pc(source, shifted): shift = pcc(shifted, source, upsample_factor=3) - return shift[0][::-1] + return shift[0] shifts = np.zeros([stack.data.shape[0] - 1, 2]) if start is None: @@ -356,7 +356,7 @@ def calculate_shifts_stackreg(stack): """ sr = StackReg(StackReg.TRANSLATION) shifts = sr.register_stack(stack.data, reference='previous') - shifts = -np.array([i[0:2, 2] for i in shifts]) + shifts = -np.array([i[0:2, 2][::-1] for i in shifts]) return shifts @@ -395,17 +395,17 @@ def calc_com_cl_shifts(stack, com_ref_index, cl_ref_index, cl_resolution, """ - def calc_xshifts(stack, com_ref=None): + def calc_yshifts(stack, com_ref=None): ntilts = stack.data.shape[0] aliX = stack.deepcopy() coms = np.zeros(ntilts) - xhifts = np.zeros_like(coms) + yshifts = np.zeros_like(coms) for i in tqdm.tqdm(range(0, ntilts)): im = aliX.data[i, :, :] coms[i], _ = ndimage.center_of_mass(im) - xhifts[i] = com_ref - coms[i] - return xhifts + yshifts[i] = com_ref - coms[i] + return yshifts if cl_resolution >= 0.5: raise ValueError("Resolution should be less than 0.5") @@ -415,8 +415,8 @@ def calc_xshifts(stack, com_ref=None): logger.info("Common line reference slice: %s" % cl_ref_index) xshifts = np.zeros(stack.data.shape[0]) yshifts = np.zeros(stack.data.shape[0]) - xshifts = calc_xshifts(stack, com_ref_index) - yshifts = calc_shifts_cl(stack, cl_ref_index, + yshifts = calc_yshifts(stack, com_ref_index) + xshifts = calc_shifts_cl(stack, cl_ref_index, cl_resolution, cl_div_factor) shifts = np.stack([yshifts, xshifts], axis=1) return shifts @@ -464,7 +464,7 @@ def align_stack(stack, method, start, show_progressbar, nslice, ratio, 3-D numpy array containing the tilt series data method : string Method by which to calculate the alignments. Valid options - are 'PC', 'COM', or 'COM-CL'. + are 'StackReg', 'PC', 'COM', or 'COM-CL'. start : integer Position in tilt series to use as starting point for the alignment. If None, the central projection is used.