-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #368 from dirac-institute/workunit
Create a framework WorkUnit Class
- Loading branch information
Showing
2 changed files
with
286 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import math | ||
|
||
from astropy.io import fits | ||
from astropy.table import Table | ||
import numpy as np | ||
from pathlib import Path | ||
|
||
from kbmod.configuration import SearchConfiguration | ||
from kbmod.search import ImageStack, LayeredImage, PSF, RawImage | ||
|
||
|
||
class WorkUnit: | ||
"""The work unit is a storage and I/O class for all of the data | ||
needed for a full run of KBMOD, including the: the search parameters, | ||
data files, and the data provenance metadata. | ||
""" | ||
|
||
def __init__(self, im_stack=None, config=None): | ||
self.im_stack = im_stack | ||
self.config = config | ||
|
||
@classmethod | ||
def from_fits(cls, filename): | ||
"""Create a WorkUnit from a single FITS file. | ||
The FITS file will have at least the following extensions: | ||
0. ``PRIMARY`` extension | ||
1. ``METADATA`` extension containing provenance | ||
2. ``KBMOD_CONFIG`` extension containing search parameters | ||
3. (+) any additional image extensions are named ``SCI_i``, ``VAR_i``, ``MSK_i`` | ||
and ``PSF_i`` for the science, variance, mask and PSF of each image respectively, | ||
where ``i`` runs from 0 to number of images in the `WorkUnit`. | ||
Parameters | ||
---------- | ||
filename : `str` | ||
The file to load. | ||
Returns | ||
------- | ||
result : `WorkUnit` | ||
The loaded WorkUnit. | ||
""" | ||
if not Path(filename).is_file(): | ||
raise ValueError(f"WorkUnit file {filename} not found.") | ||
|
||
imgs = [] | ||
with fits.open(filename) as hdul: | ||
num_layers = len(hdul) | ||
if num_layers < 5: | ||
raise ValueError(f"WorkUnit file has too few extensions {len(hdul)}.") | ||
|
||
# TODO - Read in provenance metadata from extension #1. | ||
|
||
# Read in the search parameters from the 'kbmod_config' extension. | ||
config = SearchConfiguration.from_hdu(hdul["kbmod_config"]) | ||
|
||
# Read the size and order information from the primary header. | ||
num_images = hdul[0].header["NUMIMG"] | ||
if len(hdul) != 4 * num_images + 3: | ||
raise ValueError( | ||
f"WorkUnit wrong number of extensions. Expected " | ||
f"{4 * num_images + 3}. Found {len(hdul)}." | ||
) | ||
|
||
# Read in all the image files. | ||
for i in range(num_images): | ||
# Read in science, variance, and mask layers. | ||
sci = hdu_to_raw_image(hdul[f"SCI_{i}"]) | ||
var = hdu_to_raw_image(hdul[f"VAR_{i}"]) | ||
msk = hdu_to_raw_image(hdul[f"MSK_{i}"]) | ||
|
||
# Read the PSF layer. | ||
p = PSF(hdul[f"PSF_{i}"].data) | ||
|
||
imgs.append(LayeredImage(sci, var, msk, p)) | ||
|
||
im_stack = ImageStack(imgs) | ||
return WorkUnit(im_stack=im_stack, config=config) | ||
|
||
def to_fits(self, filename, overwrite=False): | ||
"""Write the WorkUnit to a single FITS file. | ||
Uses the following extensions: | ||
0 - Primary header with overall metadata | ||
1 or "metadata" - The data provenance metadata | ||
2 or "kbmod_config" - The search parameters | ||
3+ - Image extensions for the science layer ("SCI_i"), | ||
variance layer ("VAR_i"), mask layer ("MSK_i"), and | ||
PSF ("PSF_i") of each image. | ||
Parameters | ||
---------- | ||
filename : `str` | ||
The file to which to write the data. | ||
overwrite : bool | ||
Indicates whether to overwrite an existing file. | ||
""" | ||
if Path(filename).is_file() and not overwrite: | ||
print(f"Warning: WorkUnit file {filename} already exists.") | ||
return | ||
|
||
# Set up the initial HDU list, including the primary header | ||
# the metadata (empty), and the configuration. | ||
hdul = fits.HDUList() | ||
pri = fits.PrimaryHDU() | ||
pri.header["NUMIMG"] = self.im_stack.img_count() | ||
hdul.append(pri) | ||
|
||
meta_hdu = fits.BinTableHDU() | ||
meta_hdu.name = "metadata" | ||
hdul.append(meta_hdu) | ||
|
||
config_hdu = self.config.to_hdu() | ||
config_hdu.name = "kbmod_config" | ||
hdul.append(config_hdu) | ||
|
||
for i in range(self.im_stack.img_count()): | ||
layered = self.im_stack.get_single_image(i) | ||
|
||
sci_hdu = raw_image_to_hdu(layered.get_science()) | ||
sci_hdu.name = f"SCI_{i}" | ||
hdul.append(sci_hdu) | ||
|
||
var_hdu = raw_image_to_hdu(layered.get_variance()) | ||
var_hdu.name = f"VAR_{i}" | ||
hdul.append(var_hdu) | ||
|
||
msk_hdu = raw_image_to_hdu(layered.get_mask()) | ||
msk_hdu.name = f"MSK_{i}" | ||
hdul.append(msk_hdu) | ||
|
||
p = layered.get_psf() | ||
psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim())) | ||
psf_hdu = fits.hdu.image.ImageHDU(psf_array) | ||
psf_hdu.name = f"PSF_{i}" | ||
hdul.append(psf_hdu) | ||
|
||
hdul.writeto(filename) | ||
|
||
|
||
def raw_image_to_hdu(img): | ||
"""Helper function that creates a HDU out of RawImage. | ||
Parameters | ||
---------- | ||
img : `RawImage` | ||
The RawImage to convert. | ||
Returns | ||
------- | ||
hdu : `astropy.io.fits.hdu.image.ImageHDU` | ||
The image extension. | ||
""" | ||
# Expensive copy. To be removed with RawImage refactor. | ||
np_pixels = np.array(img.get_all_pixels(), dtype=np.single) | ||
np_array = np_pixels.reshape((img.get_height(), img.get_width())) | ||
hdu = fits.hdu.image.ImageHDU(np_array) | ||
hdu.header["MJD"] = img.get_obstime() | ||
return hdu | ||
|
||
|
||
def hdu_to_raw_image(hdu): | ||
"""Helper function that creates a RawImage from a HDU. | ||
Parameters | ||
---------- | ||
hdu : `astropy.io.fits.hdu.image.ImageHDU` | ||
The image extension. | ||
Returns | ||
------- | ||
img : `RawImage` or None | ||
The RawImage if there is valid data and None otherwise. | ||
""" | ||
img = None | ||
if isinstance(hdu, fits.hdu.image.ImageHDU): | ||
# Expensive copy. To be removed with RawImage refactor. | ||
img = RawImage(hdu.data) | ||
if "MJD" in hdu.header: | ||
img.set_obstime(hdu.header["MJD"]) | ||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from astropy.io import fits | ||
from astropy.table import Table | ||
import tempfile | ||
import unittest | ||
from pathlib import Path | ||
|
||
from kbmod.configuration import SearchConfiguration | ||
import kbmod.search as kb | ||
from kbmod.work_unit import hdu_to_raw_image, raw_image_to_hdu, WorkUnit | ||
|
||
|
||
class test_work_unit(unittest.TestCase): | ||
def setUp(self): | ||
self.num_images = 5 | ||
self.width = 50 | ||
self.height = 70 | ||
self.images = [None] * self.num_images | ||
self.p = [None] * self.num_images | ||
for i in range(self.num_images): | ||
self.p[i] = kb.PSF(5.0 / float(2 * i + 1)) | ||
self.images[i] = kb.LayeredImage( | ||
("layered_test_%i" % i), | ||
self.width, | ||
self.height, | ||
2.0, # noise_level | ||
4.0, # variance | ||
2.0 * i + 1.0, # time | ||
self.p[i], | ||
) | ||
|
||
# Include one masked pixel per time step at (10, 10 + i). | ||
mask = self.images[i].get_mask() | ||
mask.set_pixel(10, 10 + i, 1) | ||
|
||
self.im_stack = kb.ImageStack(self.images) | ||
|
||
self.config = SearchConfiguration() | ||
self.config.set("im_filepath", "Here") | ||
self.config.set("num_obs", self.num_images) | ||
self.config.set("mask_bits_dict", {"A": 1, "B": 2}) | ||
self.config.set("repeated_flag_keys", None) | ||
|
||
def test_create(self): | ||
work = WorkUnit(self.im_stack, self.config) | ||
self.assertEqual(work.im_stack.img_count(), 5) | ||
self.assertEqual(work.config["im_filepath"], "Here") | ||
self.assertEqual(work.config["num_obs"], 5) | ||
|
||
def test_save_and_load_fits(self): | ||
with tempfile.TemporaryDirectory() as dir_name: | ||
file_path = f"{dir_name}/test_workunit.fits" | ||
self.assertFalse(Path(file_path).is_file()) | ||
|
||
# Unable to load non-existent file. | ||
self.assertRaises(ValueError, WorkUnit.from_fits, file_path) | ||
|
||
# Write out the existing WorkUnit | ||
work = WorkUnit(self.im_stack, self.config) | ||
work.to_fits(file_path) | ||
self.assertTrue(Path(file_path).is_file()) | ||
|
||
# Read in the file and check that the values agree. | ||
work2 = WorkUnit.from_fits(file_path) | ||
self.assertEqual(work2.im_stack.img_count(), self.num_images) | ||
for i in range(self.num_images): | ||
li = work2.im_stack.get_single_image(i) | ||
self.assertEqual(li.get_width(), self.width) | ||
self.assertEqual(li.get_height(), self.height) | ||
|
||
# Check the three image layers match. | ||
sci1 = li.get_science() | ||
var1 = li.get_variance() | ||
msk1 = li.get_mask() | ||
|
||
li_org = self.im_stack.get_single_image(i) | ||
sci2 = li_org.get_science() | ||
var2 = li_org.get_variance() | ||
msk2 = li_org.get_mask() | ||
|
||
for y in range(self.height): | ||
for x in range(self.width): | ||
self.assertAlmostEqual(sci1.get_pixel(x, y), sci2.get_pixel(x, y)) | ||
self.assertAlmostEqual(var1.get_pixel(x, y), var2.get_pixel(x, y)) | ||
self.assertAlmostEqual(msk1.get_pixel(x, y), msk2.get_pixel(x, y)) | ||
|
||
# Check the PSF layer matches. | ||
p1 = self.p[i] | ||
p2 = li.get_psf() | ||
self.assertEqual(p1.get_dim(), p2.get_dim()) | ||
|
||
for y in range(p1.get_dim()): | ||
for x in range(p1.get_dim()): | ||
self.assertAlmostEqual(p1.get_value(x, y), p2.get_value(x, y)) | ||
|
||
# Check that we read in the configuration values correctly. | ||
self.assertEqual(work2.config["im_filepath"], "Here") | ||
self.assertEqual(work2.config["num_obs"], self.num_images) | ||
self.assertDictEqual(work2.config["mask_bits_dict"], {"A": 1, "B": 2}) | ||
self.assertIsNone(work2.config["repeated_flag_keys"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |