Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saturation speedup #331

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/331.general.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Performance improvements for saturation step. (`#331
<https://github.com/spacetelescope/stcal/issues/331>`_)
t-brandt marked this conversation as resolved.
Show resolved Hide resolved
134 changes: 101 additions & 33 deletions src/stcal/saturation/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,64 +83,85 @@ def flag_saturated_pixels(

for ints in range(nints):
# Work forward through the groups for initial pass at saturation

# We want to flag saturation in all subsequent groups after
# the one in which it was found. Use this boolean array to
# keep a running tally of pixels that have saturated.
previously_saturated = np.zeros(shape=(nrows, ncols), dtype='bool')

for group in range(ngroups):
plane = data[ints, group, :, :]

flagarray, flaglowarray = plane_saturation(plane, sat_thresh, dqflags)

# for saturation, the flag is set in the current plane
# and all following planes.
np.bitwise_or(gdq[ints, group:, :, :], flagarray, gdq[ints, group:, :, :])

# Update the running tally of all pixels that have ever
# experienced saturation to account for this.

previously_saturated |= (plane >= sat_thresh)
flagarray = (previously_saturated * saturated).astype(np.uint32)

gdq[ints, group, :, :] |= flagarray

# for A/D floor, the flag is only set of the current plane
np.bitwise_or(gdq[ints, group, :, :], flaglowarray, gdq[ints, group, :, :])
flaglowarray = ((plane <= 0)*(ad_floor | dnu)).astype(np.uint32)

gdq[ints, group, :, :] |= flaglowarray

del flagarray
del flaglowarray

# now, flag any pixels that border saturated pixels (not A/D floor pix)
if n_pix_grow_sat > 0:
gdq_slice = copy.copy(gdq[ints, group, :, :]).astype(int)

gdq[ints, group, :, :] = adjacent_pixels(gdq_slice, saturated, n_pix_grow_sat)
gdq_slice = gdq[ints, group, :, :]
adjacent_pixels(gdq_slice, saturated, n_pix_grow_sat, inplace=True)

# Work backward through the groups for a second pass at saturation
# This is to flag things that actually saturated in prior groups but
# were not obvious because of group averaging
for group in range(ngroups-2, -1, -1):

for group in range(ngroups - 2, -1, -1):

plane = data[ints, group, :, :]
thisdq = gdq[ints, group, :, :]
nextdq = gdq[ints, group + 1, :, :]

# Determine the dilution factor due to group averaging

# No point in this step if the dilution factor is 1. In
# that case, there is no way that we would have missed
# saturation before but flag it now, since the threshold
# would be the same.

if read_pattern is not None:
# Single value dilution factor for this group
dilution_factor = np.mean(read_pattern[group]) / read_pattern[group][-1]
if dilution_factor == 1:
continue
# Broadcast to array size
dilution_factor = np.where(no_sat_check_mask, 1, dilution_factor)
else:
dilution_factor = 1
continue

# Find where this plane looks like it might saturate given the dilution factor
flagarray, _ = plane_saturation(plane, sat_thresh * dilution_factor, dqflags)
# Find where this plane looks like it might saturate given
# the dilution factor, *and* this group did not already get
# flagged as saturated or do not use, *and* the next group
# was flagged as saturated. Result of the line below is a
# boolean array.

# Find the overlap of where this plane looks like it might saturate, was not currently
# flagged as saturation or DO_NOT_USE, and the next group had saturation flagged.
indx = np.where((np.bitwise_and(flagarray, saturated) != 0) & \
(np.bitwise_and(thisdq, saturated) == 0) & \
(np.bitwise_and(thisdq, dnu) == 0) & \
(np.bitwise_and(nextdq, saturated) != 0))
partial_sat = ((plane >= sat_thresh*dilution_factor) & \
(thisdq & (saturated | dnu) == 0) & \
(nextdq & saturated != 0))

# Reset flag array to only pixels passing this gauntlet
flagarray[:] = 0
flagarray[indx] = dnu
flagarray = (partial_sat * dnu).astype(np.uint32)

# Grow the newly-flagged saturating pixels
if n_pix_grow_sat > 0:
flagarray = adjacent_pixels(flagarray, dnu, n_pix_grow_sat)
adjacent_pixels(flagarray, dnu, n_pix_grow_sat, inplace=True)

# Add them to the gdq array
np.bitwise_or(gdq[ints, group, :, :], flagarray, gdq[ints, group, :, :])
gdq[ints, group, :, :] |= flagarray

# Add an additional pass to look for things saturating in the second group
# that can be particularly tricky to identify
Expand All @@ -160,25 +181,24 @@ def flag_saturated_pixels(
mask &= scigp2 > sat_thresh / len(read_pattern[1])

# Identify groups that are saturated in the third group but not yet flagged in the second
gp3mask = np.where((np.bitwise_and(dq3, saturated) != 0) & \
(np.bitwise_and(dq2, saturated) == 0), True, False)
gp3mask = ((np.bitwise_and(dq3, saturated) != 0) & \
(np.bitwise_and(dq2, saturated) == 0))
mask &= gp3mask

# Flag the 2nd group for the pixels passing that gauntlet
flagarray = np.zeros_like(mask,dtype='uint8')
flagarray[mask] = dnu
flagarray = (mask * dnu).astype(np.uint32)

# Add them to the gdq array
np.bitwise_or(gdq[ints, 1, :, :], flagarray, gdq[ints, 1, :, :])


# Check ZEROFRAME.
if zframe is not None:
plane = zframe[ints, :, :]
flagarray, flaglowarray = plane_saturation(plane, sat_thresh, dqflags)
zdq = flagarray | flaglowarray
if n_pix_grow_sat > 0:
zdq = adjacent_pixels(zdq, saturated, n_pix_grow_sat)
adjacent_pixels(zdq, saturated, n_pix_grow_sat, inplace=True)
plane[zdq != 0] = 0.0
zframe[ints] = plane

Expand All @@ -192,7 +212,7 @@ def flag_saturated_pixels(
return gdq, pdq, zframe


def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat):
def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat=1, inplace=False):
"""
plane_gdq : ndarray
The data quality flags of the current.
Expand All @@ -204,17 +224,65 @@ def adjacent_pixels(plane_gdq, saturated, n_pix_grow_sat):
Number of pixels that each flagged saturated pixel should be 'grown',
to account for charge spilling. Default is 1.

inplace : bool
Update plane_gdq in place, returning None? Default False.

Return
------
sat_pix : ndarray
The saturated pixels in the current plane.
"""
cgdq = plane_gdq.copy()
only_sat = np.bitwise_and(plane_gdq, saturated).astype(np.uint8)
if not inplace:
cgdq = plane_gdq.copy()
else:
cgdq = plane_gdq

only_sat = plane_gdq & saturated > 0
dilated = only_sat.copy()
box_dim = (n_pix_grow_sat * 2) + 1
struct = np.ones((box_dim, box_dim)).astype(bool)
dialated = ndimage.binary_dilation(only_sat, structure=struct).astype(only_sat.dtype)
return np.bitwise_or(cgdq, (dialated * saturated))

# The for loops below are equivalent to
#
#struct = np.ones((box_dim, box_dim)).astype(bool)
#dilated = ndimage.binary_dilation(only_sat, structure=struct).astype(only_sat.dtype)
#
# The explicit loop over the box, followed by taking care of the
# array edges, turns out to be faster by around an order of magnitude.
# There must be poor coding in the underlying routine for
# ndimage.binary_dilation as of scipy 1.14.1.

for i in range(box_dim):
for j in range(box_dim):

# Explicit binary dilation over the inner ('valid')
# region of the convolution/filter

i2 = only_sat.shape[0] - box_dim + i + 1
j2 = only_sat.shape[1] - box_dim + j + 1

k1, k2, l1, l2 = [n_pix_grow_sat, -n_pix_grow_sat,
n_pix_grow_sat, -n_pix_grow_sat]

dilated[k1:k2, l1:l2] |= only_sat[i:i2, j:j2]

for i in range(n_pix_grow_sat - 1, -1, -1):
for j in range(i + n_pix_grow_sat, -1, -1):

# March from the limit of the 'valid' region toward
# each edge. Maximum filter ensures correct dilation.

dilated[i] |= ndimage.maximum_filter(only_sat[j], box_dim)
dilated[:, i] |= ndimage.maximum_filter(only_sat[:, j], box_dim)
dilated[-i - 1] |= ndimage.maximum_filter(only_sat[-j - 1], box_dim)
dilated[:, -i - 1] |= ndimage.maximum_filter(only_sat[:, -j - 1], box_dim)

cgdq[dilated] |= saturated

if inplace:
return None
else:
return cgdq



def plane_saturation(plane, sat_thresh, dqflags):
Expand Down
Loading