From 938d8347a499355c2c7b81e003f2fe65dd518d5b Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 10 Oct 2023 16:23:14 -0400 Subject: [PATCH] Address initial PR comments --- src/kbmod/work_unit.py | 98 +++++++++++++++++++++++------------------ tests/test_work_unit.py | 6 +-- 2 files changed, 58 insertions(+), 46 deletions(-) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index a3eb9ebf5..65662e271 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -13,13 +13,6 @@ class WorkUnit: """The work unit is a storage and I/O class for all of the data needed for a full run of KBMOD, including the: the search parameters, data files, and the data provenance metadata. - - A WorkUnit file is a FITS file with the following extensions: - 0 - Primary header with overall metadata - 1 - The data provenance metadata - 2 - The search parameters - 3 through X - Image layers with either alternating layers of science, variance, - mask, and PSF. Layers may be empty if no data is provided. """ def __init__(self, im_stack=None, config=None): @@ -27,8 +20,16 @@ def __init__(self, im_stack=None, config=None): self.config = config @classmethod - def from_file(cls, filename): - """Create a WorkUnit from a single file. + def from_fits(cls, filename): + """Create a WorkUnit from a single FITS file. + + A WorkUnit is written as a FITS file with the following extensions: + 0 - Primary header with overall metadata + 1 or "metadata" - The data provenance metadata + 2 or "kbmod_config" - The search parameters + 3+ - Image extensions for the science layer ("SCI_i"), + variance layer ("VAR_i"), mask layer ("MSK_i"), and + PSF ("PSF_i") of each image. Parameters ---------- @@ -43,7 +44,7 @@ def from_file(cls, filename): if not Path(filename).is_file(): raise ValueError(f"WorkUnit file {filename} not found.") - result = None + imgs = [] with fits.open(filename) as hdul: num_layers = len(hdul) if num_layers < 5: @@ -51,10 +52,8 @@ def from_file(cls, filename): # TODO - Read in provenance metadata from extension #1. - # Read in the search parameters from extension #2. - config = SearchConfiguration() - config_table = Table(hdul[2].data) - config.set_from_table(config_table) + # Read in the search parameters from the 'kbmod_config' extension. + config = SearchConfiguration.from_hdu(hdul["kbmod_config"]) # Read the size and order information from the primary header. num_images = hdul[0].header["NUMIMG"] @@ -65,34 +64,30 @@ def from_file(cls, filename): ) # Read in all the image files. - imgs = [] for i in range(num_images): - ext_num = 3 + 4 * i - - # Read in science and variance layers. - sci = hdu_to_raw_image(hdul[ext_num]) - var = hdu_to_raw_image(hdul[ext_num + 1]) + # Read in science, variance, and mask layers. + sci = hdu_to_raw_image(hdul[f"SCI_{i}"]) + var = hdu_to_raw_image(hdul[f"VAR_{i}"]) + msk = hdu_to_raw_image(hdul[f"MSK_{i}"]) - # Read the mask layer if it exists. - msk = hdu_to_raw_image(hdul[ext_num + 2]) - if msk is None: - msk = RawImage(np.zeros((sci.get_height(), sci.get_width()))) - - # Check if the PSF layer exists. - if hdul[ext_num + 3].header["NAXIS"] == 2: - p = PSF(hdul[ext_num + 3].data) - else: - p = PSF(1e-8) + # Read the PSF layer. + p = PSF(hdul[f"PSF_{i}"].data) imgs.append(LayeredImage(sci, var, msk, p)) - im_stack = ImageStack(imgs) + im_stack = ImageStack(imgs) + return WorkUnit(im_stack=im_stack, config=config) - result = WorkUnit(im_stack=im_stack, config=config) - return result + def to_fits(self, filename, overwrite=False): + """Write the WorkUnit to a single FITS file. - def write_to_file(self, filename, overwrite=False): - """Write the WorkUnit to a single file. + Uses the following extensions: + 0 - Primary header with overall metadata + 1 or "metadata" - The data provenance metadata + 2 or "kbmod_config" - The search parameters + 3+ - Image extensions for the science layer ("SCI_i"), + variance layer ("VAR_i"), mask layer ("MSK_i"), and + PSF ("PSF_i") of each image. Parameters ---------- @@ -111,18 +106,35 @@ def write_to_file(self, filename, overwrite=False): pri = fits.PrimaryHDU() pri.header["NUMIMG"] = self.im_stack.img_count() hdul.append(pri) - hdul.append(fits.BinTableHDU()) - hdul.append(fits.BinTableHDU(self.config.to_table(make_fits_safe=True))) + + meta_hdu = fits.BinTableHDU() + meta_hdu.name = "metadata" + hdul.append(meta_hdu) + + config_hdu = self.config.to_hdu() + config_hdu.name = "kbmod_config" + hdul.append(config_hdu) for i in range(self.im_stack.img_count()): layered = self.im_stack.get_single_image(i) - hdul.append(raw_image_to_hdu(layered.get_science())) - hdul.append(raw_image_to_hdu(layered.get_variance())) - hdul.append(raw_image_to_hdu(layered.get_mask())) + + sci_hdu = raw_image_to_hdu(layered.get_science()) + sci_hdu.name = f"SCI_{i}" + hdul.append(sci_hdu) + + var_hdu = raw_image_to_hdu(layered.get_variance()) + var_hdu.name = f"VAR_{i}" + hdul.append(var_hdu) + + msk_hdu = raw_image_to_hdu(layered.get_mask()) + msk_hdu.name = f"MSK_{i}" + hdul.append(msk_hdu) p = layered.get_psf() psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim())) - hdul.append(fits.hdu.image.ImageHDU(psf_array)) + psf_hdu = fits.hdu.image.ImageHDU(psf_array) + psf_hdu.name = f"PSF_{i}" + hdul.append(psf_hdu) hdul.writeto(filename) @@ -141,7 +153,7 @@ def raw_image_to_hdu(img): The image extension. """ # Expensive copy. To be removed with RawImage refactor. - np_pixels = np.array(img.get_all_pixels()).astype("float32", casting="same_kind") + np_pixels = np.array(img.get_all_pixels(), dtype=np.single) np_array = np_pixels.reshape((img.get_height(), img.get_width())) hdu = fits.hdu.image.ImageHDU(np_array) hdu.header["MJD"] = img.get_obstime() @@ -162,7 +174,7 @@ def hdu_to_raw_image(hdu): The RawImage if there is valid data and None otherwise. """ img = None - if hdu.header["NAXIS"] == 2: + if isinstance(hdu, fits.hdu.image.ImageHDU): # Expensive copy. To be removed with RawImage refactor. img = RawImage(hdu.data) if "MJD" in hdu.header: diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index 5ada7659b..ca6e6e8e8 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -52,15 +52,15 @@ def test_save_and_load_fits(self): self.assertFalse(Path(file_path).is_file()) # Unable to load non-existent file. - self.assertRaises(ValueError, WorkUnit.from_file, file_path) + self.assertRaises(ValueError, WorkUnit.from_fits, file_path) # Write out the existing WorkUnit work = WorkUnit(self.im_stack, self.config) - work.write_to_file(file_path) + work.to_fits(file_path) self.assertTrue(Path(file_path).is_file()) # Read in the file and check that the values agree. - work2 = WorkUnit.from_file(file_path) + work2 = WorkUnit.from_fits(file_path) self.assertEqual(work2.im_stack.img_count(), self.num_images) for i in range(self.num_images): li = work2.im_stack.get_single_image(i)