Skip to content

Commit

Permalink
Merge pull request #368 from dirac-institute/workunit
Browse files Browse the repository at this point in the history
Create a framework WorkUnit Class
  • Loading branch information
jeremykubica authored Oct 12, 2023
2 parents 89555b7 + 7175355 commit 806276e
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 0 deletions.
183 changes: 183 additions & 0 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import math

from astropy.io import fits
from astropy.table import Table
import numpy as np
from pathlib import Path

from kbmod.configuration import SearchConfiguration
from kbmod.search import ImageStack, LayeredImage, PSF, RawImage


class WorkUnit:
"""The work unit is a storage and I/O class for all of the data
needed for a full run of KBMOD, including the: the search parameters,
data files, and the data provenance metadata.
"""

def __init__(self, im_stack=None, config=None):
self.im_stack = im_stack
self.config = config

@classmethod
def from_fits(cls, filename):
"""Create a WorkUnit from a single FITS file.
The FITS file will have at least the following extensions:
0. ``PRIMARY`` extension
1. ``METADATA`` extension containing provenance
2. ``KBMOD_CONFIG`` extension containing search parameters
3. (+) any additional image extensions are named ``SCI_i``, ``VAR_i``, ``MSK_i``
and ``PSF_i`` for the science, variance, mask and PSF of each image respectively,
where ``i`` runs from 0 to number of images in the `WorkUnit`.
Parameters
----------
filename : `str`
The file to load.
Returns
-------
result : `WorkUnit`
The loaded WorkUnit.
"""
if not Path(filename).is_file():
raise ValueError(f"WorkUnit file {filename} not found.")

imgs = []
with fits.open(filename) as hdul:
num_layers = len(hdul)
if num_layers < 5:
raise ValueError(f"WorkUnit file has too few extensions {len(hdul)}.")

# TODO - Read in provenance metadata from extension #1.

# Read in the search parameters from the 'kbmod_config' extension.
config = SearchConfiguration.from_hdu(hdul["kbmod_config"])

# Read the size and order information from the primary header.
num_images = hdul[0].header["NUMIMG"]
if len(hdul) != 4 * num_images + 3:
raise ValueError(
f"WorkUnit wrong number of extensions. Expected "
f"{4 * num_images + 3}. Found {len(hdul)}."
)

# Read in all the image files.
for i in range(num_images):
# Read in science, variance, and mask layers.
sci = hdu_to_raw_image(hdul[f"SCI_{i}"])
var = hdu_to_raw_image(hdul[f"VAR_{i}"])
msk = hdu_to_raw_image(hdul[f"MSK_{i}"])

# Read the PSF layer.
p = PSF(hdul[f"PSF_{i}"].data)

imgs.append(LayeredImage(sci, var, msk, p))

im_stack = ImageStack(imgs)
return WorkUnit(im_stack=im_stack, config=config)

def to_fits(self, filename, overwrite=False):
"""Write the WorkUnit to a single FITS file.
Uses the following extensions:
0 - Primary header with overall metadata
1 or "metadata" - The data provenance metadata
2 or "kbmod_config" - The search parameters
3+ - Image extensions for the science layer ("SCI_i"),
variance layer ("VAR_i"), mask layer ("MSK_i"), and
PSF ("PSF_i") of each image.
Parameters
----------
filename : `str`
The file to which to write the data.
overwrite : bool
Indicates whether to overwrite an existing file.
"""
if Path(filename).is_file() and not overwrite:
print(f"Warning: WorkUnit file {filename} already exists.")
return

# Set up the initial HDU list, including the primary header
# the metadata (empty), and the configuration.
hdul = fits.HDUList()
pri = fits.PrimaryHDU()
pri.header["NUMIMG"] = self.im_stack.img_count()
hdul.append(pri)

meta_hdu = fits.BinTableHDU()
meta_hdu.name = "metadata"
hdul.append(meta_hdu)

config_hdu = self.config.to_hdu()
config_hdu.name = "kbmod_config"
hdul.append(config_hdu)

for i in range(self.im_stack.img_count()):
layered = self.im_stack.get_single_image(i)

sci_hdu = raw_image_to_hdu(layered.get_science())
sci_hdu.name = f"SCI_{i}"
hdul.append(sci_hdu)

var_hdu = raw_image_to_hdu(layered.get_variance())
var_hdu.name = f"VAR_{i}"
hdul.append(var_hdu)

msk_hdu = raw_image_to_hdu(layered.get_mask())
msk_hdu.name = f"MSK_{i}"
hdul.append(msk_hdu)

p = layered.get_psf()
psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim()))
psf_hdu = fits.hdu.image.ImageHDU(psf_array)
psf_hdu.name = f"PSF_{i}"
hdul.append(psf_hdu)

hdul.writeto(filename)


def raw_image_to_hdu(img):
"""Helper function that creates a HDU out of RawImage.
Parameters
----------
img : `RawImage`
The RawImage to convert.
Returns
-------
hdu : `astropy.io.fits.hdu.image.ImageHDU`
The image extension.
"""
# Expensive copy. To be removed with RawImage refactor.
np_pixels = np.array(img.get_all_pixels(), dtype=np.single)
np_array = np_pixels.reshape((img.get_height(), img.get_width()))
hdu = fits.hdu.image.ImageHDU(np_array)
hdu.header["MJD"] = img.get_obstime()
return hdu


def hdu_to_raw_image(hdu):
"""Helper function that creates a RawImage from a HDU.
Parameters
----------
hdu : `astropy.io.fits.hdu.image.ImageHDU`
The image extension.
Returns
-------
img : `RawImage` or None
The RawImage if there is valid data and None otherwise.
"""
img = None
if isinstance(hdu, fits.hdu.image.ImageHDU):
# Expensive copy. To be removed with RawImage refactor.
img = RawImage(hdu.data)
if "MJD" in hdu.header:
img.set_obstime(hdu.header["MJD"])
return img
103 changes: 103 additions & 0 deletions tests/test_work_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from astropy.io import fits
from astropy.table import Table
import tempfile
import unittest
from pathlib import Path

from kbmod.configuration import SearchConfiguration
import kbmod.search as kb
from kbmod.work_unit import hdu_to_raw_image, raw_image_to_hdu, WorkUnit


class test_work_unit(unittest.TestCase):
def setUp(self):
self.num_images = 5
self.width = 50
self.height = 70
self.images = [None] * self.num_images
self.p = [None] * self.num_images
for i in range(self.num_images):
self.p[i] = kb.PSF(5.0 / float(2 * i + 1))
self.images[i] = kb.LayeredImage(
("layered_test_%i" % i),
self.width,
self.height,
2.0, # noise_level
4.0, # variance
2.0 * i + 1.0, # time
self.p[i],
)

# Include one masked pixel per time step at (10, 10 + i).
mask = self.images[i].get_mask()
mask.set_pixel(10, 10 + i, 1)

self.im_stack = kb.ImageStack(self.images)

self.config = SearchConfiguration()
self.config.set("im_filepath", "Here")
self.config.set("num_obs", self.num_images)
self.config.set("mask_bits_dict", {"A": 1, "B": 2})
self.config.set("repeated_flag_keys", None)

def test_create(self):
work = WorkUnit(self.im_stack, self.config)
self.assertEqual(work.im_stack.img_count(), 5)
self.assertEqual(work.config["im_filepath"], "Here")
self.assertEqual(work.config["num_obs"], 5)

def test_save_and_load_fits(self):
with tempfile.TemporaryDirectory() as dir_name:
file_path = f"{dir_name}/test_workunit.fits"
self.assertFalse(Path(file_path).is_file())

# Unable to load non-existent file.
self.assertRaises(ValueError, WorkUnit.from_fits, file_path)

# Write out the existing WorkUnit
work = WorkUnit(self.im_stack, self.config)
work.to_fits(file_path)
self.assertTrue(Path(file_path).is_file())

# Read in the file and check that the values agree.
work2 = WorkUnit.from_fits(file_path)
self.assertEqual(work2.im_stack.img_count(), self.num_images)
for i in range(self.num_images):
li = work2.im_stack.get_single_image(i)
self.assertEqual(li.get_width(), self.width)
self.assertEqual(li.get_height(), self.height)

# Check the three image layers match.
sci1 = li.get_science()
var1 = li.get_variance()
msk1 = li.get_mask()

li_org = self.im_stack.get_single_image(i)
sci2 = li_org.get_science()
var2 = li_org.get_variance()
msk2 = li_org.get_mask()

for y in range(self.height):
for x in range(self.width):
self.assertAlmostEqual(sci1.get_pixel(x, y), sci2.get_pixel(x, y))
self.assertAlmostEqual(var1.get_pixel(x, y), var2.get_pixel(x, y))
self.assertAlmostEqual(msk1.get_pixel(x, y), msk2.get_pixel(x, y))

# Check the PSF layer matches.
p1 = self.p[i]
p2 = li.get_psf()
self.assertEqual(p1.get_dim(), p2.get_dim())

for y in range(p1.get_dim()):
for x in range(p1.get_dim()):
self.assertAlmostEqual(p1.get_value(x, y), p2.get_value(x, y))

# Check that we read in the configuration values correctly.
self.assertEqual(work2.config["im_filepath"], "Here")
self.assertEqual(work2.config["num_obs"], self.num_images)
self.assertDictEqual(work2.config["mask_bits_dict"], {"A": 1, "B": 2})
self.assertIsNone(work2.config["repeated_flag_keys"])


if __name__ == "__main__":
unittest.main()

0 comments on commit 806276e

Please sign in to comment.