Skip to content

Commit

Permalink
Updated utils.weight_stack
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewHerzing committed Sep 23, 2024
1 parent 9d6449c commit 06732c4
Showing 1 changed file with 100 additions and 71 deletions.
171 changes: 100 additions & 71 deletions etspy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def register_serialem_stack(stack, ncpus=1):
return reg


def weight_stack(stack, accuracy="medium"):
def weight_stack(stack, accuracy='medium'):
"""
Apply a weighting window to a stack along the direction perpendicular to the tilt axis.
Expand All @@ -127,83 +127,112 @@ def weight_stack(stack, accuracy="medium"):
tomography, Advanced Structural and Chemical Imaging vol. 1 (2015) pp 1-11.
https://doi.org/10.1186/s40679-015-0005-7
Parameters
Parameters:
----------
stack : TomoStack
Stack to be weighted.
The stack to be weighted.
accuracy : str, optional
A string indicating the accuracy level for weighting. Options are:
'low', 'medium', 'high', or any other string for default. Default is 'medium'.
accuracy : string
Level of accuracy for determining the weighting. Acceptable values are 'low', 'medium', and 'high'.
Returns
----------
reg : TomoStack object
Result of aligning and averaging frames at each tilt with shape [ntilts, ny, nx]
Returns:
-------
stackw : object
The weighted version of the input stack.
"""
stackw = stack.deepcopy()

[ntilts, ny, nx] = stack.data.shape
alpha = np.sum(stack.data, (1, 2)).min()
beta = np.sum(stack.data, (1, 2)).argmin()
v = np.arange(ntilts)
v[beta] = 0

wg = np.zeros([ny, nx])

if accuracy.lower() == "low":
num = 800
delta = 0.025
elif accuracy.lower() == "medium":
num = 2000
delta = 0.01
elif accuracy.lower() == "high":
num = 20000
delta = 0.001

# Set the parameters based on the accuracy input
if isinstance(accuracy, str):
if accuracy == 'low':
niterations = 800
delta = 0.025
elif accuracy == 'medium':
niterations = 2000
delta = 0.01
elif accuracy == 'high':
niterations = 20000
delta = 0.001
else:
raise ValueError("Unknown accuracy level. Must be 'low', 'medium', or 'high'.")
else:
raise ValueError(
"Unknown accuracy level. Must be 'low', 'medium', or 'high'.")

r = np.arange(1, ny + 1)
r = 2 / (ny - 1) * (r - 1) - 1
r = np.cos(np.pi * r**2) / 2 + 1 / 2
s = np.zeros(ntilts)
for p in range(1, int(num / 10) + 1):
rp = r ** (p * delta * 10)
for x in range(0, nx):
wg[:, x] = rp
for i in range(0, ntilts):
if v[i]:
if np.sum(stack.data[i, :, :] * wg) < alpha:
v[i] = 0
s[i] = (p - 1) * 10
if v.sum() == 0:
raise ValueError("Unknown accuracy level. Must be 'low', 'medium', or 'high'.")

weighted_stack = stack.deepcopy()

# Get stack dimensions
ntilts, ny, nx = weighted_stack.data.shape

# Compute the minimum total projected mass and the corresponding slice index (min_slice)
min_mass, min_slice = np.min(np.sum(np.sum(weighted_stack.data, axis=2), axis=1)), np.argmin(np.sum(np.sum(weighted_stack.data, axis=2), axis=1))

# Initialize the window array
window = np.zeros([ny, nx])

# Initialize the status vector (1 means unmarked, 0 means marked) and mark the reference slice (min_slice)
status = np.ones(ntilts)
status[min_slice] = 0

# Generate the weighting profile `r` based on a non-linear cosine function
r = np.arange(ny)
r = 2 / (ny - 1) * r - 1
r = np.cos(np.pi * r**2) / 2 + 0.5

# Initialize adjustment factors for each slice
adjustments = np.zeros(ntilts)

# Coarse adjustment loop
# In this step, the applied window is made increasingly restrictive in 10 pixel increments.
# Whenever the the windowed mass of a projection drops below the value of min_alpha, that projection
# is marked and the window restriction is not carried any further for that projection.

for power in np.linspace(10, niterations, niterations // 10):
# Compute the power-weighted profile for the current iteration
r_power = r ** (power * delta)
window = r_power[:, np.newaxis] # Broadcasting across all columns

# Compute the weighted sum for all slices at once using vectorization
weighted_mass = np.sum(weighted_stack.data * window[np.newaxis, :, :], axis=(1, 2))

# Update the status and adjustments for slices with weighted sums below min_mass
update_mask = (status != 0) & (weighted_mass < min_mass)
status[update_mask] = 0
adjustments[update_mask] = power - 10

# Break early if all slices are marked
if not np.any(status): # More efficient than np.sum(status)
break
for i in range(0, ntilts):
if v[i]:
s[i] = (p - 1) * 10

v = np.arange(1, ntilts + 1)
v[beta] = 0
for j in range(0, ntilts):
if j != beta:
for p in range(1, 10):
rp = r ** ((p + s[j]) * delta)
for x in range(0, nx):
wg[:, x] = rp
if np.sum(stack.data[i, :, :] * wg) < alpha:
s[j] = p + s[j]
v[i] = 0
break
for i in range(0, ntilts):
if v[i]:
s[i] = s[i] + 10

for i in range(0, ntilts):
for x in range(0, nx):
wg[:, x] = r ** (s[i] * delta)
stackw.data[i, :, :] = stack.data[i, :, :] * wg
return stackw

# Set window for any unmarked slices to the most restricive used in the rest of the slices
adjustments[np.where(status != 0)] = power - 10

# Fine adjustment loop
# In this step the severity of the window is calculated again using the value calculated in the coarse
# step and the window is made more restrictive in 1 pixel increments.
status = np.ones(ntilts)
status[min_slice] = 0

for j in range(ntilts):
if j != min_slice:
for power in np.linspace(1, 10, 10):
# Apply fine adjustments to the weight profile and update the weight grid
r_power = r**((power + adjustments[j]) * delta)
window[:] = r_power[:, np.newaxis]

if np.sum(weighted_stack.data[j, :, :] * window) < min_mass:
adjustments[j] = (power - 1) + adjustments[j]
status[j] = 0
break

# Restrict the window of any unmarked projections
adjustments[status != 0] += 10

# Apply the final window to the entire stack
for i in range(ntilts):
window[:] = (r**(adjustments[i] * delta))[:, np.newaxis]
weighted_stack.data[i, :, :] *= window

return weighted_stack


def calc_EST_angles(N):
Expand Down

0 comments on commit 06732c4

Please sign in to comment.