Skip to content

Commit

Permalink
Merge branch 'main' into jp-3805
Browse files Browse the repository at this point in the history
  • Loading branch information
tapastro authored Dec 11, 2024
2 parents 36d1cc6 + 30dd769 commit f2924b4
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 110 deletions.
2 changes: 2 additions & 0 deletions changes/306.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add maximum_shower_amplitude parameter to MIRI cosmic rays showers routine
to fix accidental flagging of bright science pixels.
170 changes: 90 additions & 80 deletions src/stcal/jump/jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def detect_jumps(
min_diffs_single_pass=10,
mask_persist_grps_next_int=True,
persist_grps_flagged=25,
max_shower_amplitude=12
):
"""
This is the high-level controlling routine for the jump detection process.
Expand Down Expand Up @@ -220,6 +221,8 @@ def detect_jumps(
then all differences are processed at once.
min_diffs_single_pass : int
The minimum number of groups to switch to flagging all outliers in a single pass.
max_shower_amplitude : float
The maximum possible amplitude for flagged MIRI showers in DN/group
Returns
-------
Expand Down Expand Up @@ -298,46 +301,7 @@ def detect_jumps(
dqflags['DO_NOT_USE']
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']
# This is the flag that controls the flagging of snowballs.
if expand_large_events:
gdq, total_snowballs = flag_large_events(
gdq,
jump_flag,
sat_flag,
min_sat_area=min_sat_area,
min_jump_area=min_jump_area,
expand_factor=expand_factor,
sat_required_snowball=sat_required_snowball,
min_sat_radius_extend=min_sat_radius_extend,
edge_size=edge_size,
sat_expand=sat_expand,
max_extended_radius=max_extended_radius,
mask_persist_grps_next_int=mask_persist_grps_next_int,
persist_grps_flagged=persist_grps_flagged,
)
log.info("Total snowballs = %i", total_snowballs)
number_extended_events = total_snowballs
if find_showers:
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise_2d,
frames_per_group,
minimum_sigclip_groups,
dqflags,
snr_threshold=extend_snr_threshold,
min_shower_area=extend_min_area,
inner=extend_inner_radius,
outer=extend_outer_radius,
sat_flag=sat_flag,
jump_flag=jump_flag,
ellipse_expand=extend_ellipse_expand_ratio,
num_grps_masked=grps_masked_after_shower,
max_extended_radius=max_extended_radius,
)
log.info("Total showers= %i", num_showers)
number_extended_events = num_showers

else:
yinc = int(n_rows // n_slices)
slices = []
Expand Down Expand Up @@ -463,46 +427,50 @@ def detect_jumps(
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']

# This is the flag that controls the flagging of snowballs.
if expand_large_events:
gdq, total_snowballs = flag_large_events(
gdq,
jump_flag,
sat_flag,
min_sat_area=min_sat_area,
min_jump_area=min_jump_area,
expand_factor=expand_factor,
sat_required_snowball=sat_required_snowball,
min_sat_radius_extend=min_sat_radius_extend,
edge_size=edge_size,
sat_expand=sat_expand,
max_extended_radius=max_extended_radius,
mask_persist_grps_next_int=mask_persist_grps_next_int,
persist_grps_flagged=persist_grps_flagged,
)
log.info("Total snowballs = %i", total_snowballs)
number_extended_events = total_snowballs
if find_showers:
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise_2d,
frames_per_group,
minimum_sigclip_groups,
dqflags,
snr_threshold=extend_snr_threshold,
min_shower_area=extend_min_area,
inner=extend_inner_radius,
outer=extend_outer_radius,
sat_flag=sat_flag,
jump_flag=jump_flag,
ellipse_expand=extend_ellipse_expand_ratio,
num_grps_masked=grps_masked_after_shower,
max_extended_radius=max_extended_radius,
)
log.info("Total showers= %i", num_showers)
number_extended_events = num_showers
# Look for snowballs in near-IR data
if expand_large_events:
gdq, total_snowballs = flag_large_events(
gdq,
jump_flag,
sat_flag,
min_sat_area=min_sat_area,
min_jump_area=min_jump_area,
expand_factor=expand_factor,
sat_required_snowball=sat_required_snowball,
min_sat_radius_extend=min_sat_radius_extend,
edge_size=edge_size,
sat_expand=sat_expand,
max_extended_radius=max_extended_radius,
mask_persist_grps_next_int=mask_persist_grps_next_int,
persist_grps_flagged=persist_grps_flagged,
)
log.info("Total snowballs = %i", total_snowballs)
number_extended_events = total_snowballs

# Look for showers in mid-IR data
if find_showers:
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise_2d,
frames_per_group,
minimum_sigclip_groups,
dqflags,
snr_threshold=extend_snr_threshold,
min_shower_area=extend_min_area,
inner=extend_inner_radius,
outer=extend_outer_radius,
sat_flag=sat_flag,
jump_flag=jump_flag,
ellipse_expand=extend_ellipse_expand_ratio,
num_grps_masked=grps_masked_after_shower,
max_extended_radius=max_extended_radius,
max_shower_amplitude=max_shower_amplitude
)
log.info("Total showers= %i", num_showers)
number_extended_events = num_showers

elapsed = time.time() - start
log.info("Total elapsed time = %g sec", elapsed)

Expand Down Expand Up @@ -878,6 +846,7 @@ def near_edge(jump, low_threshold, high_threshold):
)


# MIRI cosmic ray showers code
def find_faint_extended(
indata,
ingdq,
Expand All @@ -897,6 +866,7 @@ def find_faint_extended(
num_grps_masked=25,
max_extended_radius=200,
min_diffs_for_shower=10,
max_shower_amplitude=6,
):
"""
Parameters
Expand Down Expand Up @@ -931,6 +901,8 @@ def find_faint_extended(
The upper limit for the extension of saturation and jump
minimum_sigclip_groups : int
The minimum number of groups to use sigma clipping.
max_shower_amplitude : float
The maximum amplitude of shower artifacts to correct in DN/group
Returns
Expand All @@ -948,6 +920,7 @@ def find_faint_extended(
nints = data.shape[0]
ngrps = data.shape[1]
num_grps_donotuse = 0

for integ in range(nints):
for grp in range(ngrps):
if np.all(np.bitwise_and(gdq[integ, grp, :, :], donotuse_flag)):
Expand Down Expand Up @@ -1028,6 +1001,8 @@ def find_faint_extended(
masked_smoothed_ratio = convolve(masked_ratio.filled(np.nan), ring_2D_kernel)
# mask out the pixels that got refilled by the convolution
masked_smoothed_ratio[dnuy, dnux] = np.nan
masked_smoothed_ratio[saty, satx] = np.nan
masked_smoothed_ratio[jumpy, jumpx] = np.nan
nrows = ratio.shape[1]
ncols = ratio.shape[2]
extended_emission = np.zeros(shape=(nrows, ncols), dtype=np.uint8)
Expand Down Expand Up @@ -1111,6 +1086,41 @@ def find_faint_extended(
num_grps_masked=num_grps_masked,
max_extended_radius=max_extended_radius
)

# Ensure that flagging showers didn't change final fluxes by more than the allowed amount
for intg in range(nints):
# Consider DO_NOT_USE, SATURATION, and JUMP_DET flags
invalid_flags = donotuse_flag | sat_flag | jump_flag

# Approximate pre-shower rates
tempdata = indata[intg, :, :, :].copy()
# Ignore any groups flagged in the original gdq array
tempdata[ingdq[intg, :, :, :] & invalid_flags != 0] = np.nan
# Compute group differences
diff = np.diff(tempdata, axis=0)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice")
image1 = np.nanmean(diff, axis=0)

# Approximate post-shower rates
tempdata = indata[intg, :, :, :].copy()
# Ignore any groups flagged in the shower gdq array
tempdata[gdq[intg, :, :, :] & invalid_flags != 0] = np.nan
# Compute group differences
diff = np.diff(tempdata, axis=0)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="All-NaN")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice")
image2 = np.nanmean(diff, axis=0)

# Revert the group flags to the pre-shower flags for any pixels whose rates
# became NaN or changed by more than the amount reasonable for a real CR shower
# Note that max_shower_amplitude should now be in DN/group not DN/s
diff = np.abs(image1 - image2)
indx = np.where((np.isfinite(diff) == False) | (diff > max_shower_amplitude))
gdq[intg, :, indx[0], indx[1]] = ingdq[intg, :, indx[0], indx[1]]

return gdq, total_showers

def find_first_good_group(int_gdq, do_not_use):
Expand Down
31 changes: 1 addition & 30 deletions tests/test_jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def test_find_faint_extended(tmp_path):
jump_flag=4,
ellipse_expand=1.,
num_grps_masked=1,
max_shower_amplitude=10
)
# Check that all the expected samples in group 2 are flagged as jump and
# that they are not flagged outside
Expand All @@ -405,36 +406,6 @@ def test_find_faint_extended(tmp_path):
# Check that the flags are not applied in the 3rd group after the event
assert np.all(gdq[0, 4, 12:22, 14:23]) == 0

def test_find_faint_extended():
nint, ngrps, ncols, nrows = 1, 66, 5, 5
data = np.zeros(shape=(nint, ngrps, nrows, ncols), dtype=np.float32)
gdq = np.zeros_like(data, dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
pdq[0, 0] = 1
pdq[1, 1] = 2147483648
# pdq = np.zeros(shape=(data.shape[2], data.shape[3]), dtype=np.uint8)
gain = 4
readnoise = np.ones(shape=(nrows, ncols), dtype=np.float32) * 6.0 * gain
rng = np.random.default_rng(12345)
data[0, 1:, 14:20, 15:20] = 6 * gain * 6.0 * np.sqrt(2)
data = data + rng.normal(size=(nint, ngrps, nrows, ncols)) * readnoise
gdq, num_showers = find_faint_extended(
data,
gdq,
pdq,
readnoise * np.sqrt(2),
1,
100,
snr_threshold=3,
min_shower_area=10,
inner=1,
outer=2.6,
sat_flag=2,
jump_flag=4,
ellipse_expand=1.1,
num_grps_masked=0,
)


# No shower is found because the event is identical in all ints
def test_find_faint_extended_sigclip():
Expand Down

0 comments on commit f2924b4

Please sign in to comment.