Skip to content

Commit

Permalink
Address initial PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 10, 2023
1 parent 4b6284e commit 938d834
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 46 deletions.
98 changes: 55 additions & 43 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@ class WorkUnit:
"""The work unit is a storage and I/O class for all of the data
needed for a full run of KBMOD, including the: the search parameters,
data files, and the data provenance metadata.
A WorkUnit file is a FITS file with the following extensions:
0 - Primary header with overall metadata
1 - The data provenance metadata
2 - The search parameters
3 through X - Image layers with either alternating layers of science, variance,
mask, and PSF. Layers may be empty if no data is provided.
"""

def __init__(self, im_stack=None, config=None):
self.im_stack = im_stack
self.config = config

@classmethod
def from_file(cls, filename):
"""Create a WorkUnit from a single file.
def from_fits(cls, filename):
"""Create a WorkUnit from a single FITS file.
A WorkUnit is written as a FITS file with the following extensions:
0 - Primary header with overall metadata
1 or "metadata" - The data provenance metadata
2 or "kbmod_config" - The search parameters
3+ - Image extensions for the science layer ("SCI_i"),
variance layer ("VAR_i"), mask layer ("MSK_i"), and
PSF ("PSF_i") of each image.
Parameters
----------
Expand All @@ -43,18 +44,16 @@ def from_file(cls, filename):
if not Path(filename).is_file():
raise ValueError(f"WorkUnit file {filename} not found.")

result = None
imgs = []
with fits.open(filename) as hdul:
num_layers = len(hdul)
if num_layers < 5:
raise ValueError(f"WorkUnit file has too few extensions {len(hdul)}.")

# TODO - Read in provenance metadata from extension #1.

# Read in the search parameters from extension #2.
config = SearchConfiguration()
config_table = Table(hdul[2].data)
config.set_from_table(config_table)
# Read in the search parameters from the 'kbmod_config' extension.
config = SearchConfiguration.from_hdu(hdul["kbmod_config"])

# Read the size and order information from the primary header.
num_images = hdul[0].header["NUMIMG"]
Expand All @@ -65,34 +64,30 @@ def from_file(cls, filename):
)

# Read in all the image files.
imgs = []
for i in range(num_images):
ext_num = 3 + 4 * i

# Read in science and variance layers.
sci = hdu_to_raw_image(hdul[ext_num])
var = hdu_to_raw_image(hdul[ext_num + 1])
# Read in science, variance, and mask layers.
sci = hdu_to_raw_image(hdul[f"SCI_{i}"])
var = hdu_to_raw_image(hdul[f"VAR_{i}"])
msk = hdu_to_raw_image(hdul[f"MSK_{i}"])

# Read the mask layer if it exists.
msk = hdu_to_raw_image(hdul[ext_num + 2])
if msk is None:
msk = RawImage(np.zeros((sci.get_height(), sci.get_width())))

# Check if the PSF layer exists.
if hdul[ext_num + 3].header["NAXIS"] == 2:
p = PSF(hdul[ext_num + 3].data)
else:
p = PSF(1e-8)
# Read the PSF layer.
p = PSF(hdul[f"PSF_{i}"].data)

imgs.append(LayeredImage(sci, var, msk, p))

im_stack = ImageStack(imgs)
im_stack = ImageStack(imgs)
return WorkUnit(im_stack=im_stack, config=config)

result = WorkUnit(im_stack=im_stack, config=config)
return result
def to_fits(self, filename, overwrite=False):
"""Write the WorkUnit to a single FITS file.
def write_to_file(self, filename, overwrite=False):
"""Write the WorkUnit to a single file.
Uses the following extensions:
0 - Primary header with overall metadata
1 or "metadata" - The data provenance metadata
2 or "kbmod_config" - The search parameters
3+ - Image extensions for the science layer ("SCI_i"),
variance layer ("VAR_i"), mask layer ("MSK_i"), and
PSF ("PSF_i") of each image.
Parameters
----------
Expand All @@ -111,18 +106,35 @@ def write_to_file(self, filename, overwrite=False):
pri = fits.PrimaryHDU()
pri.header["NUMIMG"] = self.im_stack.img_count()
hdul.append(pri)
hdul.append(fits.BinTableHDU())
hdul.append(fits.BinTableHDU(self.config.to_table(make_fits_safe=True)))

meta_hdu = fits.BinTableHDU()
meta_hdu.name = "metadata"
hdul.append(meta_hdu)

config_hdu = self.config.to_hdu()
config_hdu.name = "kbmod_config"
hdul.append(config_hdu)

for i in range(self.im_stack.img_count()):
layered = self.im_stack.get_single_image(i)
hdul.append(raw_image_to_hdu(layered.get_science()))
hdul.append(raw_image_to_hdu(layered.get_variance()))
hdul.append(raw_image_to_hdu(layered.get_mask()))

sci_hdu = raw_image_to_hdu(layered.get_science())
sci_hdu.name = f"SCI_{i}"
hdul.append(sci_hdu)

var_hdu = raw_image_to_hdu(layered.get_variance())
var_hdu.name = f"VAR_{i}"
hdul.append(var_hdu)

msk_hdu = raw_image_to_hdu(layered.get_mask())
msk_hdu.name = f"MSK_{i}"
hdul.append(msk_hdu)

p = layered.get_psf()
psf_array = np.array(p.get_kernel()).reshape((p.get_dim(), p.get_dim()))
hdul.append(fits.hdu.image.ImageHDU(psf_array))
psf_hdu = fits.hdu.image.ImageHDU(psf_array)
psf_hdu.name = f"PSF_{i}"
hdul.append(psf_hdu)

hdul.writeto(filename)

Expand All @@ -141,7 +153,7 @@ def raw_image_to_hdu(img):
The image extension.
"""
# Expensive copy. To be removed with RawImage refactor.
np_pixels = np.array(img.get_all_pixels()).astype("float32", casting="same_kind")
np_pixels = np.array(img.get_all_pixels(), dtype=np.single)
np_array = np_pixels.reshape((img.get_height(), img.get_width()))
hdu = fits.hdu.image.ImageHDU(np_array)
hdu.header["MJD"] = img.get_obstime()
Expand All @@ -162,7 +174,7 @@ def hdu_to_raw_image(hdu):
The RawImage if there is valid data and None otherwise.
"""
img = None
if hdu.header["NAXIS"] == 2:
if isinstance(hdu, fits.hdu.image.ImageHDU):
# Expensive copy. To be removed with RawImage refactor.
img = RawImage(hdu.data)
if "MJD" in hdu.header:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ def test_save_and_load_fits(self):
self.assertFalse(Path(file_path).is_file())

# Unable to load non-existent file.
self.assertRaises(ValueError, WorkUnit.from_file, file_path)
self.assertRaises(ValueError, WorkUnit.from_fits, file_path)

# Write out the existing WorkUnit
work = WorkUnit(self.im_stack, self.config)
work.write_to_file(file_path)
work.to_fits(file_path)
self.assertTrue(Path(file_path).is_file())

# Read in the file and check that the values agree.
work2 = WorkUnit.from_file(file_path)
work2 = WorkUnit.from_fits(file_path)
self.assertEqual(work2.im_stack.img_count(), self.num_images)
for i in range(self.num_images):
li = work2.im_stack.get_single_image(i)
Expand Down

0 comments on commit 938d834

Please sign in to comment.