Skip to content

Commit

Permalink
Merge pull request #426 from dirac-institute/file_load
Browse files Browse the repository at this point in the history
Simplify logic for different run_search paths
  • Loading branch information
jeremykubica authored Jan 17, 2024
2 parents 751d535 + 62537fc commit aeec0b5
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 60 deletions.
66 changes: 58 additions & 8 deletions src/kbmod/data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from kbmod.configuration import SearchConfiguration
from kbmod.file_utils import *
from kbmod.work_unit import WorkUnit


def load_input_from_individual_files(
Expand Down Expand Up @@ -137,7 +138,7 @@ def load_input_from_individual_files(


def load_input_from_config(config, verbose=False):
"""This function loads images and ingests them into an ImageStack.
"""This function loads images and ingests them into a WorkUnit.
Parameters
----------
Expand All @@ -148,18 +149,67 @@ def load_input_from_config(config, verbose=False):
Returns
-------
stack : `kbmod.ImageStack`
The stack of images loaded.
wcs_list : `list`
A list of `astropy.wcs.WCS` objects for each image.
visit_times : `list`
A list of MJD times.
result : `kbmod.WorkUnit`
The input data as a ``WorkUnit``.
"""
return load_input_from_individual_files(
stack, wcs_list, _ = load_input_from_individual_files(
config["im_filepath"],
config["time_file"],
config["psf_file"],
config["mjd_lims"],
kb.PSF(config["psf_val"]), # Default PSF.
verbose=verbose,
)
return WorkUnit(stack, config, None, wcs_list)


def load_input_from_file(filename, overrides=None):
"""Build a WorkUnit from a single filename which could point to a WorkUnit
or configuration file.
Parameters
----------
filename : `str`
The path and file name of the data to load.
overrides : `dict`, optional
A dictionary of configuration parameters to override. For testing.
Returns
-------
result : `kbmod.WorkUnit`
The input data as a ``WorkUnit``.
Raises
------
``ValueError`` if unable to read the data.
"""
path_var = Path(filename)
if not path_var.is_file():
raise ValueError(f"File {filename} not found.")

work = None

path_suffix = path_var.suffix
if path_suffix == ".yml" or path_suffix == ".yaml":
# Try loading as a WorkUnit first.
with open(filename) as ff:
work = WorkUnit.from_yaml(ff.read(), strict=False)

# If that load did not work, try loading the file as a configuration
# and then using that to load the data files.
if work is None:
config = SearchConfiguration.from_file(filename, strict=False)
if overrides is not None:
config.set_multiple(overrides)
if config["im_filepath"] is not None:
return load_input_from_config(config)
elif ".fits" in filename:
work = WorkUnit.from_fits(filename)

# None of the load paths worked.
if work is None:
raise ValueError(f"Could not interprete {filename}.")

if overrides is not None:
work.config.set_multiple(overrides)
return work
76 changes: 32 additions & 44 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import kbmod.search as kb

from .analysis_utils import PostProcess
from .data_interface import load_input_from_config
from .data_interface import load_input_from_config, load_input_from_file
from .configuration import SearchConfiguration
from .masking import apply_mask_operations
from .result_list import *
Expand Down Expand Up @@ -202,81 +202,69 @@ def run_search(self, config, stack):

return keep

def run_search_from_config(self, config):
"""Run a KBMOD search from a SearchConfiguration object.
def run_search_from_work_unit(self, work):
"""Run a KBMOD search from a WorkUnit object.
Parameters
----------
config : `SearchConfiguration` or `dict`
The configuration object with all the information for the run.
work : `WorkUnit`
The input data and configuration.
Returns
-------
keep : ResultList
The results.
"""
if type(config) is dict:
config = SearchConfiguration.from_dict(config)

# Load the image files.
stack, wcs_list, _ = load_input_from_config(config, verbose=config["debug"])

# Compute the suggested search angle from the images. This is a 12 arcsecond
# segment parallel to the ecliptic is seen under from the image origin.
if config["average_angle"] == None:
center_pixel = (stack.get_width() / 2, stack.get_height() / 2)
config.set("average_angle", self._calc_suggested_angle(wcs_list[0], center_pixel))

return self.run_search(config, stack)
# Set the average angle if it is not set.
if work.config["average_angle"] is None:
center_pixel = (work.im_stack.get_width() / 2, work.im_stack.get_height() / 2)
if work.get_wcs(0) is not None:
work.config.set("average_angle", self._calc_suggested_angle(work.get_wcs(0), center_pixel))
else:
print("WARNING: average_angle is unset and no WCS provided. Using 0.0.")
work.config.set("average_angle", 0.0)

# Run the search.
return self.run_search(work.config, work.im_stack)

def run_search_from_config_file(self, filename, overrides=None):
"""Run a KBMOD search from a configuration file.
def run_search_from_config(self, config):
"""Run a KBMOD search from a SearchConfiguration object
(or corresponding dictionary).
Parameters
----------
filename : `str`
The name of the configuration file.
overrides : `dict`, optional
A dictionary of configuration parameters to override.
config : `SearchConfiguration` or `dict`
The configuration object with all the information for the run.
Returns
-------
keep : ResultList
The results.
"""
config = SearchConfiguration.from_file(filename)
if overrides is not None:
config.set_multiple(overrides)
if type(config) is dict:
config = SearchConfiguration.from_dict(config)

return self.run_search_from_config(config)
# Load the data.
work = load_input_from_config(config)
return self.run_search_from_work_unit(work)

def run_search_from_work_unit_file(self, filename, overrides=None):
"""Run a KBMOD search from a WorkUnit file.
def run_search_from_file(self, filename, overrides=None):
"""Run a KBMOD search from a configuration or WorkUnit file.
Parameters
----------
filename : `str`
The name of the WorkUnit file.
The name of the input file.
overrides : `dict`, optional
A dictionary of configuration parameters to override.
A dictionary of configuration parameters to override. For testing.
Returns
-------
keep : ResultList
The results.
"""
work = WorkUnit.from_fits(filename)

if overrides is not None:
work.config.set_multiple(overrides)

if work.config["average_angle"] == None:
print("WARNING: average_angle is unset. WorkUnit currently uses a default of 0.0")

# TODO: Support the correct setting of the angle.
work.config.set("average_angle", 0.0)

return self.run_search(work.config, work.im_stack)
work = load_input_from_file(filename, overrides)
return self.run_search_from_work_unit(work)

def _count_known_matches(self, result_list, search):
"""Look up the known objects that overlap the images and count how many
Expand Down
20 changes: 19 additions & 1 deletion src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,37 @@ def from_dict(cls, workunit_dict):
return WorkUnit(im_stack=im_stack, config=config, wcs=global_wcs, per_image_wcs=per_image_wcs)

@classmethod
def from_yaml(cls, work_unit):
def from_yaml(cls, work_unit, strict=False):
"""Load a configuration from a YAML string.
Parameters
----------
work_unit : `str` or `_io.TextIOWrapper`
The serialized YAML data.
strict : `bool`
Raise an error if the file is not a WorkUnit.
Returns
-------
result : `WorkUnit` or `None`
Returns the extracted WorkUnit. If the file did not contain a WorkUnit and
strict=False the function will return None.
Raises
------
Raises a ``ValueError`` for any invalid parameters.
"""
yaml_dict = safe_load(work_unit)

# Check if this a WorkUnit yaml file by checking it has the required fields.
required_fields = ["config", "height", "num_images", "sci_imgs", "times", "var_imgs", "width"]
for name in required_fields:
if name not in yaml_dict:
if strict:
raise ValueError(f"Missing required field {name}")
else:
return None

return WorkUnit.from_dict(yaml_dict)

def to_fits(self, filename, overwrite=False):
Expand Down
84 changes: 79 additions & 5 deletions tests/test_data_interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from astropy.wcs import WCS
import os
import tempfile
import unittest
from yaml import dump

from kbmod.configuration import SearchConfiguration
from kbmod.data_interface import (
load_input_from_individual_files,
load_input_from_config,
load_input_from_file,
load_input_from_individual_files,
)
from kbmod.fake_data_creator import FakeDataSet
from kbmod.search import *
from kbmod.work_unit import WorkUnit
from utils.utils_for_tests import get_absolute_data_path


Expand Down Expand Up @@ -60,19 +67,86 @@ def test_file_load_config(self):
config.set("psf_file", get_absolute_data_path("fake_psfs.dat")),
config.set("psf_val", 1.0)

stack, wcs_list, mjds = load_input_from_config(config, verbose=False)
self.assertEqual(stack.img_count(), 4)
worku = load_input_from_config(config, verbose=False)

# Check that each image loaded corrected.
true_times = [57130.2, 57130.21, 57130.22, 57162.0]
psfs_std = [1.0, 1.0, 1.3, 1.0]
for i in range(stack.img_count()):
img = stack.get_single_image(i)
for i in range(worku.im_stack.img_count()):
img = worku.im_stack.get_single_image(i)
self.assertEqual(img.get_width(), 64)
self.assertEqual(img.get_height(), 64)
self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005)
self.assertAlmostEqual(psfs_std[i], img.get_psf().get_std())

# Try writing the configuration to a YAML file and loading.
with tempfile.TemporaryDirectory() as dir_name:
yaml_file_path = os.path.join(dir_name, "test_config.yml")

with self.assertRaises(ValueError):
work_fits = load_input_from_file(yaml_file_path)

config.to_file(yaml_file_path)

work_yml = load_input_from_file(yaml_file_path)
self.assertIsNotNone(work_yml)
self.assertEqual(work_yml.im_stack.img_count(), 4)

def test_file_load_workunit(self):
# Create a fake WCS
fake_wcs = WCS(
{
"WCSAXES": 2,
"CTYPE1": "RA---TAN-SIP",
"CTYPE2": "DEC--TAN-SIP",
"CRVAL1": 200.614997245422,
"CRVAL2": -7.78878863332778,
"CRPIX1": 1033.934327,
"CRPIX2": 2043.548284,
"CTYPE1A": "LINEAR ",
"CTYPE2A": "LINEAR ",
"CUNIT1A": "PIXEL ",
"CUNIT2A": "PIXEL ",
}
)
fake_config = SearchConfiguration()
fake_data = FakeDataSet(64, 64, 11, obs_per_day=10, use_seed=True)
work = WorkUnit(fake_data.stack, fake_config, fake_wcs, None)

with tempfile.TemporaryDirectory() as dir_name:
# Save and load as FITS
fits_file_path = os.path.join(dir_name, "test_workunit.fits")

with self.assertRaises(ValueError):
work_fits = load_input_from_file(fits_file_path)

work.to_fits(fits_file_path)

work_fits = load_input_from_file(fits_file_path)
self.assertIsNotNone(work_fits)
self.assertEqual(work_fits.im_stack.img_count(), 11)

# Save and load as YAML
yaml_file_path = os.path.join(dir_name, "test_workunit.yml")
with open(yaml_file_path, "w") as file:
file.write(work.to_yaml())

work_yml = load_input_from_file(yaml_file_path)
self.assertIsNotNone(work_yml)
self.assertEqual(work_yml.im_stack.img_count(), 11)

def test_file_load_invalid(self):
# Create a YAML file that is neither a configuration nor a WorkUnit.
yaml_str = dump({"Field1": 1, "Field2": False})

with tempfile.TemporaryDirectory() as dir_name:
yaml_file_path = os.path.join(dir_name, "test_invalid.yml")
with open(yaml_file_path, "w") as file:
file.write(yaml_str)

with self.assertRaises(ValueError):
work = load_input_from_file(yaml_file_path)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_demo_config_file(self):
im_filepath = get_absolute_demo_data_path("demo")
config_file = get_absolute_demo_data_path("demo_config.yml")
rs = SearchRunner()
keep = rs.run_search_from_config_file(
keep = rs.run_search_from_file(
config_file,
overrides={"im_filepath": im_filepath},
)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_e2e_work_unit(self):
work.to_fits(file_path)

rs = SearchRunner()
keep = rs.run_search_from_work_unit_file(file_path)
keep = rs.run_search_from_file(file_path)
self.assertGreaterEqual(keep.num_results(), 1)


Expand Down

0 comments on commit aeec0b5

Please sign in to comment.