Skip to content

Commit

Permalink
use ndarray comparison options in wcs comparison (#941)
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram authored Oct 17, 2023
2 parents c5197ba + 8e3083c commit c5864bf
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ general

- Use tolerance for more comparisons in ``compare_asdf`` [#917]

- Use array comparison options (including ``nan`` equality) when
comparing ``WCS`` objects during ``compare_asdf`` [#941]

ramp_fitting
------------

Expand Down
26 changes: 11 additions & 15 deletions romancal/regtest/regtestdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,26 +633,22 @@ def _wcs_to_ra_dec(wcs):
return wcs(x, y)


class WCSOperator(BaseOperator):
class WCSOperator(NDArrayTypeOperator):
def give_up_diffing(self, level, diff_instance):
# for comparing wcs instances this function evaluates
# each wcs and compares the resulting ra and dec outputs
# TODO should we compare the bounding boxes?
ra_a, dec_a = _wcs_to_ra_dec(level.t1)
ra_b, dec_b = _wcs_to_ra_dec(level.t2)
meta = {}
for name, a, b in [("ra", ra_a, ra_b), ("dec", dec_a, dec_b)]:
# TODO do we want to do something fancier than allclose?
if not np.allclose(a, b):
meta[name] = {
"abs_diff": np.abs(a - b),
}
if meta:
diff_instance.custom_report_result(
"wcs_differ",
level,
meta,
)
ra_diff = self._compare_arrays(ra_a, ra_b)
dec_diff = self._compare_arrays(dec_a, dec_b)
difference = {}
if ra_diff:
difference["ra"] = ra_diff
if dec_diff:
difference["dec"] = dec_diff
if difference:
diff_instance.custom_report_result("wcs_differ", level, difference)
return True


Expand Down Expand Up @@ -712,7 +708,7 @@ def compare_asdf(result, truth, ignore=None, rtol=1e-05, atol=1e-08, equal_nan=T
),
TimeOperator(types=[astropy.time.Time]),
TableOperator(rtol, atol, equal_nan, types=[astropy.table.Table]),
WCSOperator(types=[gwcs.WCS]),
WCSOperator(rtol, atol, equal_nan, types=[gwcs.WCS]),
]
# warnings can be seen in regtest runs which indicate
# that ddtrace logs are evaluated at times after the below
Expand Down
9 changes: 9 additions & 0 deletions romancal/regtest/test_regtestdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
from roman_datamodels import datamodels as rdm
from roman_datamodels import maker_utils

from romancal.assign_wcs.assign_wcs_step import load_wcs
from romancal.regtest.regtestdata import compare_asdf


def _add_wcs(tmp_path, model):
dfn = tmp_path / "wcs_distortion.asdf"
distortion_model = rdm.DistortionRefModel(maker_utils.mk_distortion())
distortion_model.save(dfn)
load_wcs(model, {"distortion": dfn})


@pytest.mark.parametrize("modification", [None, "small", "large"])
def test_compare_asdf(tmp_path, modification):
fn0 = tmp_path / "test0.asdf"
fn1 = tmp_path / "test1.asdf"
l2 = rdm.ImageModel(maker_utils.mk_level2_image(shape=(100, 100)))
_add_wcs(tmp_path, l2)
l2.save(fn0)
atol = 0.0001
if modification == "small":
Expand Down

0 comments on commit c5864bf

Please sign in to comment.