From 06732c4ab6a2fb9291a4e99fa6e6f1a910635cad Mon Sep 17 00:00:00 2001 From: Andrew Herzing Date: Mon, 23 Sep 2024 16:10:12 -0400 Subject: [PATCH] Updated utils.weight_stack --- etspy/utils.py | 171 +++++++++++++++++++++++++++++-------------------- 1 file changed, 100 insertions(+), 71 deletions(-) diff --git a/etspy/utils.py b/etspy/utils.py index c515fe74..8a213c91 100644 --- a/etspy/utils.py +++ b/etspy/utils.py @@ -116,7 +116,7 @@ def register_serialem_stack(stack, ncpus=1): return reg -def weight_stack(stack, accuracy="medium"): +def weight_stack(stack, accuracy='medium'): """ Apply a weighting window to a stack along the direction perpendicular to the tilt axis. @@ -127,83 +127,112 @@ def weight_stack(stack, accuracy="medium"): tomography, Advanced Structural and Chemical Imaging vol. 1 (2015) pp 1-11. https://doi.org/10.1186/s40679-015-0005-7 - Parameters + Parameters: ---------- stack : TomoStack - Stack to be weighted. + The stack to be weighted. + accuracy : str, optional + A string indicating the accuracy level for weighting. Options are: + 'low', 'medium', 'high', or any other string for default. Default is 'medium'. - accuracy : string - Level of accuracy for determining the weighting. Acceptable values are 'low', 'medium', and 'high'. - - Returns - ---------- - reg : TomoStack object - Result of aligning and averaging frames at each tilt with shape [ntilts, ny, nx] + Returns: + ------- + stackw : object + The weighted version of the input stack. """ - stackw = stack.deepcopy() - - [ntilts, ny, nx] = stack.data.shape - alpha = np.sum(stack.data, (1, 2)).min() - beta = np.sum(stack.data, (1, 2)).argmin() - v = np.arange(ntilts) - v[beta] = 0 - - wg = np.zeros([ny, nx]) - - if accuracy.lower() == "low": - num = 800 - delta = 0.025 - elif accuracy.lower() == "medium": - num = 2000 - delta = 0.01 - elif accuracy.lower() == "high": - num = 20000 - delta = 0.001 + + # Set the parameters based on the accuracy input + if isinstance(accuracy, str): + if accuracy == 'low': + niterations = 800 + delta = 0.025 + elif accuracy == 'medium': + niterations = 2000 + delta = 0.01 + elif accuracy == 'high': + niterations = 20000 + delta = 0.001 + else: + raise ValueError("Unknown accuracy level. Must be 'low', 'medium', or 'high'.") else: - raise ValueError( - "Unknown accuracy level. Must be 'low', 'medium', or 'high'.") - - r = np.arange(1, ny + 1) - r = 2 / (ny - 1) * (r - 1) - 1 - r = np.cos(np.pi * r**2) / 2 + 1 / 2 - s = np.zeros(ntilts) - for p in range(1, int(num / 10) + 1): - rp = r ** (p * delta * 10) - for x in range(0, nx): - wg[:, x] = rp - for i in range(0, ntilts): - if v[i]: - if np.sum(stack.data[i, :, :] * wg) < alpha: - v[i] = 0 - s[i] = (p - 1) * 10 - if v.sum() == 0: + raise ValueError("Unknown accuracy level. Must be 'low', 'medium', or 'high'.") + + weighted_stack = stack.deepcopy() + + # Get stack dimensions + ntilts, ny, nx = weighted_stack.data.shape + + # Compute the minimum total projected mass and the corresponding slice index (min_slice) + min_mass, min_slice = np.min(np.sum(np.sum(weighted_stack.data, axis=2), axis=1)), np.argmin(np.sum(np.sum(weighted_stack.data, axis=2), axis=1)) + + # Initialize the window array + window = np.zeros([ny, nx]) + + # Initialize the status vector (1 means unmarked, 0 means marked) and mark the reference slice (min_slice) + status = np.ones(ntilts) + status[min_slice] = 0 + + # Generate the weighting profile `r` based on a non-linear cosine function + r = np.arange(ny) + r = 2 / (ny - 1) * r - 1 + r = np.cos(np.pi * r**2) / 2 + 0.5 + + # Initialize adjustment factors for each slice + adjustments = np.zeros(ntilts) + + # Coarse adjustment loop + # In this step, the applied window is made increasingly restrictive in 10 pixel increments. + # Whenever the the windowed mass of a projection drops below the value of min_alpha, that projection + # is marked and the window restriction is not carried any further for that projection. + + for power in np.linspace(10, niterations, niterations // 10): + # Compute the power-weighted profile for the current iteration + r_power = r ** (power * delta) + window = r_power[:, np.newaxis] # Broadcasting across all columns + + # Compute the weighted sum for all slices at once using vectorization + weighted_mass = np.sum(weighted_stack.data * window[np.newaxis, :, :], axis=(1, 2)) + + # Update the status and adjustments for slices with weighted sums below min_mass + update_mask = (status != 0) & (weighted_mass < min_mass) + status[update_mask] = 0 + adjustments[update_mask] = power - 10 + + # Break early if all slices are marked + if not np.any(status): # More efficient than np.sum(status) break - for i in range(0, ntilts): - if v[i]: - s[i] = (p - 1) * 10 - - v = np.arange(1, ntilts + 1) - v[beta] = 0 - for j in range(0, ntilts): - if j != beta: - for p in range(1, 10): - rp = r ** ((p + s[j]) * delta) - for x in range(0, nx): - wg[:, x] = rp - if np.sum(stack.data[i, :, :] * wg) < alpha: - s[j] = p + s[j] - v[i] = 0 - break - for i in range(0, ntilts): - if v[i]: - s[i] = s[i] + 10 - - for i in range(0, ntilts): - for x in range(0, nx): - wg[:, x] = r ** (s[i] * delta) - stackw.data[i, :, :] = stack.data[i, :, :] * wg - return stackw + + # Set window for any unmarked slices to the most restricive used in the rest of the slices + adjustments[np.where(status != 0)] = power - 10 + + # Fine adjustment loop + # In this step the severity of the window is calculated again using the value calculated in the coarse + # step and the window is made more restrictive in 1 pixel increments. + status = np.ones(ntilts) + status[min_slice] = 0 + + for j in range(ntilts): + if j != min_slice: + for power in np.linspace(1, 10, 10): + # Apply fine adjustments to the weight profile and update the weight grid + r_power = r**((power + adjustments[j]) * delta) + window[:] = r_power[:, np.newaxis] + + if np.sum(weighted_stack.data[j, :, :] * window) < min_mass: + adjustments[j] = (power - 1) + adjustments[j] + status[j] = 0 + break + + # Restrict the window of any unmarked projections + adjustments[status != 0] += 10 + + # Apply the final window to the entire stack + for i in range(ntilts): + window[:] = (r**(adjustments[i] * delta))[:, np.newaxis] + weighted_stack.data[i, :, :] *= window + + return weighted_stack def calc_EST_angles(N):