Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3697: Jump Step Refactor #9039

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions changes/9039.jump.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor jump step for memory consumption, readability, and maintenance.
115 changes: 0 additions & 115 deletions jwst/jump/jump.py

This file was deleted.

205 changes: 129 additions & 76 deletions jwst/jump/jump_step.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
#! /usr/bin/env python
from stdatamodels.jwst import datamodels

from ..stpipe import Step
from .jump import run_detect_jumps
import time

import numpy as np

from ..lib import reffile_utils

from stcal.jump.jump_class import JumpData
from stcal.jump.jump import detect_jumps_data

from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels import dqflags

__all__ = ["JumpStep"]


class JumpStep(Step):
"""
JumpStep: Performs CR/jump detection on each ramp integration within an
exposure. The 2-point difference method is applied.
"""
"""Step class to perform just detection using two point difference."""

spec = """
rejection_threshold = float(default=4.0,min=0) # CR sigma rejection threshold
Expand Down Expand Up @@ -57,93 +62,141 @@
class_alias = 'jump'

def process(self, step_input):
"""Step method to execute step computations.

Parameter
---------
step_input : RampModel
The ramp model input from the previous step.

Return:
------
result : RampModel
The ramp model with jump step as COMPLETE and jumps detected or
the jump step is SKIPPED.

"""
# Open the input data model
with datamodels.RampModel(step_input) as input_model:

tstart = time.time()

# Check for an input model with NGROUPS<=2
ngroups = input_model.data.shape[1]
nints, ngroups, nrows, ncols = input_model.data.shape
if ngroups <= 2:
self.log.warning('Cannot apply jump detection when NGROUPS<=2;')
self.log.warning('Jump step will be skipped')
input_model.meta.cal_step.jump = 'SKIPPED'
return input_model

# Work on a copy
result = input_model.copy()

# Retrieve the parameter values
rej_thresh = self.rejection_threshold
three_grp_rej_thresh = self.three_group_rejection_threshold
four_grp_rej_thresh = self.four_group_rejection_threshold
max_cores = self.maximum_cores
max_jump_to_flag_neighbors = self.max_jump_to_flag_neighbors
min_jump_to_flag_neighbors = self.min_jump_to_flag_neighbors
flag_4_neighbors = self.flag_4_neighbors
after_jump_flag_dn1 = self.after_jump_flag_dn1
after_jump_flag_time1 = self.after_jump_flag_time1
after_jump_flag_dn2 = self.after_jump_flag_dn2
after_jump_flag_time2 = self.after_jump_flag_time2
min_sat_area = self.min_sat_area
min_jump_area = self.min_jump_area
expand_factor = self.expand_factor
use_ellipses = self.use_ellipses
sat_required_snowball = self.sat_required_snowball
expand_large_events = self.expand_large_events
self.log.info('CR rejection threshold = %g sigma', rej_thresh)
self.log.info('CR rejection threshold = %g sigma', self.rejection_threshold)
if self.maximum_cores != 'none':
self.log.info('Maximum cores to use = %s', max_cores)

# Get the gain and readnoise reference files
gain_filename = self.get_reference_file(result, 'gain')
self.log.info('Using GAIN reference file: %s', gain_filename)

gain_model = datamodels.GainModel(gain_filename)

readnoise_filename = self.get_reference_file(result,
'readnoise')
self.log.info('Using READNOISE reference file: %s',
readnoise_filename)
readnoise_model = datamodels.ReadnoiseModel(readnoise_filename)
# Call the jump detection routine
result = run_detect_jumps(result, gain_model, readnoise_model,
rej_thresh, three_grp_rej_thresh, four_grp_rej_thresh, max_cores,
max_jump_to_flag_neighbors, min_jump_to_flag_neighbors,
flag_4_neighbors,
after_jump_flag_dn1,
after_jump_flag_time1,
after_jump_flag_dn2,
after_jump_flag_time2,
min_sat_area=min_sat_area, min_jump_area=min_jump_area,
expand_factor=expand_factor, use_ellipses=use_ellipses,
min_sat_radius_extend=self.min_sat_radius_extend,
sat_required_snowball=sat_required_snowball, sat_expand=self.sat_expand * 2,
expand_large_events=expand_large_events, find_showers=self.find_showers,
max_shower_amplitude=self.max_shower_amplitude,
edge_size=self.edge_size, extend_snr_threshold=self.extend_snr_threshold,
extend_min_area=self.extend_min_area,
extend_inner_radius=self.extend_inner_radius,
extend_outer_radius=self.extend_outer_radius,
extend_ellipse_expand_ratio=self.extend_ellipse_expand_ratio,
time_masked_after_shower=self.time_masked_after_shower,
min_diffs_single_pass=self.min_diffs_single_pass,
max_extended_radius=self.max_extended_radius * 2,
minimum_groups=self.minimum_groups,
minimum_sigclip_groups=self.minimum_sigclip_groups,
only_use_ints=self.only_use_ints,
mask_snowball_persist_next_int=self.mask_snowball_core_next_int,
snowball_time_masked_next_int=self.snowball_time_masked_next_int
)
self.log.info('Maximum cores to use = %s', self.maximum_cores)

# Detect jumps using a copy of the input data model.
result = input_model.copy()
jump_data = self._setup_jump_data(result)
new_gdq, new_pdq, number_crs, number_extended_events, stddev = detect_jumps_data(jump_data)

# Update the DQ arrays of the output model with the jump detection results
result.groupdq = new_gdq
result.pixeldq = new_pdq

# determine the number of groups with all pixels set to DO_NOT_USE
dnu_flag = dqflags.pixel["DO_NOT_USE"]
num_flagged_grps = 0
for integ in range(nints):
for grp in range(ngroups):
if np.all(np.bitwise_and(result.groupdq[integ, grp, :, :], dnu_flag)):
num_flagged_grps += 1

total_groups = nints * ngroups - num_flagged_grps - nints
if total_groups >= 1:
total_time = result.meta.exposure.group_time * total_groups
total_pixels = nrows * ncols

crs = 1000 * number_crs / (total_time * total_pixels)
result.meta.exposure.primary_cosmic_rays = crs

events = 1e6 * number_extended_events / (total_time * total_pixels)
result.meta.exposure.extended_emission_events = events

tstop = time.time()
self.log.info('The execution time in seconds: %f', tstop - tstart)

result.meta.cal_step.jump = 'COMPLETE'

# Cleanup
del gain_model
del readnoise_model

return result

def _setup_jump_data(self, result):
"""Create a JumpData instance to be used by STCAL jump.

Parameter
---------
result : RampModel
The ramp model input from the previous step.

Return:
------
jump_data : JumpData
The data container to be used to run the STCAL detect_jumps_data.

"""
# Get the gain and readnoise reference files
gain_filename = self.get_reference_file(result, 'gain')
self.log.info('Using GAIN reference file: %s', gain_filename)
readnoise_filename = self.get_reference_file(result,'readnoise')
self.log.info('Using READNOISE reference file: %s', readnoise_filename)

with datamodels.ReadnoiseModel(readnoise_filename) as rnoise_m, \
datamodels.GainModel(gain_filename) as gain_m:
# Get 2D gain and read noise values from their respective models
if reffile_utils.ref_matches_sci(result, gain_m):
gain_2d = gain_m.data
else:
self.log.info('Extracting gain subarray to match science data')
gain_2d = reffile_utils.get_subarray_data(result, gain_m)

if reffile_utils.ref_matches_sci(result, rnoise_m):
rnoise_2d = rnoise_m.data
else:
self.log.info('Extracting readnoise subarray to match science data')
rnoise_2d = reffile_utils.get_subarray_data(result, rnoise_m)

# Instantiate a JumpData class and populate it based on the input RampModel.
jump_data = JumpData(result, gain_2d, rnoise_2d, dqflags.pixel)

jump_data.set_detection_settings(
self.rejection_threshold, self.three_group_rejection_threshold, self.four_group_rejection_threshold,
self.max_jump_to_flag_neighbors, self.min_jump_to_flag_neighbors, self.flag_4_neighbors)

# determine the number of groups that correspond to the after_jump times

Check warning on line 172 in jwst/jump/jump_step.py

View check run for this annotation

Codecov / codecov/patch

jwst/jump/jump_step.py#L171-L172

Added lines #L171 - L172 were not covered by tests
# needed because the group time is not passed to detect_jumps_data
gtime = result.meta.exposure.group_time
after_jump_flag_n1 = int(self.after_jump_flag_time1 // gtime)
after_jump_flag_n2 = int(self.after_jump_flag_time2 // gtime)

jump_data.set_after_jump(

Check warning on line 178 in jwst/jump/jump_step.py

View check run for this annotation

Codecov / codecov/patch

jwst/jump/jump_step.py#L177-L178

Added lines #L177 - L178 were not covered by tests
self.after_jump_flag_dn1, after_jump_flag_n1,
self.after_jump_flag_dn2, after_jump_flag_n2)

sat_expand = self.sat_expand * 2
jump_data.set_snowball_info(
self.expand_large_events, self.min_jump_area, self.min_sat_area, self.expand_factor,
self.sat_required_snowball, self.min_sat_radius_extend, sat_expand, self.edge_size)

max_extended_radius = self.max_extended_radius * 2
jump_data.set_shower_info(
self.find_showers, self.extend_snr_threshold, self.extend_min_area, self.extend_inner_radius,
self.extend_outer_radius, self.extend_ellipse_expand_ratio, self.min_diffs_single_pass,
max_extended_radius)

jump_data.set_sigma_clipping_info(
self.minimum_groups, self.minimum_sigclip_groups, self.only_use_ints)

jump_data.max_cores = self.maximum_cores
jump_data.grps_masked_after_shower = int(self.time_masked_after_shower // gtime)
jump_data.mask_persist_grps_next_int = self.mask_snowball_core_next_int
jump_data.persist_grps_flagged = int(self.snowball_time_masked_next_int // gtime)
jump_data.max_shower_amplitude = jump_data.max_shower_amplitude * gtime

return jump_data
Loading
Loading