diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py new file mode 100644 index 000000000..27644c693 --- /dev/null +++ b/src/kbmod/work_unit.py @@ -0,0 +1,183 @@ +import math + +from astropy.io import fits +from astropy.table import Table +import numpy as np +from pathlib import Path + +from kbmod.configuration import SearchConfiguration +from kbmod.search import ImageStack, LayeredImage, PSF, RawImage + + +class WorkUnit: + """The work unit is a storage and I/O class for all of the data + needed for a full run of KBMOD, including the: the search parameters, + data files, and the data provenance metadata. + """ + + def __init__(self, im_stack=None, config=None): + self.im_stack = im_stack + self.config = config + + @classmethod + def from_fits(cls, filename): + """Create a WorkUnit from a single FITS file. + + The FITS file will have at least the following extensions: + + 0. ``PRIMARY`` extension + 1. ``METADATA`` extension containing provenance + 2. ``KBMOD_CONFIG`` extension containing search parameters + 3. (+) any additional image extensions are named ``SCI_i``, ``VAR_i``, ``MSK_i`` + and ``PSF_i`` for the science, variance, mask and PSF of each image respectively, + where ``i`` runs from 0 to number of images in the `WorkUnit`. + + Parameters + ---------- + filename : `str` + The file to load. + + Returns + ------- + result : `WorkUnit` + The loaded WorkUnit. + """ + if not Path(filename).is_file(): + raise ValueError(f"WorkUnit file {filename} not found.") + + imgs = [] + with fits.open(filename) as hdul: + num_layers = len(hdul) + if num_layers < 5: + raise ValueError(f"WorkUnit file has too few extensions {len(hdul)}.") + + # TODO - Read in provenance metadata from extension #1. + + # Read in the search parameters from the 'kbmod_config' extension. + config = SearchConfiguration.from_hdu(hdul["kbmod_config"]) + + # Read the size and order information from the primary header. + num_images = hdul[0].header["NUMIMG"] + if len(hdul) != 4 * num_images + 3: + raise ValueError( + f"WorkUnit wrong number of extensions. Expected " + f"{4 * num_images + 3}. Found {len(hdul)}." + ) + + # Read in all the image files. + for i in range(num_images): + # Read in science, variance, and mask layers. + sci = hdu_to_raw_image(hdul[f"SCI_{i}"]) + var = hdu_to_raw_image(hdul[f"VAR_{i}"]) + msk = hdu_to_raw_image(hdul[f"MSK_{i}"]) + + # Read the PSF layer. + p = PSF(hdul[f"PSF_{i}"].data) + + imgs.append(LayeredImage(sci, var, msk, p)) + + im_stack = ImageStack(imgs) + return WorkUnit(im_stack=im_stack, config=config) + + def to_fits(self, filename, overwrite=False): + """Write the WorkUnit to a single FITS file. + + Uses the following extensions: + 0 - Primary header with overall metadata + 1 or "metadata" - The data provenance metadata + 2 or "kbmod_config" - The search parameters + 3+ - Image extensions for the science layer ("SCI_i"), + variance layer ("VAR_i"), mask layer ("MSK_i"), and + PSF ("PSF_i") of each image. + + Parameters + ---------- + filename : `str` + The file to which to write the data. + overwrite : bool + Indicates whether to overwrite an existing file. + """ + if Path(filename).is_file() and not overwrite: + print(f"Warning: WorkUnit file {filename} already exists.") + return + + # Set up the initial HDU list, including the primary header + # the metadata (empty), and the configuration. + hdul = fits.HDUList() + pri = fits.PrimaryHDU() + pri.header["NUMIMG"] = self.im_stack.img_count() + hdul.append(pri) + + meta_hdu = fits.BinTableHDU() + meta_hdu.name = "metadata" + hdul.append(meta_hdu) + + config_hdu = self.config.to_hdu() + config_hdu.name = "kbmod_config" + hdul.append(config_hdu) + + for i in range(self.im_stack.img_count()): + layered = self.im_stack.get_single_image(i) + + sci_hdu = raw_image_to_hdu(layered.get_science()) + sci_hdu.name = f"SCI_{i}" + hdul.append(sci_hdu) + + var_hdu = raw_image_to_hdu(layered.get_variance()) + var_hdu.name = f"VAR_{i}" + hdul.append(var_hdu) + + msk_hdu = raw_image_to_hdu(layered.get_mask()) + msk_hdu.name = f"MSK_{i}" + hdul.append(msk_hdu) + + p = layered.get_psf() + psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim())) + psf_hdu = fits.hdu.image.ImageHDU(psf_array) + psf_hdu.name = f"PSF_{i}" + hdul.append(psf_hdu) + + hdul.writeto(filename) + + +def raw_image_to_hdu(img): + """Helper function that creates a HDU out of RawImage. + + Parameters + ---------- + img : `RawImage` + The RawImage to convert. + + Returns + ------- + hdu : `astropy.io.fits.hdu.image.ImageHDU` + The image extension. + """ + # Expensive copy. To be removed with RawImage refactor. + np_pixels = np.array(img.get_all_pixels(), dtype=np.single) + np_array = np_pixels.reshape((img.get_height(), img.get_width())) + hdu = fits.hdu.image.ImageHDU(np_array) + hdu.header["MJD"] = img.get_obstime() + return hdu + + +def hdu_to_raw_image(hdu): + """Helper function that creates a RawImage from a HDU. + + Parameters + ---------- + hdu : `astropy.io.fits.hdu.image.ImageHDU` + The image extension. + + Returns + ------- + img : `RawImage` or None + The RawImage if there is valid data and None otherwise. + """ + img = None + if isinstance(hdu, fits.hdu.image.ImageHDU): + # Expensive copy. To be removed with RawImage refactor. + img = RawImage(hdu.data) + if "MJD" in hdu.header: + img.set_obstime(hdu.header["MJD"]) + return img diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py new file mode 100644 index 000000000..ca6e6e8e8 --- /dev/null +++ b/tests/test_work_unit.py @@ -0,0 +1,103 @@ +from astropy.io import fits +from astropy.table import Table +import tempfile +import unittest +from pathlib import Path + +from kbmod.configuration import SearchConfiguration +import kbmod.search as kb +from kbmod.work_unit import hdu_to_raw_image, raw_image_to_hdu, WorkUnit + + +class test_work_unit(unittest.TestCase): + def setUp(self): + self.num_images = 5 + self.width = 50 + self.height = 70 + self.images = [None] * self.num_images + self.p = [None] * self.num_images + for i in range(self.num_images): + self.p[i] = kb.PSF(5.0 / float(2 * i + 1)) + self.images[i] = kb.LayeredImage( + ("layered_test_%i" % i), + self.width, + self.height, + 2.0, # noise_level + 4.0, # variance + 2.0 * i + 1.0, # time + self.p[i], + ) + + # Include one masked pixel per time step at (10, 10 + i). + mask = self.images[i].get_mask() + mask.set_pixel(10, 10 + i, 1) + + self.im_stack = kb.ImageStack(self.images) + + self.config = SearchConfiguration() + self.config.set("im_filepath", "Here") + self.config.set("num_obs", self.num_images) + self.config.set("mask_bits_dict", {"A": 1, "B": 2}) + self.config.set("repeated_flag_keys", None) + + def test_create(self): + work = WorkUnit(self.im_stack, self.config) + self.assertEqual(work.im_stack.img_count(), 5) + self.assertEqual(work.config["im_filepath"], "Here") + self.assertEqual(work.config["num_obs"], 5) + + def test_save_and_load_fits(self): + with tempfile.TemporaryDirectory() as dir_name: + file_path = f"{dir_name}/test_workunit.fits" + self.assertFalse(Path(file_path).is_file()) + + # Unable to load non-existent file. + self.assertRaises(ValueError, WorkUnit.from_fits, file_path) + + # Write out the existing WorkUnit + work = WorkUnit(self.im_stack, self.config) + work.to_fits(file_path) + self.assertTrue(Path(file_path).is_file()) + + # Read in the file and check that the values agree. + work2 = WorkUnit.from_fits(file_path) + self.assertEqual(work2.im_stack.img_count(), self.num_images) + for i in range(self.num_images): + li = work2.im_stack.get_single_image(i) + self.assertEqual(li.get_width(), self.width) + self.assertEqual(li.get_height(), self.height) + + # Check the three image layers match. + sci1 = li.get_science() + var1 = li.get_variance() + msk1 = li.get_mask() + + li_org = self.im_stack.get_single_image(i) + sci2 = li_org.get_science() + var2 = li_org.get_variance() + msk2 = li_org.get_mask() + + for y in range(self.height): + for x in range(self.width): + self.assertAlmostEqual(sci1.get_pixel(x, y), sci2.get_pixel(x, y)) + self.assertAlmostEqual(var1.get_pixel(x, y), var2.get_pixel(x, y)) + self.assertAlmostEqual(msk1.get_pixel(x, y), msk2.get_pixel(x, y)) + + # Check the PSF layer matches. + p1 = self.p[i] + p2 = li.get_psf() + self.assertEqual(p1.get_dim(), p2.get_dim()) + + for y in range(p1.get_dim()): + for x in range(p1.get_dim()): + self.assertAlmostEqual(p1.get_value(x, y), p2.get_value(x, y)) + + # Check that we read in the configuration values correctly. + self.assertEqual(work2.config["im_filepath"], "Here") + self.assertEqual(work2.config["num_obs"], self.num_images) + self.assertDictEqual(work2.config["mask_bits_dict"], {"A": 1, "B": 2}) + self.assertIsNone(work2.config["repeated_flag_keys"]) + + +if __name__ == "__main__": + unittest.main()