diff --git a/changes/8967.extract_1d.rst b/changes/8967.extract_1d.rst new file mode 100644 index 0000000000..34f6dcc714 --- /dev/null +++ b/changes/8967.extract_1d.rst @@ -0,0 +1 @@ +Add options for PSF-based optimal extraction for point sources in MIRI LRS fixed slit exposures. diff --git a/docs/jwst/extract_1d/arguments.rst b/docs/jwst/extract_1d/arguments.rst index 5d3b47e7ae..e72d3979fb 100644 --- a/docs/jwst/extract_1d/arguments.rst +++ b/docs/jwst/extract_1d/arguments.rst @@ -15,11 +15,22 @@ The following arguments apply to all modes unless otherwise specified. ``--apply_apcorr`` Switch to select whether or not to apply an APERTURE correction during the - Extract1dStep processing. Default is ``True``. Has no effect for NIRISS SOSS data. + Extract1dStep processing. Default is ``True``. Has no effect for NIRISS SOSS data + or for optimal extractions. Step Arguments for Slit and Slitless Spectroscopic Data ------------------------------------------------------- +``--extraction_type`` + Specify the extraction type. + If 'box', a standard extraction is performed, summing over an aperture box. + If 'optimal', a PSF-based optimal extraction is performed. + If None, optimal extraction is attempted whenever use_source_posn is True. + Box extraction is suitable for any input data (point sources and extended sources; + resampled and unresampled images). Optimal extraction is best suited for unresampled + point sources. Currently, optimal extraction is only available for MIRI LRS Fixed Slit data. + The default extraction type is 'box'. + ``--use_source_posn`` Specify whether the target and background extraction region locations specified in the :ref:`EXTRACT1D ` reference @@ -34,6 +45,21 @@ Step Arguments for Slit and Slitless Spectroscopic Data Specify a number of pixels (fractional pixels are allowed) to offset the extraction aperture from the nominal position. The default is 0. +``--model_nod_pair`` + Flag to enable fitting a negative trace during optimal extraction. + If True, and the extraction type is 'optimal', then a negative trace + from nod subtraction is modeled alongside the positive source during + extraction. This will be attempted only if the input data has been background + subtracted and the dither pattern type indicates that 2 nods were used. + The default value is True. + +``--optimize_psf_location`` + Flag to enable PSF location optimization during optimal extraction. + If True, and the extraction type is 'optimal', then the placement of + the PSF model for the source location (and negative nod, if present) + will be iteratively optimized. This parameter is recommended if + negative nods are modeled. The default value is True. + ``--smoothing_length`` If ``smoothing_length`` is greater than 1 (and is an odd integer), the image data used to perform background extraction will be smoothed in the @@ -82,13 +108,17 @@ Step Arguments for Slit and Slitless Spectroscopic Data to report on progress. Default value is 50. ``--save_profile`` - Flag to enable saving the spatial profile representing the extraction aperture. + Flag to enable saving the spatial profile representing the extraction aperture or model. If True, the profile is saved to disk with suffix "profile". ``--save_scene_model`` - Flag to enable saving a model of the 2D flux as defined by the extraction aperture. + Flag to enable saving a model of the 2D flux as defined by the extraction aperture or PSF model. If True, the model is saved to disk with suffix "scene_model". +``--save_residual_image`` + Flag to enable saving the residual image (from the input minus the scene model) + If True, the model is saved to disk with suffix "residual". + Step Arguments for IFU Data --------------------------- diff --git a/docs/jwst/extract_1d/description.rst b/docs/jwst/extract_1d/description.rst index 50c4149d23..fd967e9298 100644 --- a/docs/jwst/extract_1d/description.rst +++ b/docs/jwst/extract_1d/description.rst @@ -18,11 +18,14 @@ The EXTRACT1D reference file is not used for Wide-Field Slitless Spectroscopy da the full size of the input 2D subarray or cutout for each source, or restricted to the region within which the world coordinate system (WCS) is defined in each cutout. -For slit-like 2D input data, source and background extractions are done using -a rectangular aperture that covers one pixel in the dispersion direction and +For slit-like 2D input data, source and background extractions are, by default, done +using a rectangular aperture that covers one pixel in the dispersion direction and uses a height in the cross-dispersion direction that is defined by parameters in -the EXTRACT1D reference file. -For 3D IFU data, on the other hand, the extraction options differ depending on +the EXTRACT1D reference file. Optionally, for point sources, a PSF-based optimal +extraction may be performed, using a model of the spectral PSF to fit the total flux +at each dispersion element. + +For 3D IFU data, the extraction options differ depending on whether the target is a point or extended source. For a point source, the spectrum is extracted using circular aperture photometry, optionally including background subtraction using a circular annulus. @@ -36,18 +39,21 @@ object and perform extraction. For 3D NIRSpec fixed slit rateints data, the ``extract_1d`` step will be skipped as 3D input for the mode is not supported. -For most spectral modes an aperture correction will be applied to the extracted +For most spectral modes, an aperture correction will be applied to the extracted 1D spectral data (unless otherwise selected by the user), in order to put the results onto an infinite aperture scale. This is done by creating interpolation functions based on the APCORR reference file data and applying the interpolated aperture correction (a multiplicative factor between 0 and 1) to the extracted, 1D spectral data (corrected data include the "flux", "surf_bright", "flux_error", "sb_error", and all flux and -surface brightness variance columns in the output table). +surface brightness variance columns in the output table). For optimal extractions, +aperture correction is not performed, since it is assumed the total flux is +modeled by the PSF. Input ----- -Calibrated and potentially resampled 2D images or 3D cubes. The format should be a +The input data are calibrated and potentially resampled 2D images or 3D cubes. +The format should be a CubeModel, SlitModel, IFUCubeModel, ImageModel, MultiSlitModel, or a ModelContainer. For some JWST modes this is usually a resampled product, such as the "s2d" products for MIRI LRS fixed-slit, NIRSpec fixed-slit, and NIRSpec MOS, or the "s3d" products @@ -55,9 +61,7 @@ for MIRI MRS and NIRSpec IFU. For other modes that are not resampled (e.g. MIRI LRS slitless, NIRISS SOSS, NIRSpec BOTS, and NIRCam and NIRISS WFSS), this will be a "cal" or "calints" product. For modes that have multiple slit instances (NIRSpec fixed-slit and MOS, WFSS), -the SCI extensions should have the keyword SLTNAME to specify which slit was extracted, -though if there is only one slit (e.g. MIRI LRS and NIRISS SOSS), the slit name can -be taken from the EXTRACT1D reference file instead. +the SCI extensions should have the keyword SLTNAME to specify which slit was extracted. Normally the :ref:`photom ` step should be applied before running ``extract_1d``. If ``photom`` has not been run, a warning will be logged and the @@ -94,13 +98,13 @@ FLUX_ERROR is the error estimate for FLUX; it has the same units as FLUX. The error is calculated as the square root of the sum of the three variance arrays: Poisson, read noise (RNOISE), and flat field (FLAT). SURF_BRIGHT is the surface brightness in MJy / sr, except that for point -sources observed with NIRSpec and NIRISS SOSS, SURF_BRIGHT will be set to +sources observed with NIRSpec and NIRISS SOSS, or optimal extractions, SURF_BRIGHT will be set to zero, because there is no way to express the extracted results from those modes as a surface brightness. SB_ERROR is the error estimate for SURF_BRIGHT, calculated in the same fashion as FLUX_ERROR but using the SB_VAR arrays. While it's expected that a user will make use of the FLUX column for point-source data and the -SURF_BRIGHT column for an extended source, both columns are populated (except for -NIRSpec and NIRISS SOSS point sources, as mentioned above). +SURF_BRIGHT column for an extended source, both columns are populated +(except as mentioned above). The ``extract_1d`` step collapses the input data from 2-D to 1-D by summing one or more rows (or columns, depending on the dispersion direction). @@ -134,9 +138,9 @@ otherwise. .. _extract-1d-for-slits: -Extraction for 2D Slit Data ---------------------------- -The operational details of the 1D extraction depend heavily on the parameter +Box Extraction for 2D Slit Data +------------------------------- +For standard box extractions, the operational details depend heavily on the parameter values given in the :ref:`EXTRACT1D ` reference file. Here we describe their use within the ``extract_1d`` step. @@ -305,6 +309,48 @@ a separate polynomial. However, the independent variable (wavelength or pixel) does need to be the same for all polynomials for a given slit. +Optimal Extraction for 2D Slit Data +----------------------------------- + +Optimal extraction proceeds similarly to box extraction for 2D slit data, except that +instead of summing over an aperture defined by the reference files, a model of the point +spread function (PSF) is fit to the data at each dispersion element. This generally provides +higher signal-to-noise for the output spectrum than box extractions and has the advantage +of ignoring missing data due to bad pixels, cosmic rays, or saturation. Optimal extraction +also does not require a resampled spectral image as input: it can avoid the extra interpolation +by directly fitting the spatial profile along the curved trace at each dispersion element. + +Optimal extraction is suited only to point sources with known source locations, for which a +high-fidelity PSF model is available. Currently, only the MIRI LRS fixed slit exposure type +has a PSF model available in CRDS. + +When optimal extraction is selected (`extraction_type = 'optimal'`), the aperture definitions in +the extraction reference file are ignored, and the following parameters +are used instead: + +* `use_source_posn`: Source position is estimated from the input metadata and used to + center the PSF model. The recommended value is True, in order to account for spatial offsets + within the slit; if False, or if the source position could not be estimated, the source is + assumed to be at the center of the slit. +* `model_nod_pair`: If nod subtraction occurred prior to extraction, setting this option to + True will allow the extraction algorithm to model a single negative trace from the nod pair + alongside the positive trace. This can be helpful in accounting for PSF overlap between the + positive and negative traces. This option is ignored if no background subtraction occurred, + or if the dither pattern was not a 2-point nod. +* `optimize_psf_location`: Since source position estimates may be slightly inaccurate, + it may be useful to iteratively optimize the PSF location. When this option is set to True, the + location of the positive and negative traces (if used) are optimized with respect to the + residuals of the scene modeled by the PSF at that location. This option is + strongly recommended if `model_nod_pair` is True, since the negative nod location is less + reliably estimated than the positive trace location. +* `subtract_background`: Unlike during box extraction, the background levels can be modeled and removed + during optimal extraction without explicitly setting a background region. It is recommended to + set this parameter to True if background subtraction was skipped prior to extraction. Set this + parameter to False if a negative nod trace is present but not modeled (`model_nod_pair = False`). +* `override_psf`: If a custom flux model is required, it is possible to provide one by overriding + the PSF model reference file. Set this parameter to the filename for a FITS file matching the + :ref:`SpecPsfModel ` format. + .. _extract-1d-for-ifu: Extraction for 3D IFU Data diff --git a/docs/jwst/extract_1d/extract1d_api.rst b/docs/jwst/extract_1d/extract1d_api.rst index 7b3b8e836c..46c0e23a8c 100644 --- a/docs/jwst/extract_1d/extract1d_api.rst +++ b/docs/jwst/extract_1d/extract1d_api.rst @@ -4,3 +4,5 @@ Python interfaces for 1D Extraction .. automodapi:: jwst.extract_1d.extract .. automodapi:: jwst.extract_1d.extract1d .. automodapi:: jwst.extract_1d.ifu +.. automodapi:: jwst.extract_1d.source_location +.. automodapi:: jwst.extract_1d.psf_profile diff --git a/docs/jwst/extract_1d/reference_files.rst b/docs/jwst/extract_1d/reference_files.rst index 3e51a9a2b5..a47e5f3436 100644 --- a/docs/jwst/extract_1d/reference_files.rst +++ b/docs/jwst/extract_1d/reference_files.rst @@ -1,6 +1,9 @@ Reference File ============== -The ``extract_1d`` step uses an EXTRACT1D reference file and an APCORR reference file. +For most modes, the ``extract_1d`` step uses an EXTRACT1D reference file and an +APCORR reference file. For optimal extraction, it additionally uses a PSF +reference file. .. include:: ../references_general/extract1d_reffile.inc .. include:: ../references_general/apcorr_reffile.inc +.. include:: ../references_general/psf_reffile.inc diff --git a/docs/jwst/references_general/psf_reffile.inc b/docs/jwst/references_general/psf_reffile.inc new file mode 100644 index 0000000000..291a2cbc8a --- /dev/null +++ b/docs/jwst/references_general/psf_reffile.inc @@ -0,0 +1,44 @@ +.. _psf_reffile: + +PSF Reference File +^^^^^^^^^^^^^^^^^^ + +:REFTYPE: PSF +:Data model: `~jwst.datamodels.SpecPsfModel` + +The PSF reference file contains a model of the 1-D point spread function +by wavelength, intended to support spectral modeling and extraction. + + +Reference Selection Keywords for PSF +++++++++++++++++++++++++++++++++++++ +CRDS selects appropriate PSF references based on the following keywords. +PSF is not applicable for instruments not in the table. +All keywords used for file selection are *required*. + +========== ========================================================================= +Instrument Keywords +========== ========================================================================= +MIRI INSTRUME, DETECTOR, FILTER, EXP_TYPE +========== ========================================================================= + +.. include:: ../includes/standard_keywords.inc + +Reference File Format ++++++++++++++++++++++ +PSF reference files are in FITS format, with 2 IMAGE extensions. +The FITS primary HDU does not contain a data array. +The format and content of the file is as follows: + +======= ======== ===== ============== ========= +EXTNAME XTENSION NAXIS Dimensions Data type +======= ======== ===== ============== ========= +PSF IMAGE 2 ncols x nrows float +WAVE IMAGE 1 ncols float +======= ======== ===== ============== ========= + +The values in the ``PSF`` array give relative spectral flux values by cross-dispersion +position, at each dispersion element specified in the ``WAVE`` array. Detector +pixels are subsampled by the amount specified in the SUBPIX keyword, and the PSF +is centered on the cross-dispersion element specified in either CENTCOL (for vertical +dispersion) or CENTROW (for horizontal dispersion). diff --git a/docs/jwst/references_general/references_general.rst b/docs/jwst/references_general/references_general.rst index ca25eb25fc..9ca69a85b4 100644 --- a/docs/jwst/references_general/references_general.rst +++ b/docs/jwst/references_general/references_general.rst @@ -100,6 +100,8 @@ documentation on each reference file. + +--------------------------------------------------+ | | :ref:`APCORR ` | + +--------------------------------------------------+ +| | :ref:`PSF ` | ++ +--------------------------------------------------+ | | SPECKERNEL (NIRISS SOSS ATOCA only) | + +--------------------------------------------------+ | | SPECPROFILE (NIRISS SOSS ATOCA only) | @@ -258,6 +260,8 @@ documentation on each reference file. +--------------------------------------------------+-------------------------------------------------------+ | :ref:`PHOTOM ` | :ref:`photom ` | +--------------------------------------------------+-------------------------------------------------------+ +| :ref:`PSF ` | :ref:`extract_1d ` | ++--------------------------------------------------+-------------------------------------------------------+ | :ref:`PSFMASK ` | :ref:`align_refs ` | +--------------------------------------------------+-------------------------------------------------------+ | :ref:`READNOISE ` | :ref:`jump ` | diff --git a/jwst/extract_1d/extract.py b/jwst/extract_1d/extract.py index f78b826e0f..5abe1c67ed 100644 --- a/jwst/extract_1d/extract.py +++ b/jwst/extract_1d/extract.py @@ -4,25 +4,24 @@ from json.decoder import JSONDecodeError from astropy.modeling import polynomial -from gwcs.wcstools import grid_from_bounding_box -from scipy.interpolate import interp1d from stdatamodels.jwst import datamodels from stdatamodels.jwst.datamodels.apcorr import ( MirLrsApcorrModel, MirMrsApcorrModel, NrcWfssApcorrModel, NrsFsApcorrModel, NrsMosApcorrModel, NrsIfuApcorrModel, NisWfssApcorrModel ) -from jwst.assign_wcs.util import wcs_bbox_from_shape from jwst.datamodels import ModelContainer from jwst.lib import pipe_utils from jwst.lib.wcs_utils import get_wavelengths from jwst.extract_1d import extract1d, spec_wcs from jwst.extract_1d.apply_apcorr import select_apcorr +from jwst.extract_1d.psf_profile import psf_profile +from jwst.extract_1d.source_location import location_from_wcs __all__ = ['run_extract1d', 'read_extract1d_ref', 'read_apcorr_ref', 'get_extract_parameters', 'box_profile', 'aperture_center', - 'location_from_wcs', 'shift_by_offset', - 'define_aperture', 'extract_one_slit', 'create_extraction'] + 'shift_by_offset', 'define_aperture', 'extract_one_slit', + 'create_extraction'] log = logging.getLogger(__name__) @@ -145,8 +144,10 @@ def read_apcorr_ref(refname, exptype): def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, smoothing_length=None, bkg_fit=None, bkg_order=None, - use_source_posn=None, subtract_background=None, - position_offset=0.0): + subtract_background=None, + use_source_posn=None, position_offset=0.0, + model_nod_pair=False, optimize_psf_location=False, + extraction_type='box', psf_ref_name='N/A'): """Get extraction parameter values. Parameters @@ -155,21 +156,16 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, For an extract1d reference file in JSON format, `ref_dict` will be the entire contents of the file. If there is no extract1d reference file, `ref_dict` will be None. - input_model : JWSTDataModel This can be either the input science file or one SlitModel out of a list of slits. - slitname : str The name of the slit, or "ANY". - sp_order : int The spectral order number. - meta : ObjectNode The metadata for the actual input model, i.e. not just for the current slit, from input_model.meta. - smoothing_length : int or None, optional Width of a boxcar function for smoothing the background regions. If None, the smoothing length will be retrieved from `ref_dict`, or @@ -179,14 +175,12 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, explicitly specified the value, so that value will be used. This argument is only used if background regions have been specified. - bkg_fit : str or None, optional The type of fit to apply to background values in each column (or row, if the dispersion is vertical). The default `poly` results in a polynomial fit of order `bkg_order`. Other options are `mean` and `median`. If `mean` or `median` is selected, `bkg_order` is ignored. - bkg_order : int or None, optional Polynomial order for fitting to each column (or row, if the dispersion is vertical) of background. If None, the polynomial @@ -198,21 +192,31 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, specified the value, so that value will be used. This argument must be positive or zero, and it is only used if background regions have been specified. - + subtract_background : bool or None, optional + If False, all background parameters will be ignored. use_source_posn : bool or None, optional If True, the target and background positions specified in `ref_dict` (or a default target position) will be shifted to account for the actual source location in the data. - If None, the value specified in `ref_dict` will be used, or it will - be set to True if not found in `ref_dict`. - - subtract_background : bool or None, optional - If False, all background parameters will be ignored. - + If None, a default value will be set, based on the exposure type. position_offset : float or None, optional Pixel offset to apply to the nominal source location. If None, the value specified in `ref_dict` will be used or it will default to 0. + model_nod_pair : bool, optional + If True, and if `extraction_type` is 'optimal', then a negative + trace from nod subtraction will be modeled alongside the positive + source, if possible. + optimize_psf_location : bool + If True, and if `extraction_type` is 'optimal', then the source + location will be optimized, via iterative comparisons of the scene + model with the input data. + extraction_type : str, optional + Extraction type ('box' or 'optimal'). Optimal extraction is + only available if `psf_ref_name` is not 'N/A'. If set to None, + optimal extraction will be used if `use_source_posn` is True. + psf_ref_name : str, optional + The name of the PSF reference file, or "N/A". Returns ------- @@ -241,9 +245,11 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, extract_params['subtract_background'] = False extract_params['extraction_type'] = 'box' extract_params['use_source_posn'] = False # no source position correction - extract_params['position_correction'] = 0 - extract_params['independent_var'] = 'pixel' extract_params['position_offset'] = 0. + extract_params['model_nod_pair'] = False + extract_params['optimize_psf_location'] = False + extract_params['psf'] = 'N/A' + extract_params['independent_var'] = 'pixel' extract_params['trace'] = None # Note that extract_params['dispaxis'] is not assigned. # This will be done later, possibly slit by slit. @@ -286,8 +292,12 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, extract_params['src_coeff'] = aper.get('src_coeff') extract_params['bkg_coeff'] = aper.get('bkg_coeff') + if (extract_params['bkg_coeff'] is not None - and subtract_background is not False): + and subtract_background is None): + subtract_background = True + + if subtract_background: extract_params['subtract_background'] = True if bkg_fit is not None: # Mean value for background fitting is equivalent @@ -299,8 +309,8 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, else: extract_params['bkg_fit'] = aper.get('bkg_fit', 'poly') else: - extract_params['bkg_fit'] = None extract_params['subtract_background'] = False + extract_params['bkg_fit'] = None extract_params['independent_var'] = aper.get('independent_var', 'pixel').lower() @@ -331,7 +341,6 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, extract_params['trace'] = None extract_params['extract_width'] = aper.get('extract_width') - extract_params['position_correction'] = 0 # default value if smoothing_length is None: extract_params['smoothing_length'] = aper.get('smoothing_length', 0) @@ -357,9 +366,26 @@ def get_extract_parameters(ref_dict, input_model, slitname, sp_order, meta, f'{sm_length}.') extract_params['smoothing_length'] = sm_length - # Default the extraction type to 'box': 'optimal' - # is not yet supported. - extract_params['extraction_type'] = 'box' + extract_params['psf'] = psf_ref_name + extract_params['optimize_psf_location'] = optimize_psf_location + extract_params['model_nod_pair'] = model_nod_pair + + # Check for a valid PSF file + extraction_type = str(extraction_type).lower() + if extract_params['psf'] == 'N/A' and extraction_type != 'box': + log.warning("No PSF file available. Setting extraction type to 'box'.") + extraction_type = 'box' + + # Set the extraction type to 'box' or 'optimal' + if extraction_type == 'none': + if extract_params['use_source_posn']: + extract_params['extraction_type'] = 'optimal' + else: + extract_params['extraction_type'] = 'box' + log.info(f"Using extraction type '{extract_params['extraction_type']}' " + f"for use_source_posn = {extract_params['use_source_posn']}") + else: + extract_params['extraction_type'] = extraction_type break @@ -982,142 +1008,6 @@ def aperture_center(profile, dispaxis=1, middle_pix=None): return slit_center, spec_center -def location_from_wcs(input_model, slit): - """Get the cross-dispersion location of the spectrum, based on the WCS. - - None values will be returned if there was insufficient information - available, e.g. if the wavelength attribute or wcs function is not - defined. - - Parameters - ---------- - input_model : DataModel - The input science model containing metadata information. - slit : DataModel or None - One slit from a MultiSlitModel (or similar), or None. - The WCS and target coordinates will be retrieved from `slit` - unless `slit` is None. In that case, they will be retrieved - from `input_model`. - - Returns - ------- - middle : int or None - Pixel coordinate in the dispersion direction within the 2-D - cutout (or the entire input image) at the middle of the WCS - bounding box. This is the point at which to determine the - nominal extraction location, in case it varies along the - spectrum. The offset will then be the difference between - `location` (below) and the nominal location. - middle_wl : float or None - The wavelength at pixel `middle`. - location : float or None - Pixel coordinate in the cross-dispersion direction within the - spectral image that is at the planned target location. - The spectral extraction region should be centered here. - trace : ndarray or None - An array of source positions, one per dispersion element, corresponding - to the location at each point in the wavelength array. If the - input data is resampled, the trace corresponds directly to the - location. - """ - if slit is not None: - shape = slit.data.shape[-2:] - wcs = slit.meta.wcs - dispaxis = slit.meta.wcsinfo.dispersion_direction - else: - shape = input_model.data.shape[-2:] - wcs = input_model.meta.wcs - dispaxis = input_model.meta.wcsinfo.dispersion_direction - - bb = wcs.bounding_box # ((x0, x1), (y0, y1)) - if bb is None: - bb = wcs_bbox_from_shape(shape) - - if dispaxis == HORIZONTAL: - # Width (height) in the cross-dispersion direction, from the start of - # the 2-D cutout (or of the full image) to the upper limit of the bounding box. - # This may be smaller than the full width of the image, but it's all we - # need to consider. - xd_width = int(round(bb[1][1])) # must be an int - middle = int((bb[0][0] + bb[0][1]) / 2.) # Middle of the bounding_box in the dispersion direction. - x = np.empty(xd_width, dtype=np.float64) - x[:] = float(middle) - y = np.arange(xd_width, dtype=np.float64) - lower = bb[1][0] - upper = bb[1][1] - else: # dispaxis = VERTICAL - xd_width = int(round(bb[0][1])) # Cross-dispersion total width of bounding box; must be an int - middle = int((bb[1][0] + bb[1][1]) / 2.) # Mid-point of width along dispersion direction - x = np.arange(xd_width, dtype=np.float64) # 1-D vector of cross-dispersion (x) pixel indices - y = np.empty(xd_width, dtype=np.float64) # 1-D vector all set to middle y index - y[:] = float(middle) - - # lower and upper range in cross-dispersion direction - lower = bb[0][0] - upper = bb[0][1] - - # Get the wavelengths for the valid data in the sky transform, - # average to get the middle wavelength - fwd_transform = wcs(x, y) - middle_wl = np.nanmean(fwd_transform[2]) - - exp_type = input_model.meta.exposure.type - trace = None - if exp_type in ['NRS_FIXEDSLIT', 'NRS_MSASPEC', 'NRS_BRIGHTOBJ']: - log.info("Using source_xpos and source_ypos to center extraction.") - if slit is None: - xpos = input_model.source_xpos - ypos = input_model.source_ypos - else: - xpos = slit.source_xpos - ypos = slit.source_ypos - - slit2det = wcs.get_transform('slit_frame', 'detector') - if 'gwa' in wcs.available_frames: - # Input is not resampled, wavelengths need to be meters - _, location = slit2det(xpos, ypos, middle_wl * 1e-6) - else: - _, location = slit2det(xpos, ypos, middle_wl) - - if ~np.isnan(location): - trace = _nirspec_trace_from_wcs(shape, bb, wcs, xpos, ypos) - - elif exp_type == 'MIR_LRS-FIXEDSLIT': - log.info("Using dithered_ra and dithered_dec to center extraction.") - try: - if slit is None: - dithra = input_model.meta.dither.dithered_ra - dithdec = input_model.meta.dither.dithered_dec - else: - dithra = slit.meta.dither.dithered_ra - dithdec = slit.meta.dither.dithered_dec - location, _ = wcs.backward_transform(dithra, dithdec, middle_wl) - - except (AttributeError, TypeError): - log.warning("Dithered pointing location not found in wcsinfo.") - return None, None, None, None - - if ~np.isnan(location): - trace = _miri_trace_from_wcs(shape, bb, wcs, dithra, dithdec) - else: - log.warning(f"Source position cannot be found for EXP_TYPE {exp_type}") - return None, None, None, None - - if np.isnan(location): - log.warning('Source position could not be determined from WCS.') - return None, None, None, None - - # If the target is at the edge of the image or at the edge of the - # non-NaN area, we can't use the WCS to find the - # location of the target spectrum. - if location < lower or location > upper: - log.warning(f"WCS implies the target is at {location:.2f}, which is outside the bounding box,") - log.warning("so we can't get spectrum location using the WCS") - return None, None, None, None - - return middle, middle_wl, location, trace - - def shift_by_offset(offset, extract_params, update_trace=True): """Shift the nominal extraction parameters by a pixel offset. @@ -1158,130 +1048,6 @@ def shift_by_offset(offset, extract_params, update_trace=True): extract_params['trace'] += offset -def _nirspec_trace_from_wcs(shape, bounding_box, wcs_ref, source_xpos, source_ypos): - """Calculate NIRSpec source trace from WCS. - - The source trace is calculated by projecting the recorded source - positions source_xpos/ypos from the NIRSpec "slit_frame" onto - detector pixels. - - Parameters - ---------- - shape : tuple of int - 2D shape for the full input data array, (ny, nx). - bounding_box : tuple - A pair of tuples, each consisting of two numbers. - Represents the range of useful pixel values in both dimensions, - ((xmin, xmax), (ymin, ymax)). - wcs_ref : `~gwcs.WCS` - WCS for the input data model, containing slit and detector - transforms. - source_xpos : float - Slit position, in the x direction, for the target. - source_ypos : float - Slit position, in the y direction, for the target. - - Returns - ------- - trace : ndarray of float - Fractional pixel positions in the y (cross-dispersion direction) - of the trace for each x (dispersion direction) pixel. - """ - x, y = grid_from_bounding_box(bounding_box) - nx = int(bounding_box[0][1] - bounding_box[0][0]) - - # Calculate the wavelengths in the slit frame because they are in - # meters for cal files and um for s2d files - d2s = wcs_ref.get_transform("detector", "slit_frame") - _, _, slit_wavelength = d2s(x,y) - - # Make an initial array of wavelengths that will cover the wavelength range of the data - wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), nx) - # Get arrays of the source position in the slit - pos_x = np.full(nx, source_xpos) - pos_y = np.full(nx, source_ypos) - - # Grab the wcs transform between the slit frame where we know the - # source position and the detector frame - s2d = wcs_ref.get_transform("slit_frame", "detector") - - # Calculate the expected center of the source trace - trace_x, trace_y = s2d(pos_x, pos_y, wave_vals) - - # Interpolate the trace to a regular pixel grid in the dispersion - # direction - interp_trace = interp1d(trace_x, trace_y, fill_value='extrapolate') - - # Get the trace position for each dispersion element - trace = interp_trace(np.arange(nx)) - - # Place the trace in the full array - full_trace = np.full(shape[1], np.nan) - x0 = int(np.ceil(bounding_box[0][0])) - full_trace[x0:x0 + nx] = trace - - return full_trace - - -def _miri_trace_from_wcs(shape, bounding_box, wcs_ref, source_ra, source_dec): - """Calculate MIRI LRS fixed slit source trace from WCS. - - The source trace is calculated by projecting the recorded source - positions dithered_ra/dec from the world frame onto detector pixels. - - Parameters - ---------- - shape : tuple of int - 2D shape for the full input data array, (ny, nx). - bounding_box : tuple - A pair of tuples, each consisting of two numbers. - Represents the range of useful pixel values in both dimensions, - ((xmin, xmax), (ymin, ymax)). - wcs_ref : `~gwcs.WCS` - WCS for the input data model, containing sky and detector - transforms, forward and backward. - source_ra : float - RA coordinate for the target. - source_dec : float - Dec coordinate for the target. - - Returns - ------- - trace : ndarray of float - Fractional pixel positions in the x (cross-dispersion direction) - of the trace for each y (dispersion direction) pixel. - """ - x, y = grid_from_bounding_box(bounding_box) - ny = int(bounding_box[1][1] - bounding_box[1][0]) - - # Calculate the wavelengths for the full array - _, _, slit_wavelength = wcs_ref(x, y) - - # Make an initial array of wavelengths that will cover the wavelength range of the data - wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), ny) - - # Get arrays of the source position - pos_ra = np.full(ny, source_ra) - pos_dec = np.full(ny, source_dec) - - # Calculate the expected center of the source trace - trace_x, trace_y = wcs_ref.backward_transform(pos_ra, pos_dec, wave_vals) - - # Interpolate the trace to a regular pixel grid in the dispersion - # direction - interp_trace = interp1d(trace_y, trace_x, fill_value='extrapolate') - - # Get the trace position for each dispersion element within the bounding box - trace = interp_trace(np.arange(ny)) - - # Place the trace in the full array - full_trace = np.full(shape[0], np.nan) - y0 = int(np.ceil(bounding_box[1][0])) - full_trace[y0:y0 + ny] = trace - - return full_trace - - def define_aperture(input_model, slit, extract_params, exp_type): """Define an extraction aperture from input parameters. @@ -1322,6 +1088,11 @@ def define_aperture(input_model, slit, extract_params, exp_type): `bg_profile` is a 2D image containing pixel weights for background regions, to be fit during extraction. Otherwise, `bg_profile` is None. + nod_profile : ndarray of float or None + For optimal extraction, if nod subtraction was performed, a + second spatial profile is generated, modeling the negative source + in the slit. This second spatial profile is returned in `nod_profile` + if generated. Otherwise, `nod_profile` is None. limits : tuple of float Index limit values for the aperture, returned as (lower_limit, upper_limit, left_limit, right_limit). Upper/lower limits are along the @@ -1343,7 +1114,6 @@ def define_aperture(input_model, slit, extract_params, exp_type): if extract_params['use_source_posn']: # Source location from WCS middle_pix, middle_wl, location, trace = location_from_wcs(input_model, slit) - if location is not None: log.info(f"Computed source location is {location:.2f}, " f"at pixel {middle_pix}, wavelength {middle_wl:.2f}") @@ -1372,11 +1142,25 @@ def define_aperture(input_model, slit, extract_params, exp_type): shift_by_offset(offset, extract_params, update_trace=True) # Make a spatial profile, including source shifts if necessary - profile, lower_limit, upper_limit = box_profile( - data_shape, extract_params, wl_array, return_limits=True) + nod_profile = None + if extract_params['extraction_type'] == 'optimal': + profiles, lower_limit, upper_limit = psf_profile( + data_model, extract_params['trace'], + wl_array, extract_params['psf'], + model_nod_pair=extract_params['model_nod_pair'], + optimize_shifts=extract_params['optimize_psf_location']) + if len(profiles) > 1: + profile, nod_profile = profiles + else: + profile = profiles[0] + else: + profile, lower_limit, upper_limit = box_profile( + data_shape, extract_params, wl_array, return_limits=True) # Make sure profile weights are zero where wavelengths are invalid profile[~np.isfinite(wl_array)] = 0.0 + if nod_profile is not None: + nod_profile[~np.isfinite(wl_array)] = 0.0 # Get the effective left and right limits from the profile weights nonzero_weight = np.where(np.sum(profile, axis=extract_params['dispaxis'] - 1) > 0)[0] @@ -1419,10 +1203,11 @@ def define_aperture(input_model, slit, extract_params, exp_type): # Return limits as a tuple with 4 elements: lower, upper, left, right limits = (lower_limit, upper_limit, left_limit, right_limit) - return ra, dec, wavelength, profile, bg_profile, limits + return ra, dec, wavelength, profile, bg_profile, nod_profile, limits -def extract_one_slit(data_model, integration, profile, bg_profile, extract_params): +def extract_one_slit(data_model, integration, profile, bg_profile, + nod_profile, extract_params): """Extract data for one slit, or spectral order, or integration. Parameters @@ -1431,23 +1216,25 @@ def extract_one_slit(data_model, integration, profile, bg_profile, extract_param The input science model. May be a single slit from a MultiSlitModel (or similar), or a single data type, like an ImageModel, SlitModel, or CubeModel. - integration : int For the case that data_model is a SlitModel or a CubeModel, `integration` is the integration number. If the integration number is not relevant (i.e. the data array is 2-D), `integration` should be -1. - profile : ndarray of float Spatial profile indicating the aperture location. Must be a 2D image matching the input, with floating point values between 0 and 1 assigning a weight to each pixel. 0 means the pixel is not used, 1 means the pixel is fully included in the aperture. - bg_profile : ndarray of float or None Background profile indicating any background regions to use, following the same format as the spatial profile. Ignored if extract_params['subtract_background'] is False. - + nod_profile : ndarray of float or None + For optimal extraction, if nod subtraction was performed, a + second spatial profile is generated, modeling the negative source + in the slit. This second spatial profile may be passed in `nod_profile` + for simultaneous fitting with the primary source in `profile`. + Otherwise, `nod_profile` should be None. extract_params : dict Parameters read from the extract1d reference file, as returned by `get_extract_parameters`. @@ -1464,44 +1251,36 @@ def extract_one_slit(data_model, integration, profile, bg_profile, extract_param point source (column "flux"). Divide `sum_flux` by `npixels` (to compute the average) to get the array for the "surf_bright" (surface brightness) output column. - f_var_rnoise : ndarray, 1-D The extracted read noise variance values to go along with the sum_flux array. - f_var_poisson : ndarray, 1-D The extracted poisson variance values to go along with the sum_flux array. - f_var_flat : ndarray, 1-D The extracted flat field variance values to go along with the sum_flux array. - background : ndarray, 1-D The background count rate that was subtracted from the sum of the source data values to get `sum_flux`. - b_var_rnoise : ndarray, 1-D The extracted read noise variance values to go along with the background array. - b_var_poisson : ndarray, 1-D The extracted poisson variance values to go along with the background array. - b_var_flat : ndarray, 1-D The extracted flat field variance values to go along with the background array. - npixels : ndarray, 1-D, float64 The number of pixels that were added together to get `sum_flux`, including any fractional pixels included via non-integer weights in the input profile. - scene_model : ndarray, 2-D, float64 A 2D model of the flux in the spectral image, corresponding to the extracted aperture. - + residual : ndarray, 2-D, float64 + Residual image from the input minus the scene model. """ # Get the data and variance arrays if integration > -1: @@ -1526,11 +1305,13 @@ def extract_one_slit(data_model, integration, profile, bg_profile, extract_param # Transpose data for extraction if extract_params['dispaxis'] == HORIZONTAL: - profile_view = profile bg_profile_view = bg_profile + if nod_profile is not None: + profiles = [profile, nod_profile] + else: + profiles = [profile] else: data = data.T - profile_view = profile.T var_rnoise = var_rnoise.T var_poisson = var_poisson.T var_flat = var_flat.T @@ -1538,9 +1319,13 @@ def extract_one_slit(data_model, integration, profile, bg_profile, extract_param bg_profile_view = bg_profile.T else: bg_profile_view = None + if nod_profile is not None: + profiles = [profile.T, nod_profile.T] + else: + profiles = [profile.T] # Extract spectra from the data - result = extract1d.extract1d(data, [profile_view], var_rnoise, var_poisson, var_flat, + result = extract1d.extract1d(data, profiles, var_rnoise, var_poisson, var_flat, profile_bg=bg_profile_view, bg_smooth_length=extract_params['smoothing_length'], fit_bkg=extract_params['subtract_background'], @@ -1558,17 +1343,21 @@ def extract_one_slit(data_model, integration, profile, bg_profile, extract_param # of the number of input profiles. It may need to be transposed to match # the input data. scene_model = result[-1] + residual = data - scene_model if extract_params['dispaxis'] == HORIZONTAL: first_result.append(scene_model) + first_result.append(residual) else: first_result.append(scene_model.T) + first_result.append(residual.T) return first_result def create_extraction(input_model, slit, output_model, extract_ref_dict, slitname, sp_order, exp_type, apcorr_ref_model=None, log_increment=50, - save_profile=False, save_scene_model=False, **kwargs): + save_profile=False, save_scene_model=False, + save_residual_image=False, **kwargs): """Extract spectra from an input model and append to an output model. Input data, specified in the `slit` or `input_model`, should contain data @@ -1642,6 +1431,9 @@ def create_extraction(input_model, slit, output_model, save_scene_model : bool, optional If True, the flux model created during extraction will be returned as an ImageModel or CubeModel. If False, the return value is None. + save_residual_image : bool, optional + If True, the residual image (from input minus scene model) will be returned + as an ImageModel or CubeModel. If False, the return value is None. kwargs : dict, optional Additional options to pass to `get_extract_parameters`. @@ -1655,6 +1447,9 @@ def create_extraction(input_model, slit, output_model, If `save_scene_model` is True, the return value is an ImageModel or CubeModel matching the input data, containing the flux model generated during extraction. + residual : ImageModel, CubeModel, or None + If `save_residual_image` is True, the return value is an ImageModel or CubeModel + matching the input data, containing the residual image. """ if slit is None: @@ -1733,7 +1528,7 @@ def create_extraction(input_model, slit, output_model, # Set up spatial profiles and wavelength array, # to be used for every integration - (ra, dec, wavelength, profile, bg_profile, limits) = define_aperture( + (ra, dec, wavelength, profile, bg_profile, nod_profile, limits) = define_aperture( input_model, slit, extract_params, exp_type) valid = np.isfinite(wavelength) @@ -1744,7 +1539,10 @@ def create_extraction(input_model, slit, output_model, # Save the profile if desired if save_profile: - profile_model = datamodels.ImageModel(profile) + if nod_profile is not None: + profile_model = datamodels.ImageModel(profile + nod_profile) + else: + profile_model = datamodels.ImageModel(profile) profile_model.update(input_model, only='PRIMARY') profile_model.name = slitname else: @@ -1790,7 +1588,7 @@ def create_extraction(input_model, slit, output_model, integrations = range(shape[0]) progress_msg_printed = False - # Set up a flux model to update if desired + # Set up a scene model and residual image to update if desired if save_scene_model: if len(integrations) > 1: scene_model = datamodels.CubeModel(shape) @@ -1800,31 +1598,53 @@ def create_extraction(input_model, slit, output_model, scene_model.name = slitname else: scene_model = None + if save_residual_image: + if len(integrations) > 1: + residual = datamodels.CubeModel(shape) + else: + residual = datamodels.ImageModel() + residual.update(input_model, only='PRIMARY') + residual.name = slitname + else: + residual = None # Extract each integration for integ in integrations: (sum_flux, f_var_rnoise, f_var_poisson, f_var_flat, background, b_var_rnoise, b_var_poisson, - b_var_flat, npixels, scene_model_2d) = extract_one_slit( - data_model, integ, profile, bg_profile, extract_params) + b_var_flat, npixels, scene_model_2d, residual_2d) = extract_one_slit( + data_model, integ, profile, bg_profile, nod_profile, extract_params) - # Save the flux model + # Save the scene model and residual if save_scene_model: if isinstance(scene_model, datamodels.CubeModel): scene_model.data[integ] = scene_model_2d else: scene_model.data = scene_model_2d + if save_residual_image: + if isinstance(residual, datamodels.CubeModel): + residual.data[integ] = residual_2d + else: + residual.data = residual_2d # Convert the sum to an average, for surface brightness. npixels_temp = np.where(npixels > 0., npixels, 1.) - surf_bright = sum_flux / npixels_temp # may be reset below - sb_var_poisson = f_var_poisson / npixels_temp / npixels_temp - sb_var_rnoise = f_var_rnoise / npixels_temp / npixels_temp - sb_var_flat = f_var_flat / npixels_temp / npixels_temp + npixels_squared = npixels_temp ** 2 + if extract_params['extraction_type'] == 'optimal': + # surface brightness makes no sense for an optimal extraction + surf_bright = np.zeros_like(sum_flux) + sb_var_poisson = np.zeros_like(sum_flux) + sb_var_rnoise = np.zeros_like(sum_flux) + sb_var_flat = np.zeros_like(sum_flux) + else: + surf_bright = sum_flux / npixels_temp # may be reset below + sb_var_poisson = f_var_poisson / npixels_squared + sb_var_rnoise = f_var_rnoise / npixels_squared + sb_var_flat = f_var_flat / npixels_squared background /= npixels_temp - b_var_poisson = b_var_poisson / npixels_temp / npixels_temp - b_var_rnoise = b_var_rnoise / npixels_temp / npixels_temp - b_var_flat = b_var_flat / npixels_temp / npixels_temp + b_var_poisson = b_var_poisson / npixels_squared + b_var_rnoise = b_var_rnoise / npixels_squared + b_var_flat = b_var_flat / npixels_squared del npixels_temp @@ -1849,9 +1669,9 @@ def create_extraction(input_model, slit, output_model, sb_var_rnoise[:] = 0. sb_var_flat[:] = 0. background[:] /= pixel_solid_angle # MJy / sr - b_var_poisson = b_var_poisson / pixel_solid_angle / pixel_solid_angle - b_var_rnoise = b_var_rnoise / pixel_solid_angle / pixel_solid_angle - b_var_flat = b_var_flat / pixel_solid_angle / pixel_solid_angle + b_var_poisson = b_var_poisson / pixel_solid_angle ** 2 + b_var_rnoise = b_var_rnoise / pixel_solid_angle ** 2 + b_var_flat = b_var_flat / pixel_solid_angle ** 2 else: flux = sum_flux * pixel_solid_angle * 1.e6 # MJy / steradian --> Jy f_var_poisson *= (pixel_solid_angle ** 2 * 1.e12) # (MJy / sr)**2 --> Jy**2 @@ -1964,14 +1784,17 @@ def create_extraction(input_model, slit, output_model, if not progress_msg_printed: log.info(f"All {input_model.data.shape[0]} integrations done") - return profile_model, scene_model + return profile_model, scene_model, residual def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, + psf_ref_name="N/A", extraction_type="box", smoothing_length=None, bkg_fit=None, bkg_order=None, log_increment=50, subtract_background=None, use_source_posn=None, position_offset=0.0, - save_profile=False, save_scene_model=False): + model_nod_pair=False, optimize_psf_location=True, + save_profile=False, save_scene_model=False, + save_residual_image=False): """Extract all 1-D spectra from an input model. Parameters @@ -1982,6 +1805,11 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, The name of the extract1d reference file, or "N/A". apcorr_ref_name : str or None Name of the APCORR reference file. Default is None + psf_ref_name : str + The name of the PSF reference file, or "N/A". + extraction_type : str + Extraction type ('box' or 'optimal'). Optimal extraction is + only available if `psf_ref_name` is not "N/A". smoothing_length : int or None Width of a boxcar function for smoothing the background regions. bkg_fit : str or None @@ -2009,6 +1837,16 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, position_offset : float Number of pixels to shift the nominal source position in the cross-dispersion direction. + model_nod_pair : bool + If True, and if `extraction_type` is 'optimal', then a negative trace + from nod subtraction is modeled alongside the positive source during + extraction. Even if set to True, this will be attempted only if the + input data has been background subtracted and the dither pattern + indicates that only 2 nods were used. + optimize_psf_location : bool + If True, and if `extraction_type` is 'optimal', then the source + location will be optimized, via iterative comparisons of the scene + model with the input data. save_profile : bool If True, the spatial profiles created for the input model will be returned as ImageModels. If False, the return value is None. @@ -2016,6 +1854,10 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, If True, a model of the 2D flux as defined by the extraction aperture is returned as an ImageModel or CubeModel. If False, the return value is None. + save_residual_image : bool + If True, the residual image (from the input minus the scene model) + is returned as an ImageModel or CubeModel. If False, the return value + is None. Returns ------- @@ -2030,6 +1872,10 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, If `save_scene_model` is True, the return value is an ImageModel or CubeModel matching the input data, containing a model of the flux as defined by the aperture, created during extraction. Otherwise, the return value is None. + residual : ModelContainer, ImageModel, CubeModel, or None + If `save_residual_image` is True, the return value is an ImageModel or CubeModel + matching the input data, containing the residual image (from the input minus + the scene model). Otherwise, the return value is None. """ # Set "meta_source" to either the first model in a container, # or the individual input model, for convenience @@ -2045,10 +1891,20 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, # Read in the extract1d reference file. extract_ref_dict = read_extract1d_ref(extract_ref_name) + # Check for non-null PSF reference file + if psf_ref_name == 'N/A': + if extraction_type != 'box': + log.warning(f'Optimal extraction is not available for EXP_TYPE {exp_type}') + log.warning('Defaulting to box extraction.') + extraction_type = 'box' + # Read in the aperture correction reference file apcorr_ref_model = None if apcorr_ref_name is not None and apcorr_ref_name != 'N/A': - apcorr_ref_model = read_apcorr_ref(apcorr_ref_name, exp_type) + if extraction_type == 'optimal': + log.warning("Turning off aperture correction for optimal extraction") + else: + apcorr_ref_model = read_apcorr_ref(apcorr_ref_name, exp_type) # Set up the output model output_model = datamodels.MultiSpecModel() @@ -2064,6 +1920,7 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, # Handle inputs that contain one or more slit models profile_model = None scene_model = None + residual = None if isinstance(input_model, (ModelContainer, datamodels.MultiSlitModel)): if isinstance(input_model, ModelContainer): slits = input_model @@ -2079,6 +1936,8 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, profile_model = ModelContainer() if save_scene_model: scene_model = ModelContainer() + if save_residual_image: + residual = ModelContainer() for slit in slits: # Loop over the slits in the input model log.info(f'Working on slit {slit.name}') @@ -2097,16 +1956,21 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, continue try: - profile, slit_scene_model = create_extraction( + profile, slit_scene_model, slit_residual = create_extraction( meta_source, slit, output_model, extract_ref_dict, slitname, sp_order, exp_type, apcorr_ref_model=apcorr_ref_model, log_increment=log_increment, save_profile=save_profile, save_scene_model=save_scene_model, + save_residual_image=save_residual_image, + psf_ref_name=psf_ref_name, + extraction_type=extraction_type, smoothing_length=smoothing_length, bkg_fit=bkg_fit, bkg_order=bkg_order, subtract_background=subtract_background, use_source_posn=use_source_posn, - position_offset=position_offset) + position_offset=position_offset, + model_nod_pair=model_nod_pair, + optimize_psf_location=optimize_psf_location) except ContinueError: continue @@ -2114,46 +1978,22 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, profile_model.append(profile) if save_scene_model: scene_model.append(slit_scene_model) + if save_residual_image: + residual.append(slit_residual) else: # Define source of metadata slit = None - # This default value for slitname is not really a slit name. - # It may be assigned a better value below, in the sections for - # ImageModel or SlitModel. + # Default the slitname to the exp_type, in case there's no + # better value available slitname = exp_type - if isinstance(input_model, datamodels.ImageModel): + # Get the slit name from the input model if hasattr(input_model, "name") and input_model.name is not None: slitname = input_model.name - sp_order = get_spectral_order(input_model) - if sp_order == 0 and not prism_mode: - log.info("Spectral order 0 is a direct image, skipping ...") - else: - log.info(f'Processing spectral order {sp_order}') - try: - profile_model, scene_model = create_extraction( - input_model, slit, output_model, - extract_ref_dict, slitname, sp_order, exp_type, - apcorr_ref_model=apcorr_ref_model, log_increment=log_increment, - save_profile=save_profile, save_scene_model=save_scene_model, - smoothing_length=smoothing_length, - bkg_fit=bkg_fit, bkg_order=bkg_order, - subtract_background=subtract_background, - use_source_posn=use_source_posn, - position_offset=position_offset) - except ContinueError: - pass - elif isinstance(input_model, (datamodels.CubeModel, datamodels.SlitModel)): - # This branch will be invoked for inputs that are a CubeModel, which typically includes - # NIRSpec BrightObj (fixed slit) mode, as well as inputs that are a - # single SlitModel, which typically includes data from a single resampled/combined slit - # instance from level-3 processing of NIRSpec fixed slits and MOS modes. - - # Replace the default value for slitname with a more accurate value, if possible. # For NRS_BRIGHTOBJ, the slit name comes from the slit model info if exp_type == 'NRS_BRIGHTOBJ' and hasattr(input_model, "name"): slitname = input_model.name @@ -2166,30 +2006,34 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, else: slitname = input_model.meta.instrument.fixed_slit - sp_order = get_spectral_order(input_model) - if sp_order == 0 and not prism_mode: - log.info("Spectral order 0 is a direct image, skipping ...") - else: - log.info(f'Processing spectral order {sp_order}') - - try: - profile_model, scene_model = create_extraction( - input_model, slit, output_model, - extract_ref_dict, slitname, sp_order, exp_type, - apcorr_ref_model=apcorr_ref_model, log_increment=log_increment, - save_profile=save_profile, save_scene_model=save_scene_model, - smoothing_length=smoothing_length, - bkg_fit=bkg_fit, bkg_order=bkg_order, - subtract_background=subtract_background, - use_source_posn=use_source_posn, - position_offset=position_offset) - except ContinueError: - pass - else: log.error("The input file is not supported for this step.") raise RuntimeError("Can't extract a spectrum from this file.") + sp_order = get_spectral_order(input_model) + if sp_order == 0 and not prism_mode: + log.info("Spectral order 0 is a direct image, skipping ...") + else: + log.info(f'Processing spectral order {sp_order}') + try: + profile_model, scene_model, residual = create_extraction( + input_model, slit, output_model, + extract_ref_dict, slitname, sp_order, exp_type, + apcorr_ref_model=apcorr_ref_model, log_increment=log_increment, + save_profile=save_profile, save_scene_model=save_scene_model, + save_residual_image=save_residual_image, + psf_ref_name=psf_ref_name, + extraction_type=extraction_type, + smoothing_length=smoothing_length, + bkg_fit=bkg_fit, bkg_order=bkg_order, + subtract_background=subtract_background, + use_source_posn=use_source_posn, + position_offset=position_offset, + model_nod_pair=model_nod_pair, + optimize_psf_location=optimize_psf_location) + except ContinueError: + pass + # Copy the integration time information from the INT_TIMES table to keywords in the output file. if pipe_utils.is_tso(input_model): populate_time_keywords(input_model, output_model) @@ -2208,4 +2052,4 @@ def run_extract1d(input_model, extract_ref_name="N/A", apcorr_ref_name=None, # x1d product just to hold this keyword. output_model.meta.target.source_type = None - return output_model, profile_model, scene_model + return output_model, profile_model, scene_model, residual diff --git a/jwst/extract_1d/extract1d.py b/jwst/extract_1d/extract1d.py index 46ceee056a..0cefe516eb 100644 --- a/jwst/extract_1d/extract1d.py +++ b/jwst/extract_1d/extract1d.py @@ -285,18 +285,24 @@ def _optimal_extract( if order > -1: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value") + warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero") - wgt_nobkg = [profiles_2d[i] * weights / np.sum(profiles_2d[i] ** 2 * weights, axis=0) + wgt_nobkg = [profiles_2d[i] * weights + / np.sum(profiles_2d[i] ** 2 * weights, axis=0) for i in range(nobjects)] - bkg = np.array([np.sum(wgt_nobkg[i] * bkg_2d, axis=0) for i in range(nobjects)]) - - var_bkg_rn = np.array([var_rn[i] - np.sum(wgt_nobkg[i] ** 2 * variance_rn, axis=0) - for i in range(nobjects)]) - var_bkg_phnoise = np.array([var_phnoise[i] - np.sum(wgt_nobkg[i] ** 2 * variance_phnoise, axis=0) - for i in range(nobjects)]) - var_bkg_flat = np.array([var_flat[i] - np.sum(wgt_nobkg[i] ** 2 * variance_flat, axis=0) - for i in range(nobjects)]) + bkg = np.array([np.sum(wgt_nobkg[i] * bkg_2d, axis=0) for i in range(nobjects)]) + + # Avoid overflow in squaring weights by multiplying by variance first + var_bkg_rn = np.array( + [var_rn[i] - np.sum(variance_rn * wgt_nobkg[i] * wgt_nobkg[i], axis=0) + for i in range(nobjects)]) + var_bkg_phnoise = np.array( + [var_phnoise[i] - np.sum(variance_phnoise * wgt_nobkg[i] * wgt_nobkg[i], axis=0) + for i in range(nobjects)]) + var_bkg_flat = np.array( + [var_flat[i] - np.sum(variance_flat * wgt_nobkg[i] * wgt_nobkg[i], axis=0) + for i in range(nobjects)]) # Make sure background values are finite bkg[~np.isfinite(bkg)] = 0.0 diff --git a/jwst/extract_1d/extract_1d_step.py b/jwst/extract_1d/extract_1d_step.py index 3738ee51fc..08970c47e7 100644 --- a/jwst/extract_1d/extract_1d_step.py +++ b/jwst/extract_1d/extract_1d_step.py @@ -1,11 +1,11 @@ +import crds from stdatamodels.jwst import datamodels from jwst.datamodels import ModelContainer, SourceModelContainer - -from ..stpipe import Step -from . import extract -from .soss_extract import soss_extract -from .ifu import ifu_extract1d +from jwst.stpipe import Step +from jwst.extract_1d import extract +from jwst.extract_1d.soss_extract import soss_extract +from jwst.extract_1d.ifu import ifu_extract1d __all__ = ["Extract1dStep"] @@ -25,15 +25,36 @@ class Extract1dStep(Step): Switch to select whether to apply an APERTURE correction during the Extract1dStep. Default is True. + extraction_type : str or None + If 'box', a standard extraction is performed, summing over an + aperture box. If 'optimal', a PSF-based extraction is performed. + If None, optimal extraction is attempted whenever use_source_posn is + True. Currently, optimal extraction is only available for MIRI LRS + Fixed Slit data. + use_source_posn : bool or None - If True, the source and background extraction positions specified in - the extract1d reference file (or the default position, if there is no - reference file) will be shifted to account for the computed position - of the source in the data. If None (the default), the values in the - reference file will be used. Aperture offset is determined by computing - the pixel location of the source based on its RA and Dec. It does not - make sense to apply aperture offsets for extended sources, so this - parameter can be overridden (set to False) internally by the step. + If True, the source and background extraction regions specified in + the extract1d reference file will be shifted to account for the computed + position of the source in the data. If None (the default), this parameter + is set to True for point sources in NIRSpec and MIRI LRS fixed slit modes. + + position_offset : float + Number of pixels to offset the source and background extraction regions + in the cross-dispersion direction. This is intended to allow a manual + tweak to the aperture defined via reference file; the default value is 0.0. + + model_nod_pair : bool + If True, and the extraction type is 'optimal', then a negative trace + from nod subtraction is modeled alongside the positive source during + extraction. Even if set to True, this will be attempted only if the + input data has been background subtracted and the dither pattern + indicates that only 2 nods were used. + + optimize_psf_location : bool + If True, and the extraction type is 'optimal', then the placement of + the PSF model for the source location (and negative nod, if present) + will be iteratively optimized. This parameter is recommended if + negative nods are modeled. smoothing_length : int or None If not None, the background regions (if any) will be smoothed @@ -72,6 +93,10 @@ class Extract1dStep(Step): If True, a model of the 2D flux as defined by the extraction aperture is saved to disk. Ignored for IFU and NIRISS SOSS extractions. + save_residual_image : bool + If True, the residual image (from the input minus the scene model) + is saved to disk. Ignored for IFU and NIRISS SOSS extractions. + center_xy : int or None A list of 2 pixel coordinate values at which to place the center of the IFU extraction aperture, overriding any centering done by the step. @@ -161,14 +186,18 @@ class Extract1dStep(Step): subtract_background = boolean(default=None) # subtract background? apply_apcorr = boolean(default=True) # apply aperture corrections? + extraction_type = option("box", "optimal", None, default="box") # Perform box or optimal extraction use_source_posn = boolean(default=None) # use source coords to center extractions? position_offset = float(default=0) # number of pixels to shift source trace in the cross-dispersion direction + model_nod_pair = boolean(default=True) # For optimal extraction, model a negative nod if possible + optimize_psf_location = boolean(default=True) # For optimal extraction, optimize source location smoothing_length = integer(default=None) # background smoothing size bkg_fit = option("poly", "mean", "median", None, default=None) # background fitting type bkg_order = integer(default=None, min=0) # order of background polynomial fit log_increment = integer(default=50) # increment for multi-integration log messages save_profile = boolean(default=False) # save spatial profile to disk save_scene_model = boolean(default=False) # save flux model to disk + save_residual_image = boolean(default=False) # save residual image to disk center_xy = float_list(min=2, max=2, default=None) # IFU extraction x/y center ifu_autocen = boolean(default=False) # Auto source centering for IFU point source data. @@ -192,7 +221,8 @@ class Extract1dStep(Step): soss_modelname = output_file(default = None) # Filename for optional model output of traces and pixel weights """ - reference_file_types = ['extract1d', 'apcorr', 'pastasoss', 'specprofile', 'speckernel'] + reference_file_types = ['extract1d', 'apcorr', 'pastasoss', 'specprofile', + 'speckernel', 'psf'] def _get_extract_reference_files_by_mode(self, model, exp_type): """Get extraction reference files with special handling by exposure type.""" @@ -213,7 +243,14 @@ def _get_extract_reference_files_by_mode(self, model, exp_type): if apcorr_ref != 'N/A': self.log.info(f'Using APCORR file {apcorr_ref}') - return extract_ref, apcorr_ref + try: + psf_ref = self.get_reference_file(model, 'psf') + except crds.core.exceptions.CrdsLookupError: + psf_ref = 'N/A' + if psf_ref != 'N/A': + self.log.info(f'Using PSF reference file {psf_ref}') + + return extract_ref, apcorr_ref, psf_ref def _extract_soss(self, model): """Extract NIRISS SOSS spectra.""" @@ -318,18 +355,26 @@ def _extract_ifu(self, model, exp_type, extract_ref, apcorr_ref): ) return result - def _save_intermediate(self, intermediate_model, suffix): + def _save_intermediate(self, intermediate_model, suffix, idx): """Save an intermediate output file.""" if isinstance(intermediate_model, ModelContainer): - # Save the profile with the slit name + suffix 'profile' + # Save the profile with the slit name + index + suffix 'profile' for model in intermediate_model: slit = str(model.name).lower() - output_path = self.make_output_path(suffix=f'{slit}_{suffix}') + if idx is not None: + complete_suffix = f'{slit}_{idx}_{suffix}' + else: + complete_suffix = f'{slit}_{suffix}' + output_path = self.make_output_path(suffix=complete_suffix) self.log.info(f"Saving {suffix} {output_path}") model.save(output_path) else: - # Only one profile - just use the suffix 'profile' - output_path = self.make_output_path(suffix=suffix) + # Only one profile - just use the index and suffix 'profile' + if idx is not None: + complete_suffix = f'{idx}_{suffix}' + else: + complete_suffix = suffix + output_path = self.make_output_path(suffix=complete_suffix) self.log.info(f"Saving {suffix} {output_path}") intermediate_model.save(output_path) intermediate_model.close() @@ -401,22 +446,25 @@ def process(self, input): else: result = ModelContainer() - for model in input_model: + for i, model in enumerate(input_model): # Get the reference file names - extract_ref, apcorr_ref = self._get_extract_reference_files_by_mode( - model, exp_type) + extract_ref, apcorr_ref, psf_ref = \ + self._get_extract_reference_files_by_mode(model, exp_type) profile = None scene_model = None + residual = None if isinstance(model, datamodels.IFUCubeModel): # Call the IFU specific extraction routine extracted = self._extract_ifu(model, exp_type, extract_ref, apcorr_ref) else: # Call the general extraction routine - extracted, profile, scene_model = extract.run_extract1d( + extracted, profile, scene_model, residual = extract.run_extract1d( model, extract_ref, apcorr_ref, + psf_ref, + self.extraction_type, self.smoothing_length, self.bkg_fit, self.bkg_order, @@ -424,8 +472,11 @@ def process(self, input): self.subtract_background, self.use_source_posn, self.position_offset, + self.model_nod_pair, + self.optimize_psf_location, self.save_profile, self.save_scene_model, + self.save_residual_image, ) # Set the step flag to complete in each model @@ -434,12 +485,20 @@ def process(self, input): del extracted # Save profile if needed + if len(input_model) > 1: + idx = i + else: + idx = None if self.save_profile and profile is not None: - self._save_intermediate(profile, 'profile') + self._save_intermediate(profile, 'profile', idx) # Save model if needed if self.save_scene_model and scene_model is not None: - self._save_intermediate(scene_model, 'scene_model') + self._save_intermediate(scene_model, 'scene_model', idx) + + # Save residual if needed + if self.save_residual_image and residual is not None: + self._save_intermediate(residual, 'residual', idx) # If only one result, return the model instead of the container if len(result) == 1: diff --git a/jwst/extract_1d/psf_profile.py b/jwst/extract_1d/psf_profile.py new file mode 100644 index 0000000000..f0ebf28dbb --- /dev/null +++ b/jwst/extract_1d/psf_profile.py @@ -0,0 +1,349 @@ +import logging +import numpy as np + +from scipy import ndimage, optimize +from stdatamodels.jwst.datamodels import SpecPsfModel + +from jwst.extract_1d.extract1d import extract1d +from jwst.extract_1d.source_location import ( + middle_from_wcs, nod_pair_location, trace_from_wcs) + +__all__ = ['psf_profile'] + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + +HORIZONTAL = 1 +VERTICAL = 2 +"""Dispersion direction, predominantly horizontal or vertical.""" + +NOD_PAIR_PATTERN = ['ALONG-SLIT-NOD', '2-POINT-NOD'] + + +def open_psf(psf_refname, exp_type): + """Open the PSF reference file. + + Parameters + ---------- + psf_refname : str + The name of the psf reference file. + exp_type : str + The exposure type of the data. + + Returns + ------- + psf_model : SpecPsfModel + Returns the EPSF model. + + """ + if exp_type == 'MIR_LRS-FIXEDSLIT': + # The information we read in from PSF file is: + # center_col: psf_model.meta.psf.center_col + # super sample factor: psf_model.meta.psf.subpix) + # psf : psf_model.data (2d) + # wavelength of PSF planes: psf_model.wave + psf_model = SpecPsfModel(psf_refname) + + else: + # So far, only MIRI LRS has a PSF datamodel defined. For any other + # exposure type, try to use the model MIRI LRS uses to open the input model + try: + psf_model = SpecPsfModel(psf_refname) + except (ValueError, AttributeError): + raise NotImplementedError(f'PSF file for EXP_TYPE {exp_type} ' + f'could not be read as SpecPsfModel.') from None + return psf_model + + +def _normalize_profile(profile, dispaxis): + """Normalize a spatial profile along the cross-dispersion axis.""" + if dispaxis == HORIZONTAL: + psum = np.nansum(profile, axis=0) + nz = (psum != 0) + profile[:, nz] = profile[:, nz] / psum[nz] + profile[:, ~nz] = 0.0 + else: + psum = np.nansum(profile, axis=1) + nz = (psum != 0) + profile[nz, :] = profile[nz, :] / psum[nz, None] + profile[~nz, :] = 0.0 + profile[~np.isfinite(profile)] = 0.0 + + +def _make_cutout_profile(xidx, yidx, psf_subpix, psf_data, dispaxis, + extra_shift=0.0, nod_offset=None): + """Make a spatial profile corresponding to the data cutout. + + Input index values should already contain the shift to the trace location + in the cross-dispersion direction and any wavelength shifts necessary + in the dispersion direction. + + Parameters + ---------- + xidx : ndarray of float + Index array for x values. + yidx : ndarray of float + Index array for y values. + psf_subpix : float + Scaling factor for pixel size in the PSF data. + psf_data : ndarray of float + 2D PSF model. + dispaxis : int + Dispersion axis. + extra_shift : float, optional + An extra shift for the primary trace location, to be added to the + cross-dispersion indices. + nod_offset : float, optional + If not None, a negative trace is added to the spatial profile, + with a cross-dispersion shift of `nod_offset`. + + Returns + ------- + profiles : list of ndarray of float + 2D spatial profiles containing the primary trace and, optionally, + a negative trace for a nod pair. The profiles are normalized along + the cross-dispersion axis. + """ + # Add an extra spatial shift to the primary trace + if dispaxis == HORIZONTAL: + xmap = xidx + ymap = yidx + extra_shift * psf_subpix + else: + xmap = xidx + extra_shift * psf_subpix + ymap = yidx + sprofile = ndimage.map_coordinates(psf_data, [ymap, xmap], order=1) + _normalize_profile(sprofile, dispaxis) + + if nod_offset is None: + return [sprofile] + + # Make an additional profile for the negative nod if desired + if dispaxis == HORIZONTAL: + ymap += psf_subpix * nod_offset + else: + xmap += psf_subpix * nod_offset + + nod_profile = ndimage.map_coordinates(psf_data, [ymap, xmap], order=1) + _normalize_profile(nod_profile, dispaxis) + + return [sprofile, nod_profile * -1] + + +def _profile_residual(shifts_to_optimize, cutout, cutout_var, xidx, yidx, + psf_subpix, psf_data, dispaxis, fit_bkg=True): + """Residual function to minimize for optimizing trace locations.""" + if len(shifts_to_optimize) > 1: + nod_offset = shifts_to_optimize[1] + else: + nod_offset = None + sprofiles = _make_cutout_profile(xidx, yidx, psf_subpix, psf_data, dispaxis, + extra_shift=shifts_to_optimize[0], + nod_offset=nod_offset) + extract_kwargs = {'extraction_type': 'optimal', + 'fit_bkg': fit_bkg, + 'bkg_fit_type': 'poly', + 'bkg_order': 0} + if dispaxis == HORIZONTAL: + empty_var = np.zeros_like(cutout) + result = extract1d(cutout, sprofiles, cutout_var, empty_var, empty_var, + **extract_kwargs) + model = result[-1] + else: + sprofiles = [profile.T for profile in sprofiles] + empty_var = np.zeros_like(cutout.T) + result = extract1d(cutout.T, sprofiles, cutout_var.T, empty_var, empty_var, + **extract_kwargs) + model = result[-1].T + return np.nansum((model - cutout) ** 2 / cutout_var) + + +def psf_profile(input_model, trace, wl_array, psf_ref_name, + optimize_shifts=True, model_nod_pair=True): + """Create a spatial profile from a PSF reference. + + Provides PSF-based profiles for point sources in slit-like data containing + one positive trace and, optionally, one negative trace resulting from nod + subtraction. The location of the positive trace should be provided in the + `trace` input parameter; the negative trace location will be guessed from + the input metadata. If a negative trace is modeled, it is recommended that + `optimize_shifts` also be set to True, to improve the initial guess for the + trace location. + + Parameters + ---------- + input_model : data model + This can be either the input science file or one SlitModel out of + a list of slits. + trace : ndarray or None + Array of source cross-dispersion position values, one for each + dispersion element in the input model data. If None, the source + is assumed to be at the center of the slit. + wl_array : ndarray + Array of wavelength values, matching the input model data shape, for + each pixel in the array. + psf_ref_name : str + PSF reference filename. + optimize_shifts : bool, optional + If True, the spatial location of the trace will be optimized by + minimizing the residuals in a scene model compared to the data in + the first integration of `input_model`. + model_nod_pair : bool, optional + If True, and if background subtraction has taken place, a negative + PSF will be modeled at the mirrored spatial location of the positive + trace. + + Returns + ------- + profile : ndarray + Spatial profile matching the input data. + lower_limit : int + Lower limit of the aperture in the cross-dispersion direction. + For PSF profiles, this is always set to the lower edge of the bounding box, + since the full array may have non-zero weight. + upper_limit : int + Upper limit of the aperture in the cross-dispersion direction. + For PSF profiles, this is always set to the upper edge of the bounding box, + since the full array may have non-zero weight. + """ + # Read in reference files + exp_type = input_model.meta.exposure.type + psf_model = open_psf(psf_ref_name, exp_type) + + # Get the data cutout + data_shape = input_model.data.shape[-2:] + dispaxis = input_model.meta.wcsinfo.dispersion_direction + wcs = input_model.meta.wcs + bbox = wcs.bounding_box + + y0 = int(np.ceil(bbox[1][0])) + y1 = int(np.ceil(bbox[1][1])) + x0 = int(np.ceil(bbox[0][0])) + x1 = int(np.ceil(bbox[0][1])) + if input_model.data.ndim == 3: + # use the first integration only + cutout = input_model.data[0, y0:y1, x0:x1] + cutout_var = input_model.var_rnoise[0, y0:y1, x0:x1] + else: + cutout = input_model.data[y0:y1, x0:x1] + cutout_var = input_model.var_rnoise[y0:y1, x0:x1] + cutout_wl = wl_array[y0:y1, x0:x1] + + # Get the nominal center of the cutout + middle_disp, middle_xdisp, middle_wl = middle_from_wcs(wcs, bbox, dispaxis) + + # Get the effective index into the 1D PSF wavelengths from the data wavelengths + psf_wave = psf_model.wave + sort_idx = np.argsort(psf_wave) + valid_wave = np.isfinite(psf_wave[sort_idx]) + wave_idx = np.interp(cutout_wl, psf_wave[sort_idx][valid_wave], sort_idx[valid_wave], + left=np.nan, right=np.nan) + + if trace is None: + # Don't try to model a negative pair if we don't have a trace to start + if model_nod_pair: + log.warning('Cannot model a negative nod without position information') + model_nod_pair = False + + # Set the location to the middle of the cross-dispersion + # all the way across the array + location = middle_xdisp + if dispaxis == HORIZONTAL: + trace = trace_from_wcs(exp_type, data_shape, bbox, wcs, + middle_disp, middle_xdisp, dispaxis) + else: + trace = trace_from_wcs(exp_type, data_shape, bbox, wcs, + middle_xdisp, middle_disp, dispaxis) + + else: + # Nominal location from the middle dispersion point + location = trace[int(np.round(middle_disp))] + + # Trim the trace to the data cutout + if dispaxis == HORIZONTAL: + trace = trace[x0:x1] + else: + trace = trace[y0:y1] + + # Check if we need to add a negative nod pair trace + nod_offset = None + if model_nod_pair: + nod_subtracted = str(input_model.meta.cal_step.back_sub) == 'COMPLETE' + pattype_ok = str(input_model.meta.dither.primary_type) in NOD_PAIR_PATTERN + if not nod_subtracted: + log.info('Input data was not nod-subtracted. ' + 'A negative trace will not be modeled.') + elif not pattype_ok: + log.info('Input data was not a two-point nod. ' + 'A negative trace will not be modeled.') + else: + nod_center = nod_pair_location(input_model, middle_wl) + if np.isnan(nod_center) or (np.abs(location - nod_center) < 2): + log.warning('Nod center could not be estimated from the WCS.') + log.warning('The negative nod will not be modeled.') + else: + if not optimize_shifts: + log.warning('Negative nod locations are currently approximations only.') + log.warning('PSF location optimization is recommended when ' + 'negative nods are modeled.') + nod_offset = location - nod_center + + # Get an index grid for the data cutout + cutout_shape = cutout.shape + _y, _x = np.mgrid[:cutout_shape[0], :cutout_shape[1]] + + # Scale the trace location to the subsampled psf and + # add the wavelength and spatial shifts to the coordinates to map to + psf_subpix = psf_model.meta.psf.subpix + psf_location = trace - bbox[0][0] + if dispaxis == HORIZONTAL: + psf_shift = psf_model.meta.psf.center_row - (psf_location * psf_subpix) + xidx = wave_idx + yidx = _y * psf_subpix + psf_shift + else: + psf_shift = psf_model.meta.psf.center_col - (psf_location * psf_subpix) + xidx = _x * psf_subpix + psf_shift[:, None] + yidx = wave_idx + + # If desired, add additional spatial shifts to the starting locations of + # the primary trace (and negative nod pair trace if necessary) + if optimize_shifts: + log.info('Optimizing trace locations') + if nod_offset is None: + extra_shift, = optimize.minimize( + _profile_residual, [0.0], + (cutout, cutout_var, xidx, yidx, + psf_subpix, psf_model.data, dispaxis), method='Nelder-Mead').x + else: + extra_shift, nod_offset = optimize.minimize( + _profile_residual, [0.0, nod_offset], + (cutout, cutout_var, xidx, yidx, + psf_subpix, psf_model.data, dispaxis), method='Nelder-Mead').x + location -= extra_shift + else: + extra_shift = 0.0 + + log.info(f'Centering profile on spectrum at {location:.2f}, wavelength {middle_wl:.2f}') + if nod_offset is not None: + log.info(f'Also modeling a negative trace at {location - nod_offset:.2f} ' + f'(offset: {nod_offset:.2f})') + + # Make a spatial profile from the shifted PSF data + sprofiles = _make_cutout_profile(xidx, yidx, psf_subpix, psf_model.data, + dispaxis, extra_shift=extra_shift, + nod_offset=nod_offset) + + # Make the output profile, matching the input data + output_y = _y + y0 + output_x = _x + x0 + valid = (output_y >= 0) & (output_y < y1) & (output_x >= 0) & (output_x < x1) + profiles = [] + for sprofile in sprofiles: + profile = np.full(data_shape, 0.0) + profile[output_y[valid], output_x[valid]] = sprofile[valid] + profiles.append(profile) + + if dispaxis == HORIZONTAL: + limits = (y0, y1) + else: + limits = (x0, x1) + return profiles, *limits diff --git a/jwst/extract_1d/source_location.py b/jwst/extract_1d/source_location.py new file mode 100644 index 0000000000..eec8a61ad1 --- /dev/null +++ b/jwst/extract_1d/source_location.py @@ -0,0 +1,516 @@ +import logging +import numpy as np +from gwcs.wcstools import grid_from_bounding_box +from scipy.interpolate import interp1d +from stdatamodels.jwst.transforms.models import IdealToV2V3 + +from jwst.assign_wcs.util import wcs_bbox_from_shape + + +__all__ = ['middle_from_wcs', 'location_from_wcs', 'trace_from_wcs', + 'nod_pair_location'] + +HORIZONTAL = 1 +"""Horizontal dispersion axis.""" +VERTICAL = 2 +"""Vertical dispersion axis.""" + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +def middle_from_wcs(wcs, bounding_box, dispaxis): + """Calculate the effective middle of the spectral region. + + Parameters + ---------- + wcs : `~gwcs.WCS` + WCS for the input data model, containing detector to wavelength + transforms. + bounding_box : tuple + A pair of tuples, each consisting of two numbers. + Represents the range of useful pixel values in both dimensions, + ((xmin, xmax), (ymin, ymax)). + dispaxis : int + Dispersion axis. + + Returns + ------- + middle_disp : float + Middle pixel in the dispersion axis. + middle_xdisp : float + Middle pixel in the cross-dispersion axis. + middle_wavelength : float + Wavelength at the middle pixel. + """ + if dispaxis == HORIZONTAL: + # Width (height) in the cross-dispersion direction, from the start of + # the 2-D cutout (or of the full image) to the upper limit of the bounding box. + xd_width = int(round(bounding_box[1][1])) # must be an int + + # Middle of the bounding_box in the dispersion direction. + middle_disp = (bounding_box[0][0] + bounding_box[0][1]) / 2. + x = np.full(xd_width, middle_disp) + + # 1-D vector of cross-dispersion (y) pixel indices + y = np.arange(xd_width, dtype=np.float64) + + else: + # Cross-dispersion total width of bounding box; must be an int + xd_width = int(round(bounding_box[0][1])) + + # Middle of the bounding_box in the dispersion direction. + middle_disp = (bounding_box[1][0] + bounding_box[1][1]) / 2. + y = np.full(xd_width, middle_disp) + + # 1-D vector of cross-dispersion (x) pixel indices + x = np.arange(xd_width, dtype=np.float64) + + # Get all the wavelengths at the middle dispersion element + _, _, center_wavelengths = wcs(x, y) + sort_idx = np.argsort(center_wavelengths) + valid = np.isfinite(center_wavelengths[sort_idx]) + + # Average to get the middle wavelength + middle_wavelength = np.nanmean(center_wavelengths) + + # Find the effective index in cross-dispersion coordinates for the + # averaged wavelength to get the cross-dispersion center + if dispaxis == HORIZONTAL: + if np.allclose(center_wavelengths, middle_wavelength): + middle_xdisp = np.mean(y) + else: + middle_xdisp = np.interp( + middle_wavelength, center_wavelengths[sort_idx][valid], + y[sort_idx[valid]]) + else: + if np.allclose(center_wavelengths, middle_wavelength): + middle_xdisp = np.mean(x) + else: + middle_xdisp = np.interp( + middle_wavelength, center_wavelengths[sort_idx][valid], + x[sort_idx[valid]]) + return middle_disp, middle_xdisp, middle_wavelength + + +def location_from_wcs(input_model, slit, make_trace=True): + """Get the cross-dispersion location of the spectrum, based on the WCS. + + None values will be returned if there was insufficient information + available, e.g. if the wavelength attribute or wcs function is not + defined. + + Parameters + ---------- + input_model : DataModel + The input science model containing metadata information. + slit : DataModel or None + One slit from a MultiSlitModel (or similar), or None. + The WCS and target coordinates will be retrieved from `slit` + unless `slit` is None. In that case, they will be retrieved + from `input_model`. + make_trace : bool, optional + If True, the source position will be calculated for each + dispersion element and returned in `trace`. If False, + None is returned. + + Returns + ------- + middle : int or None + Pixel coordinate in the dispersion direction within the 2-D + cutout (or the entire input image) at the middle of the WCS + bounding box. This is the point at which to determine the + nominal extraction location, in case it varies along the + spectrum. The offset will then be the difference between + `location` (below) and the nominal location. + middle_wl : float or None + The wavelength at pixel `middle`. + location : float or None + Pixel coordinate in the cross-dispersion direction within the + spectral image that is at the planned target location. + The spectral extraction region should be centered here. + trace : ndarray or None + An array of source positions, one per dispersion element, corresponding + to the location at each point in the wavelength array. If the + input data is resampled, the trace corresponds directly to the + location. If the trace could not be generated, or `make_trace` is + False, None is returned. + """ + if slit is not None: + shape = slit.data.shape[-2:] + wcs = slit.meta.wcs + dispaxis = slit.meta.wcsinfo.dispersion_direction + else: + shape = input_model.data.shape[-2:] + wcs = input_model.meta.wcs + dispaxis = input_model.meta.wcsinfo.dispersion_direction + + bb = wcs.bounding_box # ((x0, x1), (y0, y1)) + if bb is None: + bb = wcs_bbox_from_shape(shape) + if dispaxis == HORIZONTAL: + lower = bb[1][0] + upper = bb[1][1] + else: + lower = bb[0][0] + upper = bb[0][1] + + # Get the wavelengths for the valid data in the sky transform, + # average to get the middle wavelength + middle, _, middle_wl = middle_from_wcs(wcs, bb, dispaxis) + middle = int(np.round(middle)) + + exp_type = input_model.meta.exposure.type + trace = None + if exp_type in ['NRS_FIXEDSLIT', 'NRS_MSASPEC', 'NRS_BRIGHTOBJ']: + log.info("Using source_xpos and source_ypos to center extraction.") + if slit is None: + xpos = input_model.source_xpos + ypos = input_model.source_ypos + else: + xpos = slit.source_xpos + ypos = slit.source_ypos + + slit2det = wcs.get_transform('slit_frame', 'detector') + if 'gwa' in wcs.available_frames: + # Input is not resampled, wavelengths need to be meters + _, location = slit2det(xpos, ypos, middle_wl * 1e-6) + else: + _, location = slit2det(xpos, ypos, middle_wl) + + if ~np.isnan(location) and make_trace: + trace = _nirspec_trace_from_wcs(shape, bb, wcs, xpos, ypos) + + elif exp_type == 'MIR_LRS-FIXEDSLIT': + log.info("Using dithered_ra and dithered_dec to center extraction.") + try: + if slit is None: + dithra = input_model.meta.dither.dithered_ra + dithdec = input_model.meta.dither.dithered_dec + else: + dithra = slit.meta.dither.dithered_ra + dithdec = slit.meta.dither.dithered_dec + location, _ = wcs.backward_transform(dithra, dithdec, middle_wl) + + except (AttributeError, TypeError): + log.warning("Dithered pointing location not found in wcsinfo.") + return None, None, None, None + + if ~np.isnan(location) and make_trace: + trace = _miri_trace_from_wcs(shape, bb, wcs, dithra, dithdec) + else: + log.warning(f"Source position cannot be found for EXP_TYPE {exp_type}") + return None, None, None, None + + if np.isnan(location): + log.warning('Source position could not be determined from WCS.') + return None, None, None, None + + # If the target is at the edge of the image or at the edge of the + # non-NaN area, we can't use the WCS to find the + # location of the target spectrum. + if location < lower or location > upper: + log.warning(f"WCS implies the target is at {location:.2f}, which is outside the bounding box,") + log.warning("so we can't get spectrum location using the WCS") + return None, None, None, None + + return middle, middle_wl, location, trace + + +def _nirspec_trace_from_wcs(shape, bounding_box, wcs_ref, source_xpos, source_ypos): + """Calculate NIRSpec source trace from WCS. + + The source trace is calculated by projecting the recorded source + positions source_xpos/ypos from the NIRSpec "slit_frame" onto + detector pixels. + + Parameters + ---------- + shape : tuple of int + 2D shape for the full input data array, (ny, nx). + bounding_box : tuple + A pair of tuples, each consisting of two numbers. + Represents the range of useful pixel values in both dimensions, + ((xmin, xmax), (ymin, ymax)). + wcs_ref : `~gwcs.WCS` + WCS for the input data model, containing slit and detector + transforms. + source_xpos : float + Slit position, in the x direction, for the target. + source_ypos : float + Slit position, in the y direction, for the target. + + Returns + ------- + trace : ndarray of float + Fractional pixel positions in the y (cross-dispersion direction) + of the trace for each x (dispersion direction) pixel. + """ + x, y = grid_from_bounding_box(bounding_box) + nx = int(bounding_box[0][1] - bounding_box[0][0]) + + # Calculate the wavelengths in the slit frame because they are in + # meters for cal files and um for s2d files + d2s = wcs_ref.get_transform("detector", "slit_frame") + _, _, slit_wavelength = d2s(x, y) + + # Make an initial array of wavelengths that will cover the wavelength range of the data + wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), nx) + # Get arrays of the source position in the slit + pos_x = np.full(nx, source_xpos) + pos_y = np.full(nx, source_ypos) + + # Grab the wcs transform between the slit frame where we know the + # source position and the detector frame + s2d = wcs_ref.get_transform("slit_frame", "detector") + + # Calculate the expected center of the source trace + trace_x, trace_y = s2d(pos_x, pos_y, wave_vals) + + # Interpolate the trace to a regular pixel grid in the dispersion + # direction + interp_trace = interp1d(trace_x, trace_y, fill_value='extrapolate') + + # Get the trace position for each dispersion element + trace = interp_trace(np.arange(nx)) + + # Place the trace in the full array + full_trace = np.full(shape[1], np.nan) + x0 = int(np.ceil(bounding_box[0][0])) + full_trace[x0:x0 + nx] = trace + + return full_trace + + +def _miri_trace_from_wcs(shape, bounding_box, wcs_ref, source_ra, source_dec): + """Calculate MIRI LRS fixed slit source trace from WCS. + + The source trace is calculated by projecting the recorded source + positions dithered_ra/dec from the world frame onto detector pixels. + + Parameters + ---------- + shape : tuple of int + 2D shape for the full input data array, (ny, nx). + bounding_box : tuple + A pair of tuples, each consisting of two numbers. + Represents the range of useful pixel values in both dimensions, + ((xmin, xmax), (ymin, ymax)). + wcs_ref : `~gwcs.WCS` + WCS for the input data model, containing sky and detector + transforms, forward and backward. + source_ra : float + RA coordinate for the target. + source_dec : float + Dec coordinate for the target. + + Returns + ------- + trace : ndarray of float + Fractional pixel positions in the x (cross-dispersion direction) + of the trace for each y (dispersion direction) pixel. + """ + x, y = grid_from_bounding_box(bounding_box) + ny = int(bounding_box[1][1] - bounding_box[1][0]) + + # Calculate the wavelengths for the full array + _, _, slit_wavelength = wcs_ref(x, y) + + # Make an initial array of wavelengths that will cover the wavelength range of the data + wave_vals = np.linspace(np.nanmin(slit_wavelength), np.nanmax(slit_wavelength), ny) + + # Get arrays of the source position + pos_ra = np.full(ny, source_ra) + pos_dec = np.full(ny, source_dec) + + # Calculate the expected center of the source trace + trace_x, trace_y = wcs_ref.backward_transform(pos_ra, pos_dec, wave_vals) + + # Interpolate the trace to a regular pixel grid in the dispersion + # direction + interp_trace = interp1d(trace_y, trace_x, fill_value='extrapolate') + + # Get the trace position for each dispersion element within the bounding box + trace = interp_trace(np.arange(ny)) + + # Place the trace in the full array + full_trace = np.full(shape[0], np.nan) + y0 = int(np.ceil(bounding_box[1][0])) + full_trace[y0:y0 + ny] = trace + + return full_trace + + +def trace_from_wcs(exp_type, shape, bounding_box, wcs_ref, source_x, source_y, dispaxis): + """Calculate a source trace from WCS. + + The source trace is calculated by projecting a fixed source + positions onto detector pixels, to get a source location at each + dispersion element. For MIRI LRS fixed slit and NIRSpec modes, this + will be a curved trace, using the sky or slit frame as appropriate. + For all other modes, a flat trace is returned, containing the + cross-dispersion position at all dispersion elements. + + Parameters + ---------- + exp_type : str + Exposure type for the input data. + shape : tuple of int + 2D shape for the full input data array, (ny, nx). + bounding_box : tuple + A pair of tuples, each consisting of two numbers. + Represents the range of useful pixel values in both dimensions, + ((xmin, xmax), (ymin, ymax)). + wcs_ref : `~gwcs.WCS` + WCS for the input data model, containing sky and detector + transforms, forward and backward. + source_x : float + X pixel coordinate for the target. + source_y : float + Y pixel coordinate for the target. + dispaxis : int + Dispersion axis. + + Returns + ------- + trace : ndarray of float + Pixel positions in the cross-dispersion direction + of the trace for each dispersion pixel. + """ + if exp_type == 'MIR_LRS-FIXEDSLIT': + source_ra, source_dec, _ = wcs_ref(source_x, source_y) + trace = _miri_trace_from_wcs(shape, bounding_box, wcs_ref, source_ra, source_dec) + elif exp_type.startswith('NRS'): + d2s = wcs_ref.get_transform("detector", "slit_frame") + source_xpos, source_ypos, _ = d2s(source_x, source_y) + trace = _nirspec_trace_from_wcs(shape, bounding_box, wcs_ref, source_xpos, source_ypos) + else: + # Flat trace containing the cross-dispersion position at every element + if dispaxis == HORIZONTAL: + trace = np.full(shape[1], np.nan) + x0 = int(np.ceil(bounding_box[0][0])) + nx = int(bounding_box[0][1] - bounding_box[0][0]) + trace[x0:x0 + nx] = source_y + else: + trace = np.full(shape[0], np.nan) + y0 = int(np.ceil(bounding_box[1][0])) + ny = int(bounding_box[1][1] - bounding_box[1][0]) + trace[y0:y0 + ny] = source_x + + return trace + + +def _nod_pair_from_dither(input_model, middle_wl, dispaxis): + """Estimate a nod pair location from the dither offsets. + + Expected location is at the opposite spatial offset from + the input model. Requires 'v2v3' transform in the WCS, so + is only available for unresampled data. + + Parameters + ---------- + input_model : DataModel + Model containing WCS and dither data. + middle_wl : float + Wavelength at the middle of the array. + dispaxis : int + Dispersion axis. + + Returns + ------- + nod_location : float + The expected location of the negative trace, in the + cross-dispersion direction, at the middle wavelength. + """ + if 'v2v3' not in input_model.meta.wcs.available_frames: + return np.nan + + idltov23 = IdealToV2V3( + input_model.meta.wcsinfo.v3yangle, + input_model.meta.wcsinfo.v2_ref, input_model.meta.wcsinfo.v3_ref, + input_model.meta.wcsinfo.vparity + ) + + if dispaxis == HORIZONTAL: + x_offset = input_model.meta.dither.x_offset + y_offset = -input_model.meta.dither.y_offset + else: + x_offset = -input_model.meta.dither.x_offset + y_offset = input_model.meta.dither.y_offset + + dithered_v2, dithered_v3 = idltov23(x_offset, y_offset) + + # v23toworld requires a wavelength along with v2, v3, but value does not affect return + v23toworld = input_model.meta.wcs.get_transform('v2v3', 'world') + dithered_ra, dithered_dec, _ = v23toworld(dithered_v2, dithered_v3, 0.0) + + x, y = input_model.meta.wcs.backward_transform(dithered_ra, dithered_dec, middle_wl) + + if dispaxis == HORIZONTAL: + return y + else: + return x + + +def _nod_pair_from_slitpos(input_model, middle_wl): + """Estimate a nod pair location from the source slit postion. + + Expected location is at the opposite spatial position from + the input model. Requires 'slit_frame' transform in the WCS. + Implemented only for NIRSpec, assuming horizontal dispersion axis. + + Parameters + ---------- + input_model : DataModel + Model containing WCS and dither data. + middle_wl : float + Wavelength at the middle of the array. + + Returns + ------- + nod_location : float + The expected location of the negative trace, in the + cross-dispersion direction, at the middle wavelength. + """ + xpos = input_model.source_xpos + ypos = -input_model.source_ypos + wcs = input_model.meta.wcs + slit2det = wcs.get_transform('slit_frame', 'detector') + if 'gwa' in wcs.available_frames: + # Input is not resampled, wavelengths need to be meters + _, location = slit2det(xpos, ypos, middle_wl * 1e-6) + else: + _, location = slit2det(xpos, ypos, middle_wl) + return location + + +def nod_pair_location(input_model, middle_wl): + """Estimate a nod pair location from the WCS. + + For MIRI, it will guess the location from the dither offsets. + For NIRSpec, it will guess from the slit position. + For anything else, or if the estimate fails, it will return NaN + for the location. + + Parameters + ---------- + input_model : DataModel + Model containing WCS and dither data. + middle_wl : float + Wavelength at the middle of the array. + + Returns + ------- + nod_location : float + The expected location of the negative trace, in the + cross-dispersion direction, at the middle wavelength. + """ + exp_type = input_model.meta.exposure.type + nod_center = np.nan + if exp_type == 'MIR_LRS-FIXEDSLIT': + dispaxis = input_model.meta.wcsinfo.dispersion_direction + nod_center = _nod_pair_from_dither(input_model, middle_wl, dispaxis) + elif exp_type.startswith('NRS'): + nod_center = _nod_pair_from_slitpos(input_model, middle_wl) + + return nod_center diff --git a/jwst/extract_1d/tests/conftest.py b/jwst/extract_1d/tests/conftest.py index 709cfe0754..33ecdcecae 100644 --- a/jwst/extract_1d/tests/conftest.py +++ b/jwst/extract_1d/tests/conftest.py @@ -38,16 +38,21 @@ def simple_wcs_function(x, y): def get_transform(*args, **kwargs): def return_results(*args, **kwargs): if len(args) == 2: - zeros = np.zeros(args[0].shape) - wave, _ = np.meshgrid(args[0], args[1]) + try: + zeros = np.zeros(args[0].shape) + wave, _ = np.meshgrid(args[0], args[1]) + except AttributeError: + zeros = 0.0 + wave = args[0] return zeros, zeros, wave if len(args) == 3: try: nx = len(args[0]) + pix = np.arange(nx) + trace = np.ones(nx) except TypeError: - nx = 1 - pix = np.arange(nx) - trace = np.ones(nx) + pix = 0 + trace = 1.0 return pix, trace return return_results @@ -87,13 +92,22 @@ def simple_wcs_function(x, y): def backward_transform(*args, **kwargs): try: nx = len(args[0]) + pix = np.arange(nx) + trace = np.ones(nx) except TypeError: - nx = 1 - pix = np.arange(nx) - trace = np.ones(nx) + pix = 0.0 + trace = 1.0 return trace, pix + # Mock a simple forward transform, for mocking a v2v3 frame + def get_transform(*args, **kwargs): + def return_results(*args, **kwargs): + return 1.0, 1.0, 1.0 + return return_results + + simple_wcs_function.get_transform = get_transform simple_wcs_function.backward_transform = backward_transform + simple_wcs_function.available_frames = [] return simple_wcs_function @@ -229,6 +243,7 @@ def mock_miri_lrs_fs(simple_wcs_transpose): model = dm.ImageModel() model.meta.instrument.name = 'MIRI' model.meta.instrument.detector = 'MIRIMAGE' + model.meta.instrument.filter = 'P750L' model.meta.observation.date = '2023-07-22' model.meta.observation.time = '06:24:45.569' model.meta.exposure.nints = 1 @@ -471,3 +486,72 @@ def nirspec_fs_apcorr_file(tmp_path, nirspec_fs_apcorr): filename = str(tmp_path / 'nirspec_fs_apcorr.fits') nirspec_fs_apcorr.save(filename) return filename + + +@pytest.fixture() +def psf_reference(): + psf_model = dm.SpecPsfModel() + psf_model.data = np.ones((50, 50), dtype=float) + psf_model.wave = np.linspace(0, 10, 50) + psf_model.meta.psf.subpix = 1.0 + psf_model.meta.psf.center_col = 25 + psf_model.meta.psf.center_row = 25 + yield psf_model + psf_model.close() + + +@pytest.fixture() +def psf_reference_file(tmp_path, psf_reference): + filename = str(tmp_path / 'psf_reference.fits') + psf_reference.save(filename) + return filename + + +@pytest.fixture() +def psf_reference_with_source(): + psf_model = dm.SpecPsfModel() + psf_model.data = np.full((50, 50), 1e-6) + psf_model.data[:, 24:27] += 1.0 + + psf_model.wave = np.linspace(0, 10, 50) + psf_model.meta.psf.subpix = 1.0 + psf_model.meta.psf.center_col = 25 + psf_model.meta.psf.center_row = 25 + yield psf_model + psf_model.close() + + +@pytest.fixture() +def psf_reference_file_with_source(tmp_path, psf_reference_with_source): + filename = str(tmp_path / 'psf_reference_with_source.fits') + psf_reference_with_source.save(filename) + return filename + + +@pytest.fixture() +def simple_profile(): + profile = np.zeros((50, 50), dtype=np.float32) + profile[20:30, :] = 1.0 + return profile + + +@pytest.fixture() +def background_profile(): + profile = np.zeros((50, 50), dtype=np.float32) + profile[:10, :] = 1.0 + profile[40:, :] = 1.0 + return profile + + +@pytest.fixture() +def nod_profile(): + profile = np.zeros((50, 50), dtype=np.float32) + profile[10:20, :] = 1.0 / 10 + return profile + + +@pytest.fixture() +def negative_nod_profile(): + profile = np.zeros((50, 50), dtype=np.float32) + profile[30:40, :] = -1.0 / 10 + return profile diff --git a/jwst/extract_1d/tests/test_extract.py b/jwst/extract_1d/tests/test_extract.py index 9e4dbd5baa..cf34b2a9ed 100644 --- a/jwst/extract_1d/tests/test_extract.py +++ b/jwst/extract_1d/tests/test_extract.py @@ -7,6 +7,7 @@ from jwst.datamodels import ModelContainer from jwst.extract_1d import extract as ex +from jwst.extract_1d import psf_profile as pp from jwst.tests.helpers import LogWatcher @@ -54,7 +55,6 @@ def extract_defaults(): 'extraction_type': 'box', 'independent_var': 'pixel', 'match': 'exact match', - 'position_correction': 0, 'smoothing_length': 0, 'spectral_order': 1, 'src_coeff': None, @@ -62,28 +62,16 @@ def extract_defaults(): 'position_offset': 0.0, 'trace': None, 'use_source_posn': False, + 'model_nod_pair': False, + 'optimize_psf_location': False, 'xstart': 0, 'xstop': 49, 'ystart': 0, - 'ystop': 49} + 'ystop': 49, + 'psf': 'N/A'} return default -@pytest.fixture() -def simple_profile(): - profile = np.zeros((50, 50), dtype=np.float32) - profile[20:30, :] = 1.0 - return profile - - -@pytest.fixture() -def background_profile(): - profile = np.zeros((50, 50), dtype=np.float32) - profile[:10, :] = 1.0 - profile[40:, :] = 1.0 - return profile - - @pytest.fixture() def create_extraction_inputs(mock_nirspec_fs_one_slit, extract1d_ref_dict): input_model = mock_nirspec_fs_one_slit @@ -198,7 +186,7 @@ def test_get_extract_parameters_no_match( def test_get_extract_parameters_source_posn_exptype( - mock_nirspec_bots, extract1d_ref_dict, extract_defaults): + mock_nirspec_bots, extract1d_ref_dict): input_model = mock_nirspec_bots input_model.meta.exposure.type = 'NRS_LAMP' @@ -212,10 +200,10 @@ def test_get_extract_parameters_source_posn_exptype( def test_get_extract_parameters_source_posn_from_ref( - mock_nirspec_bots, extract1d_ref_dict, extract_defaults): + mock_nirspec_bots, extract1d_ref_dict): input_model = mock_nirspec_bots - # match an entry that explicity sets use_source_posn + # match an entry that explicitly sets use_source_posn params = ex.get_extract_parameters( extract1d_ref_dict, input_model, 'slit6', 1, input_model.meta, use_source_posn=None) @@ -227,8 +215,7 @@ def test_get_extract_parameters_source_posn_from_ref( @pytest.mark.parametrize('length', [3, 4, 2.8, 3.5]) def test_get_extract_parameters_smoothing( - mock_nirspec_fs_one_slit, extract1d_ref_dict, - extract_defaults, length): + mock_nirspec_fs_one_slit, extract1d_ref_dict, length): input_model = mock_nirspec_fs_one_slit params = ex.get_extract_parameters( @@ -242,8 +229,7 @@ def test_get_extract_parameters_smoothing( @pytest.mark.parametrize('length', [-1, 1, 2, 1.3]) def test_get_extract_parameters_smoothing_bad_value( - mock_nirspec_fs_one_slit, extract1d_ref_dict, - extract_defaults, length): + mock_nirspec_fs_one_slit, extract1d_ref_dict, length): input_model = mock_nirspec_fs_one_slit params = ex.get_extract_parameters( @@ -254,6 +240,44 @@ def test_get_extract_parameters_smoothing_bad_value( assert params['smoothing_length'] == 0 +@pytest.mark.parametrize('use_source', [None, True, False]) +def test_get_extract_parameters_extraction_type_none( + mock_nirspec_fs_one_slit, extract1d_ref_dict, use_source, log_watcher): + input_model = mock_nirspec_fs_one_slit + + log_watcher.message = "Using extraction type" + params = ex.get_extract_parameters( + extract1d_ref_dict, input_model, 'slit1', 1, input_model.meta, + extraction_type=None, use_source_posn=use_source, psf_ref_name='available') + log_watcher.assert_seen() + + # Extraction type is set to optimal if use_source_posn is True + if use_source is None or use_source is True: + assert params['use_source_posn'] is True + assert params['extraction_type'] == 'optimal' + else: + assert params['use_source_posn'] is False + assert params['extraction_type'] == 'box' + + +@pytest.mark.parametrize('extraction_type', [None, 'box', 'optimal']) +def test_get_extract_parameters_no_psf( + mock_nirspec_fs_one_slit, extract1d_ref_dict, extraction_type, log_watcher): + input_model = mock_nirspec_fs_one_slit + + log_watcher.message = "Setting extraction type to 'box'" + params = ex.get_extract_parameters( + extract1d_ref_dict, input_model, 'slit1', 1, input_model.meta, + extraction_type=extraction_type, psf_ref_name='N/A') + + # Warning message issued if extraction type was not already 'box' + if extraction_type != 'box': + log_watcher.assert_seen() + + # Extraction type is always box if no psf is available + assert params['extraction_type'] == 'box' + + def test_log_params(extract_defaults, log_watcher): log_watcher.message = 'Extraction parameters' @@ -854,138 +878,6 @@ def test_aperture_center_variable_weight_by_spec(middle, dispaxis): assert spec_center == middle -@pytest.mark.parametrize('resampled', [True, False]) -@pytest.mark.parametrize('is_slit', [True, False]) -@pytest.mark.parametrize('missing_bbox', [True, False]) -def test_location_from_wcs_nirspec( - monkeypatch, mock_nirspec_fs_one_slit, resampled, is_slit, missing_bbox): - model = mock_nirspec_fs_one_slit - - if not resampled: - # mock available frames, so it looks like unresampled cal data - monkeypatch.setattr(model.meta.wcs, 'available_frames', ['gwa']) - - if missing_bbox: - # mock a missing bounding box - should have same results - # for the test data - monkeypatch.setattr(model.meta.wcs, 'bounding_box', None) - - if is_slit: - middle, middle_wl, location, trace = ex.location_from_wcs(model, model) - else: - middle, middle_wl, location, trace = ex.location_from_wcs(model, None) - - # middle pixel is center of dispersion axis - assert middle == int((model.data.shape[1] - 1) / 2) - - # middle wavelength is the wavelength at that point, from the mock wcs - assert np.isclose(middle_wl, 7.74) - - # location is 1.0 - from the mocked transform function - assert location == 1.0 - - # trace is the same, in an array - assert np.all(trace == 1.0) - - -@pytest.mark.parametrize('is_slit', [True, False]) -def test_location_from_wcs_miri(monkeypatch, mock_miri_lrs_fs, is_slit): - model = mock_miri_lrs_fs - - # monkey patch in a transform for the wcs - def radec2det(*args, **kwargs): - def return_one(*args, **kwargs): - return 1.0, 0.0 - return return_one - - monkeypatch.setattr(model.meta.wcs, 'backward_transform', radec2det()) - - # mock the trace function - def mock_trace(*args, **kwargs): - return np.full(model.data.shape[-2], 1.0) - - monkeypatch.setattr(ex, '_miri_trace_from_wcs', mock_trace) - - # Get the slit center from the WCS - if is_slit: - middle, middle_wl, location, trace = ex.location_from_wcs(model, model) - else: - middle, middle_wl, location, trace = ex.location_from_wcs(model, None) - - # middle pixel is center of dispersion axis - assert middle == int((model.data.shape[0] - 1) / 2) - - # middle wavelength is the wavelength at that point, from the mock wcs - assert np.isclose(middle_wl, 7.26) - - # location is 1.0 - from the mocked transform function - assert location == 1.0 - - # trace is the same, in an array - assert np.all(trace == 1.0) - - -def test_location_from_wcs_missing_data(mock_miri_lrs_fs, log_watcher): - model = mock_miri_lrs_fs - model.meta.wcs.backward_transform = None - - # model is missing WCS information - None values are returned - log_watcher.message = "Dithered pointing location not found" - result = ex.location_from_wcs(model, None) - assert result == (None, None, None, None) - log_watcher.assert_seen() - - -def test_location_from_wcs_wrong_exptype(mock_niriss_soss, log_watcher): - # model is not a handled exposure type - log_watcher.message = "Source position cannot be found for EXP_TYPE" - result = ex.location_from_wcs(mock_niriss_soss, None) - assert result == (None, None, None, None) - log_watcher.assert_seen() - - -def test_location_from_wcs_bad_location( - monkeypatch, mock_nirspec_fs_one_slit, log_watcher): - model = mock_nirspec_fs_one_slit - - # monkey patch in a transform for the wcs - def slit2det(*args, **kwargs): - def return_one(*args, **kwargs): - return 0.0, np.nan - return return_one - - monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) - - # WCS transform returns NaN for the location - log_watcher.message = "Source position could not be determined" - result = ex.location_from_wcs(model, None) - assert result == (None, None, None, None) - log_watcher.assert_seen() - - -def test_location_from_wcs_location_out_of_range( - monkeypatch, mock_nirspec_fs_one_slit, log_watcher): - model = mock_nirspec_fs_one_slit - - # monkey patch in a transform for the wcs - def slit2det(*args, **kwargs): - def return_one(*args, **kwargs): - return 0.0, 2000 - return return_one - - monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) - - # mock the trace function - def mock_trace(*args, **kwargs): - return np.full(model.data.shape[-1], 1.0) - - monkeypatch.setattr(ex, '_nirspec_trace_from_wcs', mock_trace) - - # WCS transform a value outside the bounding box - log_watcher.message = "outside the bounding box" - result = ex.location_from_wcs(model, None) - assert result == (None, None, None, None) - log_watcher.assert_seen() def test_shift_by_offset_horizontal(extract_defaults): @@ -1055,23 +947,6 @@ def test_shift_by_offset_trace_no_update(extract_defaults): assert np.all(extract_params['trace'] == np.arange(10)) -def test_nirspec_trace_from_wcs(mock_nirspec_fs_one_slit): - model = mock_nirspec_fs_one_slit - trace = ex._nirspec_trace_from_wcs(model.data.shape, model.meta.wcs.bounding_box, - model.meta.wcs, 1.0, 1.0) - # mocked model contains some mock transforms as well - all ones are expected - assert np.all(trace == np.ones(model.data.shape[-1])) - - -def test_miri_trace_from_wcs(mock_miri_lrs_fs): - model = mock_miri_lrs_fs - trace = ex._miri_trace_from_wcs(model.data.shape, model.meta.wcs.bounding_box, - model.meta.wcs, 1.0, 1.0) - - # mocked model contains some mock transforms as well - all ones are expected - assert np.all(trace == np.ones(model.data.shape[-1])) - - @pytest.mark.parametrize('is_slit', [True, False]) def test_define_aperture_nirspec(mock_nirspec_fs_one_slit, extract_defaults, is_slit): model = mock_nirspec_fs_one_slit @@ -1082,7 +957,7 @@ def test_define_aperture_nirspec(mock_nirspec_fs_one_slit, extract_defaults, is_ slit = None exptype = 'NRS_FIXEDSLIT' result = ex.define_aperture(model, slit, extract_defaults, exptype) - ra, dec, wavelength, profile, bg_profile, limits = result + ra, dec, wavelength, profile, bg_profile, nod_profile, limits = result assert np.isclose(ra, 45.05) assert np.isclose(dec, 45.1) assert wavelength.shape == (model.data.shape[1],) @@ -1106,7 +981,7 @@ def test_define_aperture_miri(mock_miri_lrs_fs, extract_defaults, is_slit): slit = None exptype = 'MIR_LRS-FIXEDSLIT' result = ex.define_aperture(model, slit, extract_defaults, exptype) - ra, dec, wavelength, profile, bg_profile, limits = result + ra, dec, wavelength, profile, bg_profile, nod_profile, limits = result assert np.isclose(ra, 45.05) assert np.isclose(dec, 45.1) assert wavelength.shape == (model.data.shape[1],) @@ -1130,7 +1005,7 @@ def test_define_aperture_with_bg(mock_nirspec_fs_one_slit, extract_defaults): extract_defaults['bkg_coeff'] = [[-0.5], [2.5]] result = ex.define_aperture(model, slit, extract_defaults, exptype) - bg_profile = result[-2] + bg_profile = result[-3] # Bg profile has 1s in the first 3 rows assert bg_profile.shape == model.data.shape @@ -1149,7 +1024,7 @@ def test_define_aperture_empty_aperture(mock_nirspec_fs_one_slit, extract_defaul extract_defaults['ystop'] = 3000 result = ex.define_aperture(model, slit, extract_defaults, exptype) - _, _, _, profile, _, limits = result + _, _, _, profile, _, _, limits = result assert np.all(profile == 0.0) assert limits == (2000, 3000, None, None) @@ -1179,7 +1054,8 @@ def return_nan(*args): assert dec is None -def test_define_aperture_use_source(monkeypatch, mock_nirspec_fs_one_slit, extract_defaults): +def test_define_aperture_use_source( + monkeypatch, mock_nirspec_fs_one_slit, extract_defaults): model = mock_nirspec_fs_one_slit extract_defaults['dispaxis'] = 1 slit = None @@ -1196,7 +1072,7 @@ def mock_source_location(*args): extract_defaults['extract_width'] = 6.0 result = ex.define_aperture(model, slit, extract_defaults, exptype) - _, _, _, profile, _, limits = result + _, _, _, profile, _, _, limits = result assert np.all(profile[:7] == 0.0) assert np.all(profile[7:13] == 1.0) @@ -1212,7 +1088,7 @@ def test_define_aperture_extra_offset(mock_nirspec_fs_one_slit, extract_defaults extract_defaults['position_offset'] = 2.0 result = ex.define_aperture(model, slit, extract_defaults, exptype) - _, _, _, profile, _, limits = result + _, _, _, profile, _, _, limits = result assert profile.shape == model.data.shape # Default profile is shifted 2 pixels up @@ -1221,6 +1097,87 @@ def test_define_aperture_extra_offset(mock_nirspec_fs_one_slit, extract_defaults assert limits == (2, model.data.shape[0] + 1, 0, model.data.shape[1] - 1) +def test_define_aperture_optimal(mock_miri_lrs_fs, extract_defaults, psf_reference_file): + model = mock_miri_lrs_fs + extract_defaults['dispaxis'] = 2 + slit = None + exptype = 'MIR_LRS-FIXEDSLIT' + + # set parameters for optimal extraction + extract_defaults['extraction_type'] = 'optimal' + extract_defaults['use_source_posn'] = True + extract_defaults['psf'] = psf_reference_file + + result = ex.define_aperture(model, slit, extract_defaults, exptype) + _, _, _, profile, bg_profile, nod_profile, limits = result + + assert bg_profile is None + assert nod_profile is None + + # profile is normalized along cross-dispersion + assert np.allclose(np.sum(profile, axis=1), 1.0) + + # trace is centered on 1.0, near the edge of the slit, + # and the psf data has the same size as the array (50x50), + # so only half the psf is included + npix = 26 + assert np.all(np.sum(profile != 0, axis=1) == npix) + + # psf is uniform when in range, 0 otherwise + assert np.allclose(profile[:, :npix], 1 / npix) + assert np.allclose(profile[:, npix:], 0.0) + + +def test_define_aperture_optimal_with_nod( + monkeypatch, mock_miri_lrs_fs, extract_defaults, psf_reference_file): + model = mock_miri_lrs_fs + extract_defaults['dispaxis'] = 2 + slit = None + exptype = 'MIR_LRS-FIXEDSLIT' + + # mock nod subtraction + mock_miri_lrs_fs.meta.cal_step.back_sub = 'COMPLETE' + mock_miri_lrs_fs.meta.dither.primary_type = 'ALONG-SLIT-NOD' + + # mock a nod position at the opposite end of the array + def mock_nod(*args, **kwargs): + return 48.0 + + monkeypatch.setattr(pp, 'nod_pair_location', mock_nod) + + # set parameters for optimal extraction + extract_defaults['extraction_type'] = 'optimal' + extract_defaults['use_source_posn'] = True + extract_defaults['psf'] = psf_reference_file + extract_defaults['model_nod_pair'] = True + + result = ex.define_aperture(model, slit, extract_defaults, exptype) + _, _, _, profile, bg_profile, nod_profile, limits = result + + assert bg_profile is None + assert nod_profile is not None + + # profiles are normalized along cross-dispersion, + # nod profile is negative + assert np.allclose(np.sum(profile, axis=1), 1.0) + assert np.allclose(np.sum(nod_profile, axis=1), -1.0) + + # positive trace is centered on 1.0, negative trace on + # 48.0, array size is 50. + npix = 26 + assert np.all(np.sum(profile != 0, axis=1) == npix) + assert np.all(np.sum(nod_profile != 0, axis=1) == npix) + + # psf is uniform when in range, 0 otherwise + assert np.allclose(profile[:, :npix], 1 / npix) + assert np.allclose(profile[:, npix:], 0.0) + + # nod profile is the same, but negative, and at the other + # end of the array + assert np.allclose(nod_profile[:, -npix:], -1 / npix) + assert np.allclose(nod_profile[:, :-npix], 0.0) + + def test_extract_one_slit_horizontal(mock_nirspec_fs_one_slit, extract_defaults, simple_profile, background_profile): # update parameters to subtract background @@ -1233,21 +1190,25 @@ def test_extract_one_slit_horizontal(mock_nirspec_fs_one_slit, extract_defaults, mock_nirspec_fs_one_slit.data[simple_profile != 0] += 1.0 result = ex.extract_one_slit(mock_nirspec_fs_one_slit, -1, simple_profile, - background_profile, extract_defaults) + background_profile, None, extract_defaults) - for data in result[:-1]: + for data in result[:-2]: assert np.all(data > 0) assert data.shape == (mock_nirspec_fs_one_slit.data.shape[1],) # residuals from the 2D scene model should be zero - this simple case # is exactly modeled with a box profile - scene_model = result[-1] + scene_model = result[-2] assert scene_model.shape == mock_nirspec_fs_one_slit.data.shape assert np.allclose(np.abs(mock_nirspec_fs_one_slit.data - scene_model), 0) + residual = result[-1] + assert residual.shape == mock_nirspec_fs_one_slit.data.shape + assert np.allclose(np.abs(residual), 0) + # flux should be 1.0 * npixels flux = result[0] - npixels = result[-2] + npixels = result[-3] assert np.allclose(flux, npixels) # npixels is sum of profile @@ -1269,21 +1230,25 @@ def test_extract_one_slit_vertical(mock_miri_lrs_fs, extract_defaults, # set a source in the profile region model.data[profile != 0] += 1.0 - result = ex.extract_one_slit(model, -1, profile, profile_bg, extract_defaults) + result = ex.extract_one_slit(model, -1, profile, profile_bg, None, extract_defaults) - for data in result[:-1]: + for data in result[:-2]: assert np.all(data > 0) assert data.shape == (model.data.shape[0],) # residuals from the 2D scene model should be zero - this simple case # is exactly modeled with a box profile - scene_model = result[-1] + scene_model = result[-2] assert scene_model.shape == model.data.shape assert np.allclose(np.abs(model.data - scene_model), 0) + residual = result[-1] + assert residual.shape == model.data.shape + assert np.allclose(np.abs(residual), 0) + # flux should be 1.0 * npixels flux = result[0] - npixels = result[-2] + npixels = result[-3] assert np.allclose(flux, npixels) # npixels is sum of profile @@ -1296,7 +1261,7 @@ def test_extract_one_slit_vertical_no_bg(mock_miri_lrs_fs, extract_defaults, profile = simple_profile.T extract_defaults['dispaxis'] = 2 - result = ex.extract_one_slit(model, -1, profile, None, extract_defaults) + result = ex.extract_one_slit(model, -1, profile, None, None, extract_defaults) # flux and variances are nonzero for data in result[:4]: @@ -1311,7 +1276,8 @@ def test_extract_one_slit_vertical_no_bg(mock_miri_lrs_fs, extract_defaults, # npixels is the sum of the profile assert np.allclose(result[8], np.sum(simple_profile, axis=0)) - # scene model has 2D shape + # scene model and residual has 2D shape + assert result[-2].shape == model.data.shape assert result[-1].shape == model.data.shape @@ -1321,7 +1287,7 @@ def test_extract_one_slit_multi_int(mock_nirspec_bots, extract_defaults, extract_defaults['dispaxis'] = 1 log_watcher.message = "Extracting integration 2" - result = ex.extract_one_slit(model, 1, simple_profile, None, extract_defaults) + result = ex.extract_one_slit(model, 1, simple_profile, None, None, extract_defaults) log_watcher.assert_seen() # flux and variances are nonzero @@ -1337,7 +1303,8 @@ def test_extract_one_slit_multi_int(mock_nirspec_bots, extract_defaults, # npixels is the sum of the profile assert np.allclose(result[8], np.sum(simple_profile, axis=0)) - # scene model has 2D shape + # scene model and residual has 2D shape + assert result[-2].shape == model.data.shape[-2:] assert result[-1].shape == model.data.shape[-2:] @@ -1354,7 +1321,7 @@ def test_extract_one_slit_missing_var(mock_nirspec_fs_one_slit, extract_defaults model.var_poisson = np.zeros((10, 10)) model.var_flat = np.zeros((10, 10)) - result = ex.extract_one_slit(model, -1, simple_profile, None, extract_defaults) + result = ex.extract_one_slit(model, -1, simple_profile, None, None, extract_defaults) # flux is nonzero assert np.all(result[0] > 0) @@ -1366,6 +1333,62 @@ def test_extract_one_slit_missing_var(mock_nirspec_fs_one_slit, extract_defaults assert data.shape == (model.data.shape[1],) +def test_extract_one_slit_optimal_horizontal( + mock_nirspec_fs_one_slit, extract_defaults, + nod_profile, negative_nod_profile): + model = mock_nirspec_fs_one_slit + extract_defaults['dispaxis'] = 1 + extract_defaults['extraction_type'] = 'optimal' + + result = ex.extract_one_slit(model, -1, nod_profile, None, + negative_nod_profile, extract_defaults) + + # flux and variances are nonzero + for data in result[:4]: + assert np.all(data > 0) + assert data.shape == (model.data.shape[0],) + + # background and variances are zero + for data in result[4:8]: + assert np.all(data == 0) + assert data.shape == (model.data.shape[0],) + + # npixels is the sum of the pixels in the positive profile + assert np.allclose(result[8], np.sum(nod_profile > 0, axis=0)) + + # scene model and residual has 2D shape + assert result[-1].shape == model.data.shape + assert result[-2].shape == model.data.shape + + +def test_extract_one_slit_optimal_vertical( + mock_miri_lrs_fs, extract_defaults, nod_profile, negative_nod_profile): + model = mock_miri_lrs_fs + nod_profile = nod_profile.T + negative_nod_profile = negative_nod_profile.T + extract_defaults['dispaxis'] = 2 + extract_defaults['extraction_type'] = 'optimal' + + result = ex.extract_one_slit(model, -1, nod_profile, None, negative_nod_profile, extract_defaults) + + # flux and variances are nonzero + for data in result[:4]: + assert np.all(data > 0) + assert data.shape == (model.data.shape[0],) + + # background and variances are zero + for data in result[4:8]: + assert np.all(data == 0) + assert data.shape == (model.data.shape[0],) + + # npixels is the sum of the pixels in the positive profile + assert np.allclose(result[8], np.sum(nod_profile > 0, axis=1)) + + # scene model and residual has 2D shape + assert result[-2].shape == model.data.shape + assert result[-1].shape == model.data.shape + + def test_create_extraction_with_photom(create_extraction_inputs): model = create_extraction_inputs[0] model.meta.cal_step.photom = 'COMPLETE' @@ -1529,55 +1552,92 @@ def mock_source_location(*args): log_watcher.assert_seen() +def test_create_extraction_optimal( + monkeypatch, create_extraction_inputs, psf_reference_file): + model = create_extraction_inputs[0] + + # mock nod subtraction + model.meta.cal_step.back_sub = 'COMPLETE' + model.meta.dither.primary_type = '2-POINT-NOD' + + # mock a nod position at the opposite end of the array + def mock_nod(*args, **kwargs): + return 48.0 + + monkeypatch.setattr(pp, 'nod_pair_location', mock_nod) + + profile_model, _, _ = ex.create_extraction( + *create_extraction_inputs, save_profile=True, + psf_ref_name=psf_reference_file, use_source_posn=True, + extraction_type='optimal', model_nod_pair=True) + + assert profile_model is not None + + # profile contains positive and negative nod, summed + assert np.all(profile_model.data[:10] > 0) + assert np.all(profile_model.data[-10:] < 0) + + profile_model.close() + + def test_run_extract1d(mock_nirspec_mos): model = mock_nirspec_mos - output_model, profile_model, scene_model = ex.run_extract1d(model) + output_model, profile_model, scene_model, residual = ex.run_extract1d(model) assert isinstance(output_model, dm.MultiSpecModel) assert profile_model is None assert scene_model is None + assert residual is None output_model.close() def test_run_extract1d_save_models(mock_niriss_wfss_l3): model = mock_niriss_wfss_l3 - output_model, profile_model, scene_model = ex.run_extract1d( - model, save_profile=True, save_scene_model=True) + output_model, profile_model, scene_model, residual = ex.run_extract1d( + model, save_profile=True, save_scene_model=True, save_residual_image=True) assert isinstance(output_model, dm.MultiSpecModel) assert isinstance(profile_model, ModelContainer) assert isinstance(scene_model, ModelContainer) + assert isinstance(residual, ModelContainer) assert len(profile_model) == len(model) assert len(scene_model) == len(model) + assert len(residual) == len(model) for pmodel in profile_model: assert isinstance(pmodel, dm.ImageModel) for smodel in scene_model: assert isinstance(smodel, dm.ImageModel) + for rmodel in residual: + assert isinstance(rmodel, dm.ImageModel) output_model.close() profile_model.close() scene_model.close() + residual.close() def test_run_extract1d_save_cube_scene(mock_nirspec_bots): model = mock_nirspec_bots - output_model, profile_model, scene_model = ex.run_extract1d( - model, save_profile=True, save_scene_model=True) + output_model, profile_model, scene_model, residual = ex.run_extract1d( + model, save_profile=True, save_scene_model=True, save_residual_image=True) assert isinstance(output_model, dm.MultiSpecModel) assert isinstance(profile_model, dm.ImageModel) assert isinstance(scene_model, dm.CubeModel) + assert isinstance(residual, dm.CubeModel) assert profile_model.data.shape == model.data.shape[-2:] assert scene_model.data.shape == model.data.shape + assert residual.data.shape == model.data.shape output_model.close() profile_model.close() scene_model.close() + residual.close() def test_run_extract1d_tso(mock_nirspec_bots): model = mock_nirspec_bots - output_model, _, _ = ex.run_extract1d(model) + output_model, _, _, _ = ex.run_extract1d(model) # time and integration keywords are populated for i, spec in enumerate(output_model.spec): @@ -1597,7 +1657,7 @@ def test_run_extract1d_slitmodel_name(mock_nirspec_fs_one_slit, from_name_attr): model.name = None model.meta.instrument.fixed_slit = 'S200A1' - output_model, _, _ = ex.run_extract1d(model) + output_model, _, _, _ = ex.run_extract1d(model) assert output_model.spec[0].name == 'S200A1' output_model.close() @@ -1612,7 +1672,7 @@ def test_run_extract1d_imagemodel_name(mock_miri_lrs_fs, from_name_attr): else: model.name = None - output_model, _, _ = ex.run_extract1d(model) + output_model, _, _, _ = ex.run_extract1d(model) if from_name_attr: assert output_model.spec[0].name == 'test_slit_name' else: @@ -1625,12 +1685,39 @@ def test_run_extract1d_apcorr(mock_miri_lrs_fs, miri_lrs_apcorr_file, log_watche model.meta.target.source_type = 'POINT' log_watcher.message = 'Creating aperture correction' - output_model, _, _ = ex.run_extract1d(model, apcorr_ref_name=miri_lrs_apcorr_file) + output_model, _, _, _ = ex.run_extract1d(model, apcorr_ref_name=miri_lrs_apcorr_file) + log_watcher.assert_seen() + + output_model.close() + + +def test_run_extract1d_apcorr_optimal( + mock_miri_lrs_fs, miri_lrs_apcorr_file, psf_reference_file, log_watcher): + model = mock_miri_lrs_fs + model.meta.target.source_type = 'POINT' + + # Aperture correction that is otherwise valid is nonetheless + # turned off for optimal extraction + log_watcher.message = 'Turning off aperture correction for optimal extraction' + output_model, _, _, _ = ex.run_extract1d(model, apcorr_ref_name=miri_lrs_apcorr_file, + psf_ref_name=psf_reference_file, + extraction_type='optimal') log_watcher.assert_seen() output_model.close() +def test_run_extract1d_optimal_no_psf(mock_miri_lrs_fs, log_watcher): + model = mock_miri_lrs_fs + model.meta.target.source_type = 'POINT' + + # Optimal extraction is turned off if there is no psf file provided + log_watcher.message = 'Optimal extraction is not available' + output_model, _, _, _ = ex.run_extract1d(model, extraction_type='optimal') + log_watcher.assert_seen() + + output_model.close() + def test_run_extract1d_invalid(): model = dm.MultiSpecModel() with pytest.raises(RuntimeError, match="Can't extract a spectrum"): @@ -1640,7 +1727,7 @@ def test_run_extract1d_invalid(): def test_run_extract1d_zeroth_order_slit(mock_nirspec_fs_one_slit): model = mock_nirspec_fs_one_slit model.meta.wcsinfo.spectral_order = 0 - output_model, _, _ = ex.run_extract1d(model) + output_model, _, _, _ = ex.run_extract1d(model) # no spectra extracted for zeroth order assert len(output_model.spec) == 0 @@ -1650,7 +1737,8 @@ def test_run_extract1d_zeroth_order_slit(mock_nirspec_fs_one_slit): def test_run_extract1d_zeroth_order_image(mock_miri_lrs_fs): model = mock_miri_lrs_fs model.meta.wcsinfo.spectral_order = 0 - output_model, _, _ = ex.run_extract1d(model) + model.meta.instrument.filter = None + output_model, _, _, _ = ex.run_extract1d(model) # no spectra extracted for zeroth order assert len(output_model.spec) == 0 @@ -1661,7 +1749,7 @@ def test_run_extract1d_zeroth_order_multispec(mock_nirspec_mos): model = mock_nirspec_mos for slit in model.slits: slit.meta.wcsinfo.spectral_order = 0 - output_model, _, _ = ex.run_extract1d(model) + output_model, _, _, _ = ex.run_extract1d(model) # no spectra extracted for zeroth order assert len(output_model.spec) == 0 @@ -1672,7 +1760,7 @@ def test_run_extract1d_no_data(mock_niriss_wfss_l3): container = mock_niriss_wfss_l3 for model in container: model.data = np.array([]) - output_model, _, _ = ex.run_extract1d(container) + output_model, _, _, _ = ex.run_extract1d(container) # no spectra extracted assert len(output_model.spec) == 0 @@ -1684,7 +1772,7 @@ def raise_continue_error(*args, **kwargs): raise ex.ContinueError('Test error') monkeypatch.setattr(ex, 'create_extraction', raise_continue_error) - output_model, _, _ = ex.run_extract1d(mock_nirspec_fs_one_slit) + output_model, _, _, _ = ex.run_extract1d(mock_nirspec_fs_one_slit) # no spectra extracted assert len(output_model.spec) == 0 @@ -1696,7 +1784,7 @@ def raise_continue_error(*args, **kwargs): raise ex.ContinueError('Test error') monkeypatch.setattr(ex, 'create_extraction', raise_continue_error) - output_model, _, _ = ex.run_extract1d(mock_miri_lrs_fs) + output_model, _, _, _ = ex.run_extract1d(mock_miri_lrs_fs) # no spectra extracted assert len(output_model.spec) == 0 @@ -1708,7 +1796,7 @@ def raise_continue_error(*args, **kwargs): raise ex.ContinueError('Test error') monkeypatch.setattr(ex, 'create_extraction', raise_continue_error) - output_model, _, _ = ex.run_extract1d(mock_nirspec_mos) + output_model, _, _, _ = ex.run_extract1d(mock_nirspec_mos) # no spectra extracted assert len(output_model.spec) == 0 diff --git a/jwst/extract_1d/tests/test_extract_1d_step.py b/jwst/extract_1d/tests/test_extract_1d_step.py index 0493679097..cc85564424 100644 --- a/jwst/extract_1d/tests/test_extract_1d_step.py +++ b/jwst/extract_1d/tests/test_extract_1d_step.py @@ -68,6 +68,22 @@ def test_extract_nirspec_bots(mock_nirspec_bots, simple_wcs): result.close() +def test_extract_miri_lrs_fs(mock_miri_lrs_fs, simple_wcs_transpose): + result = Extract1dStep.call(mock_miri_lrs_fs) + assert result.meta.cal_step.extract_1d == 'COMPLETE' + assert result.spec[0].name == 'MIR_LRS-FIXEDSLIT' + + # output wavelength is the same as input + _, _, expected_wave = simple_wcs_transpose(np.arange(50), np.arange(50)) + assert np.allclose(result.spec[0].spec_table['WAVELENGTH'], expected_wave) + + # output flux and errors are non-zero, exact values will depend + # on extraction parameters + assert np.all(result.spec[0].spec_table['FLUX'] > 0) + assert np.all(result.spec[0].spec_table['FLUX_ERROR'] > 0) + result.close() + + @pytest.mark.parametrize('ifu_set_srctype', [None, 'EXTENDED']) def test_extract_miri_ifu(mock_miri_ifu, simple_wcs_ifu, ifu_set_srctype): # Source type defaults to extended, results should be the @@ -188,23 +204,48 @@ def test_save_output_single(tmp_path, mock_nirspec_fs_one_slit): mock_nirspec_fs_one_slit.meta.filename = 'test_s2d.fits' result = Extract1dStep.call(mock_nirspec_fs_one_slit, save_results=True, save_profile=True, - save_scene_model=True, output_dir=str(tmp_path), - suffix='x1d') + save_scene_model=True, save_residual_image=True, + output_dir=str(tmp_path), suffix='x1d') output_path = str(tmp_path / 'test_x1d.fits') assert os.path.isfile(output_path) assert os.path.isfile(output_path.replace('x1d', 'profile')) assert os.path.isfile(output_path.replace('x1d', 'scene_model')) + assert os.path.isfile(output_path.replace('x1d', 'residual')) + + result.close() + + +def test_save_output_multiple(tmp_path, mock_nirspec_fs_one_slit): + input_container = ModelContainer([mock_nirspec_fs_one_slit.copy(), + mock_nirspec_fs_one_slit.copy()]) + + result = Extract1dStep.call(input_container, + save_results=True, save_profile=True, + save_scene_model=True, save_residual_image=True, + output_dir=str(tmp_path), + suffix='x1d', output_file='test') + + output_paths = [str(tmp_path / 'test_0_x1d.fits'), + str(tmp_path / 'test_1_x1d.fits')] + + for output_path in output_paths: + assert os.path.isfile(output_path) + assert os.path.isfile(output_path.replace('x1d', 'profile')) + assert os.path.isfile(output_path.replace('x1d', 'scene_model')) + assert os.path.isfile(output_path.replace('x1d', 'residual')) result.close() + input_container.close() def test_save_output_multislit(tmp_path, mock_nirspec_mos): mock_nirspec_mos.meta.filename = 'test_s2d.fits' result = Extract1dStep.call(mock_nirspec_mos, save_results=True, save_profile=True, - save_scene_model=True, output_dir=str(tmp_path), + save_scene_model=True, save_residual_image=True, + output_dir=str(tmp_path), suffix='x1d') output_path = str(tmp_path / 'test_x1d.fits') @@ -215,5 +256,29 @@ def test_save_output_multislit(tmp_path, mock_nirspec_mos): for slit in mock_nirspec_mos.slits: assert os.path.isfile(output_path.replace('x1d', f'{slit.name}_profile')) assert os.path.isfile(output_path.replace('x1d', f'{slit.name}_scene_model')) + assert os.path.isfile(output_path.replace('x1d', f'{slit.name}_residual')) + + result.close() + + +def test_save_output_multiple_multislit(tmp_path, mock_nirspec_mos): + input_container = ModelContainer([mock_nirspec_mos.copy(), + mock_nirspec_mos.copy()]) + result = Extract1dStep.call(input_container, + save_results=True, save_profile=True, + save_scene_model=True, save_residual_image=True, + output_dir=str(tmp_path), + suffix='x1d', output_file='test') + + for i in range(2): + output_path = str(tmp_path / f'test_{i}_x1d.fits') + assert os.path.isfile(output_path) + + # intermediate files for multislit data contain the slit name + for slit in mock_nirspec_mos.slits: + assert os.path.isfile(str(tmp_path / f'test_{slit.name}_{i}_profile.fits')) + assert os.path.isfile(str(tmp_path / f'test_{slit.name}_{i}_scene_model.fits')) + assert os.path.isfile(str(tmp_path / f'test_{slit.name}_{i}_residual.fits')) result.close() + input_container.close() diff --git a/jwst/extract_1d/tests/test_psf_profile.py b/jwst/extract_1d/tests/test_psf_profile.py new file mode 100644 index 0000000000..319d3061e4 --- /dev/null +++ b/jwst/extract_1d/tests/test_psf_profile.py @@ -0,0 +1,436 @@ +import logging + +import numpy as np +import pytest +from stdatamodels.jwst.datamodels import SpecPsfModel + +from jwst.extract_1d import psf_profile as pp +from jwst.tests.helpers import LogWatcher + + +@pytest.fixture +def log_watcher(monkeypatch): + # Set a log watcher to check for a log message at any level + # in the extract_1d.psf_profile module + watcher = LogWatcher('') + logger = logging.getLogger('jwst.extract_1d.psf_profile') + for level in ['debug', 'info', 'warning', 'error']: + monkeypatch.setattr(logger, level, watcher) + return watcher + + +@pytest.mark.parametrize('exp_type', ['MIR_LRS-FIXEDSLIT', 'NRS_FIXEDSLIT', 'UNKNOWN']) +def test_open_psf(psf_reference_file, exp_type): + # for any exptype, a model that can be read + # as SpecPsfModel will be, since it's the only + # one implemented so far + with pp.open_psf(psf_reference_file, exp_type=exp_type) as model: + assert isinstance(model, SpecPsfModel) + + +def test_open_psf_fail(): + with pytest.raises(NotImplementedError, match='could not be read'): + pp.open_psf('bad_file', 'UNKNOWN') + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_normalize_profile(nod_profile, dispaxis): + profile = 2 * nod_profile + if dispaxis == 2: + profile = profile.T + pp._normalize_profile(profile, dispaxis) + assert np.allclose(np.sum(profile, axis=dispaxis - 1), 1.0) + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_normalize_profile_with_nans(nod_profile, dispaxis): + profile = -1 * nod_profile + profile[10, :] = np.nan + if dispaxis == 2: + profile = profile.T + + pp._normalize_profile(profile, dispaxis) + assert np.allclose(np.sum(profile, axis=dispaxis - 1), 1.0) + assert np.all(np.isfinite(profile)) + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_make_cutout_profile_default(psf_reference, dispaxis): + data_shape = psf_reference.data.shape + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + + psf_subpix = psf_reference.meta.psf.subpix + profiles = pp._make_cutout_profile(xidx, yidx, psf_subpix, psf_reference.data, dispaxis) + assert len(profiles) == 1 + assert profiles[0].shape == data_shape + + # No shift, profile is uniform and normalized to cross-dispersion size + if dispaxis == 1: + assert np.all(profiles[0] == 1 / data_shape[0]) + else: + assert np.all(profiles[0] == 1 / data_shape[1]) + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +@pytest.mark.parametrize('extra_shift', [1, 2]) +def test_make_cutout_profile_shift_down(psf_reference, dispaxis, extra_shift): + data_shape = psf_reference.data.shape + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + psf_subpix = psf_reference.meta.psf.subpix + + profiles = pp._make_cutout_profile(xidx, yidx, psf_subpix, psf_reference.data, + dispaxis, extra_shift=extra_shift) + assert len(profiles) == 1 + assert profiles[0].shape == data_shape + + # Profile is shifted down by shift amount + if dispaxis == 1: + nn = data_shape[0] - extra_shift + assert np.all(profiles[0][:-extra_shift, :] == 1 / nn) + assert np.all(profiles[0][-extra_shift:, :] == 0.0) + else: + nn = data_shape[1] - extra_shift + assert np.all(profiles[0][:, :-extra_shift] == 1 / nn) + assert np.all(profiles[0][:, -extra_shift:] == 0.0) + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +@pytest.mark.parametrize('extra_shift', [-1, -2]) +def test_make_cutout_profile_shift_up(psf_reference, dispaxis, extra_shift): + data_shape = psf_reference.data.shape + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + psf_subpix = psf_reference.meta.psf.subpix + + profiles = pp._make_cutout_profile(xidx, yidx, psf_subpix, psf_reference.data, + dispaxis, extra_shift=extra_shift) + assert len(profiles) == 1 + assert profiles[0].shape == data_shape + + # Profile is shifted up by shift amount + if dispaxis == 1: + nn = data_shape[0] + extra_shift + assert np.all(profiles[0][-extra_shift:, :] == 1 / nn) + assert np.all(profiles[0][:-extra_shift, :] == 0.0) + else: + nn = data_shape[1] + extra_shift + assert np.all(profiles[0][:, -extra_shift:] == 1 / nn) + assert np.all(profiles[0][:, :-extra_shift] == 0.0) + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_make_cutout_profile_with_nod(psf_reference, dispaxis): + data_shape = psf_reference.data.shape + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + psf_subpix = psf_reference.meta.psf.subpix + + offset = 2 + profiles = pp._make_cutout_profile(xidx, yidx, psf_subpix, psf_reference.data, + dispaxis, nod_offset=offset) + assert len(profiles) == 2 + source, nod = profiles + assert source.shape == data_shape + assert nod.shape == data_shape + + # Profile is uniform, nod profile is the same, but shifted + # down by offset and multiplied by -1 + if dispaxis == 1: + assert np.all(source == 1 / data_shape[0]) + + nn = data_shape[0] - offset + assert np.all(nod[:-offset, :] == -1 / nn) + assert np.all(nod[-offset:, :] == 0.0) + else: + assert np.all(source == 1 / data_shape[1]) + + nn = data_shape[1] - offset + assert np.all(nod[:, :-offset] == -1 / nn) + assert np.all(nod[:, -offset:] == 0.0) + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_profile_residual(psf_reference, dispaxis): + data_shape = (50, 50) + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + psf_subpix = psf_reference.meta.psf.subpix + + # Set data to all ones, so residual should be zero + # when background is not fit + data = np.full(data_shape, 1.0) + var = np.full(data_shape, 0.01) + + param = [0, None] + residual = pp._profile_residual( + param, data, var, xidx, yidx, + psf_subpix, psf_reference.data, dispaxis, fit_bkg=False) + assert np.isclose(residual, 0.0) + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_profile_residual_with_bkg(psf_reference, dispaxis): + data_shape = (50, 50) + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + psf_subpix = psf_reference.meta.psf.subpix + + # Set data to all ones, so it is all background - residual + # should be all of the data + data = np.full(data_shape, 1.0) + var = np.full(data_shape, 0.01) + + param = [0, None] + residual = pp._profile_residual( + param, data, var, xidx, yidx, + psf_subpix, psf_reference.data, dispaxis, fit_bkg=True) + assert np.isclose(residual, np.sum(data ** 2 / var)) + + +@pytest.mark.parametrize('use_trace', [True, False]) +def test_psf_profile(mock_miri_lrs_fs, psf_reference_file, use_trace): + data_shape = mock_miri_lrs_fs.data.shape + if use_trace: + # Centered trace - should match None behavior + trace = np.full(data_shape[0], (data_shape[1] - 1) / 2.0) + else: + # Avoid miri specific behavior for trace - the mock WCS is not complete enough + mock_miri_lrs_fs.meta.exposure.type = 'ANY' + trace = None + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = mock_miri_lrs_fs.meta.wcs(xidx, yidx) + + profiles, lower, upper = pp.psf_profile( + mock_miri_lrs_fs, trace, wl_array, psf_reference_file, + optimize_shifts=False, model_nod_pair=False) + + # no nod profile, data cutout matches full shape + assert len(profiles) == 1 + assert profiles[0].shape == data_shape + assert lower == 0 + assert upper == data_shape[1] + + # profile is uniform and centered + assert np.allclose(profiles[0], 1 / data_shape[1]) + + +@pytest.mark.parametrize('use_trace', [True, False]) +def test_psf_profile_multi_int(mock_nirspec_bots, psf_reference_file, use_trace): + data_shape = mock_nirspec_bots.data.shape[-2:] + + if use_trace: + trace = np.full(data_shape[1], (data_shape[0] - 1) / 2.0) + else: + # Avoid nirspec specific behavior for trace - + # the mock WCS is not complete enough + mock_nirspec_bots.meta.exposure.type = 'ANY' + trace = None + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = mock_nirspec_bots.meta.wcs(xidx, yidx) + + profiles, lower, upper = pp.psf_profile( + mock_nirspec_bots, trace, wl_array, psf_reference_file, + optimize_shifts=False, model_nod_pair=False) + + # no nod profile, data cutout matches full shape + assert len(profiles) == 1 + assert profiles[0].shape == data_shape + assert lower == 0 + assert upper == data_shape[0] + + # profile is uniform and centered + assert np.allclose(profiles[0], 1 / data_shape[1]) + + +def test_psf_profile_model_nod(monkeypatch, mock_miri_lrs_fs, psf_reference_file): + model = mock_miri_lrs_fs + data_shape = model.data.shape + trace = np.full(data_shape[0], (data_shape[1] - 1) / 2.0) + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = model.meta.wcs(xidx, yidx) + + # mock nod subtraction + model.meta.cal_step.back_sub = 'COMPLETE' + model.meta.dither.primary_type = '2-POINT-NOD' + + # mock a nod position at the opposite end of the array + def mock_nod(*args, **kwargs): + return 48.0 + + monkeypatch.setattr(pp, 'nod_pair_location', mock_nod) + + profiles, lower, upper = pp.psf_profile( + model, trace, wl_array, psf_reference_file, + optimize_shifts=False, model_nod_pair=True) + + # now returns nod profile + assert len(profiles) == 2 + source, nod = profiles + assert source.shape == data_shape + assert nod.shape == data_shape + + # source profile is uniform and centered + assert np.allclose(source, 1 / data_shape[1]) + + # nod profile is centered at the end of the array and has a negative value + assert np.allclose(nod[:, -2:], -1 / 26) + + +def test_psf_profile_model_nod_no_trace( + mock_miri_lrs_fs, psf_reference_file, log_watcher): + model = mock_miri_lrs_fs + model.meta.exposure.type = 'ANY' + data_shape = model.data.shape + trace = None + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = model.meta.wcs(xidx, yidx) + + log_watcher.message = 'Cannot model a negative nod without position' + profiles, lower, upper = pp.psf_profile( + model, trace, wl_array, psf_reference_file, + optimize_shifts=False, model_nod_pair=True) + log_watcher.assert_seen() + + # does not return nod profile + assert len(profiles) == 1 + + +def test_psf_profile_model_nod_not_subtracted( + mock_miri_lrs_fs, psf_reference_file, log_watcher): + model = mock_miri_lrs_fs + data_shape = model.data.shape + trace = np.full(data_shape[0], (data_shape[1] - 1) / 2.0) + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = model.meta.wcs(xidx, yidx) + + log_watcher.message = 'data was not nod-subtracted' + profiles, lower, upper = pp.psf_profile( + model, trace, wl_array, psf_reference_file, + optimize_shifts=False, model_nod_pair=True) + log_watcher.assert_seen() + + # does not return nod profile + assert len(profiles) == 1 + + +def test_psf_profile_model_nod_wrong_pattern( + mock_miri_lrs_fs, psf_reference_file, log_watcher): + model = mock_miri_lrs_fs + data_shape = model.data.shape + trace = np.full(data_shape[0], (data_shape[1] - 1) / 2.0) + model.meta.cal_step.back_sub = 'COMPLETE' + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = model.meta.wcs(xidx, yidx) + + log_watcher.message = 'data was not a two-point nod' + profiles, lower, upper = pp.psf_profile( + model, trace, wl_array, psf_reference_file, + optimize_shifts=False, model_nod_pair=True) + log_watcher.assert_seen() + + # does not return nod profile + assert len(profiles) == 1 + + +def test_psf_profile_model_nod_bad_position( + mock_miri_lrs_fs, psf_reference_file, log_watcher): + model = mock_miri_lrs_fs + data_shape = model.data.shape + trace = np.full(data_shape[0], (data_shape[1] - 1) / 2.0) + model.meta.cal_step.back_sub = 'COMPLETE' + model.meta.dither.primary_type = '2-POINT-NOD' + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = model.meta.wcs(xidx, yidx) + + log_watcher.message = 'Nod center could not be estimated' + profiles, lower, upper = pp.psf_profile( + model, trace, wl_array, psf_reference_file, + optimize_shifts=False, model_nod_pair=True) + log_watcher.assert_seen() + + # does not return nod profile + assert len(profiles) == 1 + + +@pytest.mark.parametrize('start_offset', [0.0, 0.1, 2.0, -1.0]) +def test_psf_profile_optimize( + mock_miri_lrs_fs, psf_reference_file_with_source, start_offset, + log_watcher): + model = mock_miri_lrs_fs + data_shape = model.data.shape + trace = np.full(data_shape[0], (data_shape[1] - 1) / 2.0) + trace += start_offset + + # add a peak to the center of the data, matching the model + model.data[:] = 0.01 + model.data[:, 24:27] += 1.0 + model.var_rnoise[:] = 0.1 + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = model.meta.wcs(xidx, yidx) + + log_watcher.message = 'Centering profile on spectrum at 24.5' + profiles, lower, upper = pp.psf_profile( + model, trace, wl_array, psf_reference_file_with_source, + optimize_shifts=True, model_nod_pair=False) + log_watcher.assert_seen() + + # profile is centered at 24.5 + profile = profiles[0] + assert np.allclose(profile[:, 24:27], 1 / 3, atol=1e-4) + assert np.allclose(profile[:, :24], 0.0, atol=1e-4) + assert np.allclose(profile[:, 27:], 0.0, atol=1e-4) + + +@pytest.mark.parametrize('start_offset', [0.0, 0.1, 2.0, -1.0]) +def test_psf_profile_optimize_with_nod( + monkeypatch, mock_miri_lrs_fs, psf_reference_file_with_source, + start_offset, log_watcher): + model = mock_miri_lrs_fs + data_shape = model.data.shape + + # mock nod subtraction + model.meta.cal_step.back_sub = 'COMPLETE' + model.meta.dither.primary_type = '2-POINT-NOD' + + # trace at pixel 9.5 + trace = np.full(data_shape[0], 9.5) + + # with an offset to optimize out + trace += start_offset + + # mock a nod position at the opposite end of the array, + # with an extra offset from truth, and also from the trace + def mock_nod(*args, **kwargs): + return 39.5 + start_offset + 0.1 + + monkeypatch.setattr(pp, 'nod_pair_location', mock_nod) + + # add a peak at pixel 10, negative peak at pixel 40 + model.data[:] = 0.0 + model.data[:, 9:12] += 1.0 + model.data[:, 39:42] -= 1.0 + model.var_rnoise[:] = 0.1 + + yidx, xidx = np.mgrid[:data_shape[0], :data_shape[1]] + _, _, wl_array = model.meta.wcs(xidx, yidx) + + log_watcher.message = 'Also modeling a negative trace at 39.50' + profiles, lower, upper = pp.psf_profile( + model, trace, wl_array, psf_reference_file_with_source, + optimize_shifts=True, model_nod_pair=True) + log_watcher.assert_seen() + + # profile is centered at 10 + source, nod = profiles + assert np.allclose(source[:, 9:12], 1 / 3, atol=1e-4) + assert np.allclose(source[:, :9], 0.0, atol=1e-4) + assert np.allclose(source[:, 12:], 0.0, atol=1e-4) + + # nod is centered at 40 + assert np.allclose(nod[:, 39:42], -1 / 3, atol=1e-4) + assert np.allclose(nod[:, :39], 0.0, atol=1e-4) + assert np.allclose(nod[:, 42:], 0.0, atol=1e-4) diff --git a/jwst/extract_1d/tests/test_source_location.py b/jwst/extract_1d/tests/test_source_location.py new file mode 100644 index 0000000000..b9de563c83 --- /dev/null +++ b/jwst/extract_1d/tests/test_source_location.py @@ -0,0 +1,314 @@ +import logging +import numpy as np +import pytest + +from jwst.extract_1d import source_location as sl +from jwst.tests.helpers import LogWatcher + + +@pytest.fixture +def log_watcher(monkeypatch): + # Set a log watcher to check for a log message at any level + # in the extract_1d.extract module + watcher = LogWatcher('') + logger = logging.getLogger('jwst.extract_1d.source_location') + for level in ['debug', 'info', 'warning', 'error']: + monkeypatch.setattr(logger, level, watcher) + return watcher + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_middle_from_wcs_constant_wl(dispaxis): + # mock a wcs that returns a constant wavelength + def mock_wcs(x, y): + return None, None, np.full(x.shape, 10.0) + + bbox = (-0.5, 9.5), (-0.5, 9.5) + + md, mx, mw = sl.middle_from_wcs(mock_wcs, bbox, dispaxis) + + # middle for dispersion and cross-dispersion are at the center, 4.5 + assert md == 4.5 + assert mx == 4.5 + + # middle wavelength is the constant value + assert mw == 10.0 + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_middle_from_wcs_variable_wl(dispaxis): + # mock a wcs that returns a variable wavelength + def mock_wcs(x, y): + return None, None, np.arange(x.size, dtype=float).reshape(x.shape) + + bbox = (-0.5, 9.5), (-0.5, 9.5) + + md, mx, mw = sl.middle_from_wcs(mock_wcs, bbox, dispaxis) + + # middle for dispersion, cross-dispersion, and wavelength are all 4.5 + assert md == 4.5 + assert mx == 4.5 + assert mw == 4.5 + + +@pytest.mark.parametrize('resampled', [True, False]) +@pytest.mark.parametrize('is_slit', [True, False]) +@pytest.mark.parametrize('missing_bbox', [True, False]) +def test_location_from_wcs_nirspec( + monkeypatch, mock_nirspec_fs_one_slit, resampled, is_slit, missing_bbox): + model = mock_nirspec_fs_one_slit + + if not resampled: + # mock available frames, so it looks like unresampled cal data + monkeypatch.setattr(model.meta.wcs, 'available_frames', ['gwa']) + + if missing_bbox: + # mock a missing bounding box - should have same results + # for the test data + monkeypatch.setattr(model.meta.wcs, 'bounding_box', None) + + if is_slit: + middle, middle_wl, location, trace = sl.location_from_wcs(model, model) + else: + middle, middle_wl, location, trace = sl.location_from_wcs(model, None) + + # middle pixel is center of dispersion axis + assert middle == int((model.data.shape[1] - 1) / 2) + + # middle wavelength is the wavelength at that point, from the mock wcs + assert np.isclose(middle_wl, 7.745) + + # location is 1.0 - from the mocked transform function + assert location == 1.0 + + # trace is the same, in an array + assert np.all(trace == 1.0) + + +@pytest.mark.parametrize('is_slit', [True, False]) +def test_location_from_wcs_miri(monkeypatch, mock_miri_lrs_fs, is_slit): + model = mock_miri_lrs_fs + + # monkey patch in a transform for the wcs + def radec2det(*args, **kwargs): + def return_one(*args, **kwargs): + return 1.0, 0.0 + return return_one + + monkeypatch.setattr(model.meta.wcs, 'backward_transform', radec2det()) + + # mock the trace function + def mock_trace(*args, **kwargs): + return np.full(model.data.shape[-2], 1.0) + + monkeypatch.setattr(sl, '_miri_trace_from_wcs', mock_trace) + + # Get the slit center from the WCS + if is_slit: + middle, middle_wl, location, trace = sl.location_from_wcs(model, model) + else: + middle, middle_wl, location, trace = sl.location_from_wcs(model, None) + + # middle pixel is center of dispersion axis + assert middle == int((model.data.shape[0] - 1) / 2) + + # middle wavelength is the wavelength at that point, from the mock wcs + assert np.isclose(middle_wl, 7.255) + + # location is 1.0 - from the mocked transform function + assert location == 1.0 + + # trace is the same, in an array + assert np.all(trace == 1.0) + + +def test_location_from_wcs_missing_data(mock_miri_lrs_fs, log_watcher): + model = mock_miri_lrs_fs + model.meta.wcs.backward_transform = None + + # model is missing WCS information - None values are returned + log_watcher.message = "Dithered pointing location not found" + result = sl.location_from_wcs(model, None) + assert result == (None, None, None, None) + log_watcher.assert_seen() + + +def test_location_from_wcs_wrong_exptype(mock_niriss_soss, log_watcher): + # model is not a handled exposure type + log_watcher.message = "Source position cannot be found for EXP_TYPE" + result = sl.location_from_wcs(mock_niriss_soss, None) + assert result == (None, None, None, None) + log_watcher.assert_seen() + + +def test_location_from_wcs_bad_location( + monkeypatch, mock_nirspec_fs_one_slit, log_watcher): + model = mock_nirspec_fs_one_slit + + # monkey patch in a transform for the wcs + def slit2det(*args, **kwargs): + def return_one(*args, **kwargs): + return 0.0, np.nan + return return_one + + monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) + + # WCS transform returns NaN for the location + log_watcher.message = "Source position could not be determined" + result = sl.location_from_wcs(model, None) + assert result == (None, None, None, None) + log_watcher.assert_seen() + + +def test_location_from_wcs_location_out_of_range( + monkeypatch, mock_nirspec_fs_one_slit, log_watcher): + model = mock_nirspec_fs_one_slit + + # monkey patch in a transform for the wcs + def slit2det(*args, **kwargs): + def return_one(*args, **kwargs): + return 0.0, 2000 + return return_one + + monkeypatch.setattr(model.meta.wcs, 'get_transform', slit2det) + + # mock the trace function + def mock_trace(*args, **kwargs): + return np.full(model.data.shape[-1], 1.0) + + monkeypatch.setattr(sl, '_nirspec_trace_from_wcs', mock_trace) + + # WCS transform a value outside the bounding box + log_watcher.message = "outside the bounding box" + result = sl.location_from_wcs(model, None) + assert result == (None, None, None, None) + log_watcher.assert_seen() + + +def test_nirspec_trace_from_wcs(mock_nirspec_fs_one_slit): + model = mock_nirspec_fs_one_slit + trace = sl._nirspec_trace_from_wcs(model.data.shape, model.meta.wcs.bounding_box, + model.meta.wcs, 1.0, 1.0) + # mocked model contains some mock transforms as well - all ones are expected + assert np.all(trace == np.ones(model.data.shape[1])) + + +def test_miri_trace_from_wcs(mock_miri_lrs_fs): + model = mock_miri_lrs_fs + trace = sl._miri_trace_from_wcs(model.data.shape, model.meta.wcs.bounding_box, + model.meta.wcs, 1.0, 1.0) + + # mocked model contains some mock transforms as well - all ones are expected + assert np.all(trace == np.ones(model.data.shape[0])) + + +def test_trace_from_wcs_nirspec(mock_nirspec_fs_one_slit): + model = mock_nirspec_fs_one_slit + trace = sl.trace_from_wcs( + 'NRS_FIXEDSLIT', model.data.shape, model.meta.wcs.bounding_box, + model.meta.wcs, 1.0, 1.0, 1) + + # mocked model contains some mock transforms as well - all ones are expected + assert np.all(trace == np.ones(model.data.shape[1])) + + +def test_trace_from_wcs_miri(mock_miri_lrs_fs): + model = mock_miri_lrs_fs + trace = sl.trace_from_wcs( + 'MIR_LRS-FIXEDSLIT', model.data.shape, model.meta.wcs.bounding_box, + model.meta.wcs, 1.0, 1.0, 2) + + # mocked model contains some mock transforms as well - all ones are expected + assert np.all(trace == np.ones(model.data.shape[0])) + + +def test_trace_from_wcs_other_horizontal(): + exp_type = 'ANY' + shape = (10, 20) + bbox = (1.5, 17.5), (1.5, 8.5) + wcs = None + source_x = 2.0 + source_y = 4.0 + dispaxis = 1 + + trace = sl.trace_from_wcs(exp_type, shape, bbox, wcs, source_x, source_y, dispaxis) + + # trace matches dispersion dimension (x) + assert trace.shape == (shape[1],) + + # trace is full of the source_y position, except outside the bounding box, + # where it is NaN + assert np.all(trace[2:18] == source_y) + assert np.all(np.isnan(trace[:2])) + assert np.all(np.isnan(trace[18:])) + + +def test_trace_from_wcs_other_vertical(): + exp_type = 'ANY' + shape = (10, 20) + bbox = (1.5, 17.5), (1.5, 8.5) + wcs = None + source_x = 2.0 + source_y = 4.0 + dispaxis = 2 + + trace = sl.trace_from_wcs(exp_type, shape, bbox, wcs, source_x, source_y, dispaxis) + + # trace matches dispersion dimension (y) + assert trace.shape == (shape[0],) + + # trace is full of the source_x position, except outside the bounding box, + # where it is NaN + assert np.all(trace[2:9] == source_x) + assert np.all(np.isnan(trace[:2])) + assert np.all(np.isnan(trace[9:])) + + +def test_nod_pair_location_nirspec(mock_nirspec_fs_one_slit): + model = mock_nirspec_fs_one_slit + middle_wl = 7.5 + + nod_center = sl.nod_pair_location(model, middle_wl) + + # for mock transforms, 1.0 is expected + assert nod_center == 1.0 + + +def test_nod_pair_location_nirspec_unresampled(mock_nirspec_fs_one_slit): + model = mock_nirspec_fs_one_slit + middle_wl = 7.5 + model.meta.wcs.available_frames = ['gwa'] + + nod_center = sl.nod_pair_location(model, middle_wl) + + # for mock transforms, 1.0 is expected + assert nod_center == 1.0 + + +@pytest.mark.parametrize('dispaxis', [1, 2]) +def test_nod_pair_location_miri(mock_miri_lrs_fs, dispaxis): + model = mock_miri_lrs_fs + middle_wl = 7.5 + + nod_center = sl.nod_pair_location(model, middle_wl) + + # for mock transforms as is, NaN is expected + assert np.isnan(nod_center) + + # mock v2v3 transform + model.meta.wcs.available_frames = ['v2v3'] + model.meta.wcsinfo.v3yangle = 1.0 + model.meta.wcsinfo.v2_ref = 1.0 + model.meta.wcsinfo.v3_ref = 1.0 + model.meta.wcsinfo.vparity = 1 + model.meta.dither.x_offset = 1.0 + model.meta.dither.y_offset = 1.0 + model.meta.wcsinfo.dispersion_direction = dispaxis + nod_center = sl.nod_pair_location(model, middle_wl) + + # the final backward transform mock returns (1.0, 0.0), + # so the location reported will be x=1.0 if vertical, y=0.0 if horizontal + if dispaxis == 2: + assert nod_center == 1.0 + else: + assert nod_center == 0.0 diff --git a/jwst/pipeline/calwebb_spec3.py b/jwst/pipeline/calwebb_spec3.py index 04712a2edc..8466830478 100644 --- a/jwst/pipeline/calwebb_spec3.py +++ b/jwst/pipeline/calwebb_spec3.py @@ -220,6 +220,7 @@ def process(self, input): result = self.mrs_imatch.run(result) # Call outlier detection and pixel replacement + resample_complete = None if exptype not in SLITLESS_TYPES: # Update the asn table name to the level 3 instance so that # the downstream products have the correct table name since @@ -235,8 +236,8 @@ def process(self, input): # interpolate pixels that have a NaN value or are flagged # as DO_NOT_USE or NON_SCIENCE. result = self.pixel_replace.run(result) + # Resample time. Dependent on whether the data is IFU or not. - resample_complete = None if exptype in IFU_EXPTYPES: result = self.cube_build.run(result) try: @@ -293,6 +294,10 @@ def process(self, input): if exptype in ['MIR_MRS']: result = self.spectral_leak.run(result) + elif exptype not in IFU_EXPTYPES: + # Extract spectra and combine results + result = self.extract_1d.run(result) + result = self.combine_1d.run(result) else: self.log.warning( 'Resampling was not completed. Skipping extract_1d.' diff --git a/jwst/regtest/test_miri_lrs_optimal_extraction.py b/jwst/regtest/test_miri_lrs_optimal_extraction.py new file mode 100644 index 0000000000..548d513de8 --- /dev/null +++ b/jwst/regtest/test_miri_lrs_optimal_extraction.py @@ -0,0 +1,91 @@ +from astropy.io.fits.diff import FITSDiff +import pytest + +from jwst.stpipe import Step + + +@pytest.fixture(scope="module") +def run_spec2_optimal(rtdata_module): + """Run the calwebb_spec2 pipeline on MIRI LRS fixedslit with optimal extraction.""" + rtdata = rtdata_module + + # Get the spec2 ASN and its members + rtdata.get_asn("miri/lrs/jw01530-o005_20221202t204827_spec2_00001_asn.json") + + # Run the calwebb_spec2 pipeline with optimal extraction and recommended + # parameters, saving intermediate files + args = ["calwebb_spec2", rtdata.input, + "--output_file=jw01530005001_03103_00001_mirimage_opt", + "--steps.resample_spec.skip=true", + "--steps.pixel_replace.skip=true", + "--steps.extract_1d.extraction_type=optimal", + "--steps.extract_1d.use_source_posn=true", + "--steps.extract_1d.model_nod_pair=true", + "--steps.extract_1d.optimize_psf_location=true", + "--steps.extract_1d.save_profile=true", + "--steps.extract_1d.save_scene_model=true", + "--steps.extract_1d.save_residual_image=true", + ] + Step.from_cmdline(args) + + +@pytest.fixture(scope="module") +def run_spec3_optimal(rtdata_module): + """Run the calwebb_spec3 pipeline on MIRI LRS fixedslit with optimal extraction.""" + rtdata = rtdata_module + + # Get the spec3 ASN and its members + rtdata.get_asn("miri/lrs/jw01530-o005_20221202t204827_spec3_00001_asn.json") + + # Run the calwebb_spec3 pipeline with optimal extraction and recommended + # parameters, saving intermediate files + args = ["calwebb_spec3", rtdata.input, + "--steps.resample_spec.skip=true", + "--steps.pixel_replace.skip=true", + "--steps.extract_1d.extraction_type=optimal", + "--steps.extract_1d.use_source_posn=true", + "--steps.extract_1d.model_nod_pair=true", + "--steps.extract_1d.optimize_psf_location=true", + "--steps.extract_1d.save_profile=true", + "--steps.extract_1d.save_scene_model=true", + "--steps.extract_1d.save_residual_image=true", + ] + Step.from_cmdline(args) + + +@pytest.mark.bigdata +@pytest.mark.parametrize("suffix", ["x1d", "profile", "scene_model", "residual"]) +def test_miri_lrs_slit_spec2_optimal( + run_spec2_optimal, fitsdiff_default_kwargs, rtdata_module, suffix): + """Regression test for MIRI LRS FS optimal extraction in spec2.""" + rtdata = rtdata_module + output = f"jw01530005001_03103_00001_mirimage_opt_{suffix}.fits" + rtdata.output = output + + # Get the truth files + rtdata.get_truth(f"truth/test_miri_lrs_optimal_extraction/{output}") + + # Compare the results + diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs) + assert diff.identical, diff.report() + + +@pytest.mark.bigdata +@pytest.mark.parametrize("suffix", ["0_x1d", "1_x1d", + "0_profile", "1_profile", + "0_scene_model", "1_scene_model", + "0_residual", "1_residual", + "c1d"]) +def test_miri_lrs_slit_spec3_optimal( + run_spec3_optimal, fitsdiff_default_kwargs, rtdata_module, suffix): + """Regression test for MIRI LRS FS optimal extraction in spec3.""" + rtdata = rtdata_module + output = f"jw01530-o005_t004_miri_p750l_{suffix}.fits" + rtdata.output = output + + # Get the truth files + rtdata.get_truth(f"truth/test_miri_lrs_optimal_extraction/{output}") + + # Compare the results + diff = FITSDiff(rtdata.output, rtdata.truth, **fitsdiff_default_kwargs) + assert diff.identical, diff.report() diff --git a/jwst/resample/resample_spec_step.py b/jwst/resample/resample_spec_step.py index 7aa406091b..aea280c961 100755 --- a/jwst/resample/resample_spec_step.py +++ b/jwst/resample/resample_spec_step.py @@ -54,7 +54,7 @@ def process(self, input): # If input is a 3D rateints MultiSlitModel (unsupported) skip the step if model_is_msm and len((input_new[0]).shape) == 3: self.log.warning('Resample spec step will be skipped') - input_new.meta.cal_step.resample_spec = 'SKIPPED' + input_new.meta.cal_step.resample = 'SKIPPED' return input_new diff --git a/pyproject.toml b/pyproject.toml index 6de7c48487..1edbf4be84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "scipy>=1.14.1", "spherical-geometry>=1.2.22", "stcal>=1.11.0,<1.12.0", - "stdatamodels>=2.2.0,<2.3.0", + "stdatamodels @ git+https://github.com/spacetelescope/stdatamodels.git@main", "stpipe>=0.8.0,<0.9.0", "stsci.imagestats>=1.6.3", "synphot>=1.2",