diff --git a/tests/resample/__init__.py b/tests/resample/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/resample/conftest.py b/tests/resample/conftest.py new file mode 100644 index 000000000..fb7c34ddc --- /dev/null +++ b/tests/resample/conftest.py @@ -0,0 +1,34 @@ +""" Test various utility functions """ +import pytest + +from astropy import wcs as fitswcs +import numpy as np + +from . helpers import make_gwcs + + +@pytest.fixture(scope='module') +def wcs_gwcs(): + crval = (150.0, 2.0) + crpix = (500.0, 500.0) + shape = (1000, 1000) + pscale = 0.06 / 3600 + return make_gwcs(crpix, crval, pscale, shape) + + +@pytest.fixture(scope='module') +def wcs_fitswcs(wcs_gwcs): + fits_wcs = fitswcs.WCS(wcs_gwcs.to_fits_sip()) + fits_wcs.pixel_area = wcs_gwcs.pixel_area + fits_wcs.pixel_scale = wcs_gwcs.pixel_scale + return fits_wcs + + +@pytest.fixture(scope='module') +def wcs_slicedwcs(wcs_gwcs): + xmin, xmax = 100, 500 + slices = (slice(xmin, xmax), slice(xmin, xmax)) + sliced_wcs = fitswcs.wcsapi.SlicedLowLevelWCS(wcs_gwcs, slices) + sliced_wcs.pixel_area = wcs_gwcs.pixel_area + sliced_wcs.pixel_scale = wcs_gwcs.pixel_scale + return sliced_wcs diff --git a/tests/resample/helpers.py b/tests/resample/helpers.py new file mode 100644 index 000000000..78ad9a230 --- /dev/null +++ b/tests/resample/helpers.py @@ -0,0 +1,114 @@ +from astropy.nddata.bitmask import BitFlagNameMap +from astropy import coordinates as coord +from astropy.modeling import models as astmodels + +from gwcs import coordinate_frames as cf +from gwcs.wcstools import wcs_from_fiducial +import numpy as np + +from stcal.alignment import compute_s_region_imaging + + +class JWST_DQ_FLAG_DEF(BitFlagNameMap): + DO_NOT_USE = 1 + SATURATED = 2 + JUMP_DET = 4 + + +def make_gwcs(crpix, crval, pscale, shape): + """ Simulate a gwcs from FITS WCS parameters. + + crpix - tuple of floats + crval - tuple of floats (RA, DEC) + pscale - pixel scale in degrees + shape - array shape (numpy's convention) + + """ + prj = astmodels.Pix2Sky_TAN() + fiducial = np.array(crval) + + pc = np.array([[-1., 0.], [0., 1.]]) + pc_matrix = astmodels.AffineTransformation2D(pc, name='pc_rotation_matrix') + scale = (astmodels.Scale(pscale, name='cdelt1') & + astmodels.Scale(pscale, name='cdelt2')) + transform = pc_matrix | scale + + out_frame = cf.CelestialFrame( + name='world', + axes_names=('lon', 'lat'), + reference_frame=coord.ICRS() + ) + input_frame = cf.Frame2D(name="detector") + wnew = wcs_from_fiducial( + fiducial, + coordinate_frame=out_frame, + projection=prj, + transform=transform, + input_frame=input_frame + ) + + output_bounding_box = ( + (-0.5, float(shape[1]) - 0.5), + (-0.5, float(shape[0]) - 0.5) + ) + offset1, offset2 = crpix + offsets = (astmodels.Shift(-offset1, name='crpix1') & + astmodels.Shift(-offset2, name='crpix2')) + + wnew.insert_transform('detector', offsets, after=True) + wnew.bounding_box = output_bounding_box + + tr = wnew.pipeline[0].transform + pix_area = ( + np.deg2rad(tr['cdelt1'].factor.value) * + np.deg2rad(tr['cdelt2'].factor.value) + ) + + wnew.pixel_area = pix_area + wnew.pixel_scale = pscale + wnew.pixel_shape = shape[::-1] + wnew.array_shape = shape + + return wnew + + +def make_input_model(crpix, crval, pscale, shape, group_id=1, exptime=1): + w = make_gwcs( + crpix=crpix, + crval=crval, + pscale=pscale, + shape=shape + ) + + model = { + "data": np.zeros(shape, dtype=np.float32), + "dq": np.zeros(shape, dtype=np.int32), + + # meta: + "filename": "", + "group_id": group_id, + "s_region": compute_s_region_imaging(w), + "wcs": w, + "bunit_data": "MJy", + + "exposure_time": exptime, + "start_time": 0.0, + "end_time": exptime, + "duration": exptime, + "measurement_time": exptime, + "effective_exposure_time": exptime, + "elapsed_exposure_time": exptime, + + "pixelarea_steradians": w.pixel_area, + "pixelarea_arcsecsq": w.pixel_area * (np.rad2deg(1) * 3600)**2, + + "level": 0.0, # sky level + "subtracted": False, + } + + for arr in ["var_flat", "var_rnoise", "var_poisson"]: + model[arr] = np.ones(shape, dtype=np.float32) + + model["err"] = np.sqrt(3.0) * np.ones(shape, dtype=np.float32) + + return model diff --git a/tests/resample/test_resample.py b/tests/resample/test_resample.py new file mode 100644 index 000000000..f6e8e1c50 --- /dev/null +++ b/tests/resample/test_resample.py @@ -0,0 +1,57 @@ +from stcal.resample import Resample + +import numpy as np + +from . helpers import make_gwcs, make_input_model, JWST_DQ_FLAG_DEF + + +def test_resample_defaults(): + crval = (150.0, 2.0) + crpix = (500.0, 500.0) + shape = (1000, 1000) + pscale = 0.06 / 3600 + + output_wcs = make_gwcs( + crpix=(600, 600), + crval=crval, + pscale=pscale, + shape=(1200, 1200) + ) + + nmodels = 4 + + resample = Resample(n_input_models=nmodels, output_wcs=output_wcs) + resample.dq_flag_name_map = JWST_DQ_FLAG_DEF + + influx = 0.0 + ttime = 0.0 + + for k in range(nmodels): + im = make_input_model( + crpix=tuple(i - 6 * k for i in crpix), + crval=crval, + pscale=pscale, + shape=shape, + group_id=k + 1, + exptime=k + 1 + ) + im["data"][:, :] = k + 0.5 + influx += k + 0.5 + ttime += k + 1 + + resample.add_model(im) + + resample.finalize() + + odata = resample.output_model["data"] + oweight = resample.output_model["wht"] + + assert resample.output_model["pointings"] == nmodels + assert resample.output_model["exposure_time"] == ttime + + # next assert assumes constant IVM + assert abs(np.sum(odata * oweight) - influx * np.prod(shape)) < 1.0e-6 + + assert np.nansum(resample.output_model["var_flat"]) > 0.0 + assert np.nansum(resample.output_model["var_poisson"]) > 0.0 + assert np.nansum(resample.output_model["var_rnoise"]) > 0.0 diff --git a/tests/resample/test_resample_utils.py b/tests/resample/test_resample_utils.py new file mode 100644 index 000000000..4da667340 --- /dev/null +++ b/tests/resample/test_resample_utils.py @@ -0,0 +1,213 @@ +""" Test various utility functions """ +import asdf +from asdf_astropy.testing.helpers import assert_model_equal + +from gwcs import coordinate_frames as cf +from numpy.testing import assert_array_equal +import numpy as np +import pytest + +from stcal.resample.utils import ( + build_mask, + bytes2human, + compute_mean_pixel_area, + get_tmeasure, + is_imaging_wcs, + resample_range, + load_custom_wcs, +) + +from . helpers import JWST_DQ_FLAG_DEF + + +DQ = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) +BITVALUES = 2**0 + 2**2 +BITVALUES_STR = f'{2**0}, {2**2}' +BITVALUES_INV_STR = f'~{2**0}, {2**2}' +JWST_NAMES = 'DO_NOT_USE,JUMP_DET' +JWST_NAMES_INV = '~' + JWST_NAMES + + +def _assert_frame_equal(a, b): + """ Copied from `gwcs`'s test_wcs.py """ + __tracebackhide__ = True + + assert type(a) is type(b) + + if a is None: + return + + if not isinstance(a, cf.CoordinateFrame): + return a == b + + assert a.name == b.name # nosec + assert a.axes_order == b.axes_order # nosec + assert a.axes_names == b.axes_names # nosec + assert a.unit == b.unit # nosec + assert a.reference_frame == b.reference_frame # nosec + + +def _assert_wcs_equal(a, b): + """ Based on corresponding function from `gwcs`'s test_wcs.py """ + assert a.name == b.name # nosec + + assert a.pixel_shape == b.pixel_shape + assert a.array_shape == b.array_shape + if a.array_shape is not None: + assert a.array_shape == b.pixel_shape[::-1] + + assert len(a.available_frames) == len(b.available_frames) # nosec + for a_step, b_step in zip(a.pipeline, b.pipeline): + _assert_frame_equal(a_step.frame, b_step.frame) + assert_model_equal(a_step.transform, b_step.transform) + + +@pytest.mark.parametrize( + 'dq, bitvalues, expected', [ + (DQ, 0, np.array([1, 0, 0, 0, 0, 0, 0, 0, 0])), + (DQ, BITVALUES, np.array([1, 1, 0, 0, 1, 1, 0, 0, 0])), + (DQ, BITVALUES_STR, np.array([1, 1, 0, 0, 1, 1, 0, 0, 0])), + (DQ, BITVALUES_INV_STR, np.array([1, 0, 1, 0, 0, 0, 0, 0, 1])), + (DQ, JWST_NAMES, np.array([1, 1, 0, 0, 1, 1, 0, 0, 0])), + (DQ, JWST_NAMES_INV, np.array([1, 0, 1, 0, 0, 0, 0, 0, 1])), + (DQ, None, np.array([1, 1, 1, 1, 1, 1, 1, 1, 1])), + ] +) +def test_build_mask(dq, bitvalues, expected): + """ Test logic of mask building + + Parameters + ---------- + dq: numpy.array + The input data quality array + + bitvalues: int or str + The bitvalues to match against + + expected: numpy.array + Expected mask array + """ + result = build_mask(dq, bitvalues, flag_name_map=JWST_DQ_FLAG_DEF) + assert_array_equal(result, expected) + + +@pytest.mark.parametrize( + "data_shape, bbox, exception, truth", + [ + ((1, 2, 3), ((1, 500), (0, 350)), True, None), + ((1, 2, 3), None, True, None), + ((1, ), ((1, 500), (0, 350)), True, None), + ((1, ), None, True, None), + ((1000, 800), ((1, 500), ), True, None), + ((1000, 800), ((1, 500), (0, 350), (0, 350)), True, None), + ((1, ), ((1, 500), (0, 350)), True, None), + ((1200, 1400), ((700, 300), (600, 800)), False, (700, 700, 600, 800)), + ((1200, 1400), ((600, 800), (700, 300)), False, (600, 800, 700, 700)), + ((1200, 1400), ((300, 700), (600, 800)), False, (300, 700, 600, 800)), + ((750, 470), ((300, 700), (600, 800)), False, (300, 469, 600, 749)), + ((750, 470), ((-5, -1), (-800, -600)), False, (0, 0, 0, 0)), + ((750, 470), None, False, (0, 469, 0, 749)), + ((-750, -470), None, False, (0, 0, 0, 0)), + ] +) +def test_resample_range(data_shape, bbox, exception, truth): + if exception: + with pytest.raises(ValueError): + resample_range(data_shape, bbox) + return + + xyminmax = resample_range(data_shape, bbox) + assert np.allclose(xyminmax, truth, rtol=0, atol=1e-12) + + +def test_load_custom_wcs_no_shape(tmpdir, wcs_gwcs): + """ + Test loading a WCS from an asdf file. + """ + wcs_file = str(tmpdir / "wcs.asdf") + wcs_gwcs.pixel_shape = None + wcs_gwcs.array_shape = None + wcs_gwcs.bounding_box = None + + with asdf.AsdfFile({"wcs": wcs_gwcs}) as af: + af.write_to(wcs_file) + + with pytest.raises(ValueError): + load_custom_wcs(wcs_file, output_shape=None) + + +@pytest.mark.parametrize( + "array_shape, pixel_shape, output_shape, expected", + [ + # (None, None, None, (1000, 1000)), # from the bounding box + # # (None, (123, 456), None, (456, 123)), # fails + # ((456, 123), None, None, (456, 123)), + # ((456, 123), None, (567, 890), (890, 567)), + ((456, 123), (123, 456), (567, 890), (890, 567)), + ((456, 123), (123, 456), None, (890, 567)), + ] +) +def test_load_custom_wcs(tmpdir, wcs_gwcs, array_shape, pixel_shape, + output_shape, expected): + """ + Test loading a WCS from an asdf file. `expected` is expected + ``wcs.array_shape``. + + """ + wcs_file = str(tmpdir / "wcs.asdf") + + wcs_gwcs.pixel_shape = pixel_shape + wcs_gwcs.array_shape = array_shape + + with asdf.AsdfFile({"wcs": wcs_gwcs}) as af: + af.write_to(wcs_file) + + if output_shape is not None: + wcs_gwcs.array_shape = output_shape[::-1] + + wcs_read = load_custom_wcs(wcs_file, output_shape=output_shape) + + assert wcs_read.array_shape == expected + + _assert_wcs_equal(wcs_gwcs, wcs_read) + + +def test_get_tmeasure(): + model = { + "measurement_time": 12.34, + "exposure_time": 23.45, + } + + assert get_tmeasure(model) == (12.34, True) + + model["measurement_time"] = None + assert get_tmeasure(model) == (23.45, False) + + del model["measurement_time"] + assert get_tmeasure(model) == (23.45, False) + + del model["exposure_time"] + with pytest.raises(KeyError): + get_tmeasure(model) + + +@pytest.mark.parametrize( + "n, readable", + [ + (10000, "9.8K"), + (100001221, "95.4M") + ] +) +def test_bytes2human(n, readable): + assert bytes2human(n) == readable + + +def test_is_imaging_wcs(wcs_gwcs): + assert is_imaging_wcs(wcs_gwcs) + + +def test_compute_mean_pixel_area(wcs_gwcs): + area = np.deg2rad(wcs_gwcs.pixel_scale)**2 + assert abs( + compute_mean_pixel_area(wcs_gwcs) / area - 1.0 + ) < 1e-5