diff --git a/src/kbmod/__init__.py b/src/kbmod/__init__.py index a9558a895..f8a3b9837 100644 --- a/src/kbmod/__init__.py +++ b/src/kbmod/__init__.py @@ -11,6 +11,7 @@ from . import ( analysis, analysis_utils, + data_interface, file_utils, filters, jointfit_functions, diff --git a/src/kbmod/analysis_utils.py b/src/kbmod/analysis_utils.py index 66c0106e8..a249b92dc 100644 --- a/src/kbmod/analysis_utils.py +++ b/src/kbmod/analysis_utils.py @@ -2,10 +2,8 @@ import os import time -from astropy.io import fits -from astropy.wcs import WCS import numpy as np -from scipy.special import erfinv # import mpmath +from scipy.special import erfinv import kbmod.search as kb @@ -15,141 +13,6 @@ from .result_list import ResultList, ResultRow -class Interface: - """This class manages is responsible for loading in data from .fits - and auxiliary files. - """ - - def __init__(self): - return - - def load_images( - self, - im_filepath, - time_file, - psf_file, - mjd_lims, - default_psf, - verbose=False, - ): - """This function loads images and ingests them into a search object. - - Parameters - ---------- - im_filepath : string - Image file path from which to load images. - time_file : string - File name containing image times. - psf_file : string - File name containing the image-specific PSFs. - If set to None the code will use the provided default psf for - all images. - mjd_lims : list of ints - Optional MJD limits on the images to search. - default_psf : `psf` - The default PSF in case no image-specific PSF is provided. - verbose : bool - Use verbose output (mainly for debugging). - - Returns - ------- - stack : `kbmod.ImageStack` - The stack of images loaded. - wcs_list : `list` - A list of `astropy.wcs.WCS` objects for each image. - visit_times : `list` - A list of MJD times. - """ - print("---------------------------------------") - print("Loading Images") - print("---------------------------------------") - - # Load a mapping from visit numbers to the visit times. This dictionary stays - # empty if no time file is specified. - image_time_dict = FileUtils.load_time_dictionary(time_file) - if verbose: - print(f"Loaded {len(image_time_dict)} time stamps.") - - # Load a mapping from visit numbers to PSFs. This dictionary stays - # empty if no time file is specified. - image_psf_dict = FileUtils.load_psf_dictionary(psf_file) - if verbose: - print(f"Loaded {len(image_psf_dict)} image PSFs stamps.") - - # Retrieve the list of visits (file names) in the data directory. - patch_visits = sorted(os.listdir(im_filepath)) - - # Load the images themselves. - images = [] - visit_times = [] - wcs_list = [] - for visit_file in np.sort(patch_visits): - # Skip non-fits files. - if not ".fits" in visit_file: - if verbose: - print(f"Skipping non-FITS file {visit_file}") - continue - - # Compute the full file path for loading. - full_file_path = os.path.join(im_filepath, visit_file) - - # Try loading information from the FITS header. - visit_id = None - with fits.open(full_file_path) as hdu_list: - curr_wcs = WCS(hdu_list[1].header) - - # If the visit ID is in header (using Rubin tags), use for the visit ID. - # Otherwise extract it from the filename. - if "IDNUM" in hdu_list[0].header: - visit_id = str(hdu_list[0].header["IDNUM"]) - else: - name = os.path.split(full_file_path)[-1] - visit_id = FileUtils.visit_from_file_name(name) - - # Skip files without a valid visit ID. - if visit_id is None: - if verbose: - print(f"WARNING: Unable to extract visit ID for {visit_file}.") - continue - - # Check if the image has a specific PSF. - psf = default_psf - if visit_id in image_psf_dict: - psf = kb.PSF(image_psf_dict[visit_id]) - - # Load the image file and set its time. - if verbose: - print(f"Loading file: {full_file_path}") - img = kb.LayeredImage(full_file_path, psf) - time_stamp = img.get_obstime() - - # Overload the header's time stamp if needed. - if visit_id in image_time_dict: - time_stamp = image_time_dict[visit_id] - img.set_obstime(time_stamp) - - if time_stamp <= 0.0: - if verbose: - print(f"WARNING: No valid timestamp provided for {visit_file}.") - continue - - # Check if we should filter the record based on the time bounds. - if mjd_lims is not None and (time_stamp < mjd_lims[0] or time_stamp > mjd_lims[1]): - if verbose: - print(f"Pruning file {visit_file} by timestamp={time_stamp}.") - continue - - # Save image, time, and WCS information. - visit_times.append(time_stamp) - images.append(img) - wcs_list.append(curr_wcs) - - print(f"Loaded {len(images)} images") - stack = kb.ImageStack(images) - - return (stack, wcs_list, visit_times) - - class PostProcess: """This class manages the post-processing utilities used to filter out and otherwise remove false positives from the KBMOD search. This includes, diff --git a/src/kbmod/data_interface.py b/src/kbmod/data_interface.py new file mode 100644 index 000000000..2b146eb42 --- /dev/null +++ b/src/kbmod/data_interface.py @@ -0,0 +1,145 @@ +import os + +from astropy.io import fits +from astropy.wcs import WCS +import numpy as np + +import kbmod.search as kb + +from .file_utils import * +from .filters.stats_filters import * + + +class Interface: + """This class manages is responsible for loading in data from .fits + and auxiliary files. + """ + + def __init__(self): + return + + def load_images( + self, + im_filepath, + time_file, + psf_file, + mjd_lims, + default_psf, + verbose=False, + ): + """This function loads images and ingests them into a search object. + + Parameters + ---------- + im_filepath : string + Image file path from which to load images. + time_file : string + File name containing image times. + psf_file : string + File name containing the image-specific PSFs. + If set to None the code will use the provided default psf for + all images. + mjd_lims : list of ints + Optional MJD limits on the images to search. + default_psf : `psf` + The default PSF in case no image-specific PSF is provided. + verbose : bool + Use verbose output (mainly for debugging). + + Returns + ------- + stack : `kbmod.ImageStack` + The stack of images loaded. + wcs_list : `list` + A list of `astropy.wcs.WCS` objects for each image. + visit_times : `list` + A list of MJD times. + """ + print("---------------------------------------") + print("Loading Images") + print("---------------------------------------") + + # Load a mapping from visit numbers to the visit times. This dictionary stays + # empty if no time file is specified. + image_time_dict = FileUtils.load_time_dictionary(time_file) + if verbose: + print(f"Loaded {len(image_time_dict)} time stamps.") + + # Load a mapping from visit numbers to PSFs. This dictionary stays + # empty if no time file is specified. + image_psf_dict = FileUtils.load_psf_dictionary(psf_file) + if verbose: + print(f"Loaded {len(image_psf_dict)} image PSFs stamps.") + + # Retrieve the list of visits (file names) in the data directory. + patch_visits = sorted(os.listdir(im_filepath)) + + # Load the images themselves. + images = [] + visit_times = [] + wcs_list = [] + for visit_file in np.sort(patch_visits): + # Skip non-fits files. + if not ".fits" in visit_file: + if verbose: + print(f"Skipping non-FITS file {visit_file}") + continue + + # Compute the full file path for loading. + full_file_path = os.path.join(im_filepath, visit_file) + + # Try loading information from the FITS header. + visit_id = None + with fits.open(full_file_path) as hdu_list: + curr_wcs = WCS(hdu_list[1].header) + + # If the visit ID is in header (using Rubin tags), use for the visit ID. + # Otherwise extract it from the filename. + if "IDNUM" in hdu_list[0].header: + visit_id = str(hdu_list[0].header["IDNUM"]) + else: + name = os.path.split(full_file_path)[-1] + visit_id = FileUtils.visit_from_file_name(name) + + # Skip files without a valid visit ID. + if visit_id is None: + if verbose: + print(f"WARNING: Unable to extract visit ID for {visit_file}.") + continue + + # Check if the image has a specific PSF. + psf = default_psf + if visit_id in image_psf_dict: + psf = kb.PSF(image_psf_dict[visit_id]) + + # Load the image file and set its time. + if verbose: + print(f"Loading file: {full_file_path}") + img = kb.LayeredImage(full_file_path, psf) + time_stamp = img.get_obstime() + + # Overload the header's time stamp if needed. + if visit_id in image_time_dict: + time_stamp = image_time_dict[visit_id] + img.set_obstime(time_stamp) + + if time_stamp <= 0.0: + if verbose: + print(f"WARNING: No valid timestamp provided for {visit_file}.") + continue + + # Check if we should filter the record based on the time bounds. + if mjd_lims is not None and (time_stamp < mjd_lims[0] or time_stamp > mjd_lims[1]): + if verbose: + print(f"Pruning file {visit_file} by timestamp={time_stamp}.") + continue + + # Save image, time, and WCS information. + visit_times.append(time_stamp) + images.append(img) + wcs_list.append(curr_wcs) + + print(f"Loaded {len(images)} images") + stack = kb.ImageStack(images) + + return (stack, wcs_list, visit_times) diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 14af551a2..9f95d997b 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -12,7 +12,8 @@ import kbmod.search as kb -from .analysis_utils import Interface, PostProcess +from .analysis_utils import PostProcess +from .data_interface import Interface from .configuration import SearchConfiguration from .masking import ( BitVectorMasker, diff --git a/tests/test_analysis_utils.py b/tests/test_analysis_utils.py index 00e1659be..681b4274e 100644 --- a/tests/test_analysis_utils.py +++ b/tests/test_analysis_utils.py @@ -1,6 +1,7 @@ import unittest from kbmod.analysis_utils import * +from kbmod.data_interface import Interface from kbmod.fake_data_creator import add_fake_object from kbmod.result_list import * from kbmod.search import * @@ -376,51 +377,6 @@ def test_load_and_filter_results_lh(self): self.assertEqual(results.results[0].trajectory.y, 30) self.assertEqual(results.results[1].trajectory.y, 40) - def test_file_load_basic(self): - loader = Interface() - stack, wcs_list, mjds = loader.load_images( - get_absolute_data_path("fake_images"), - None, - None, - [0, 157130.2], - PSF(1.0), - verbose=False, - ) - self.assertEqual(stack.img_count(), 4) - - # Check that each image loaded corrected. - true_times = [57130.2, 57130.21, 57130.22, 57131.2] - for i in range(stack.img_count()): - img = stack.get_single_image(i) - self.assertEqual(img.get_width(), 64) - self.assertEqual(img.get_height(), 64) - self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005) - self.assertAlmostEqual(1.0, img.get_psf().get_std()) - - def test_file_load_extra(self): - p = PSF(1.0) - - loader = Interface() - stack, wcs_list, mjds = loader.load_images( - get_absolute_data_path("fake_images"), - get_absolute_data_path("fake_times.dat"), - get_absolute_data_path("fake_psfs.dat"), - [0, 157130.2], - p, - verbose=False, - ) - self.assertEqual(stack.img_count(), 4) - - # Check that each image loaded corrected. - true_times = [57130.2, 57130.21, 57130.22, 57162.0] - psfs_std = [1.0, 1.0, 1.3, 1.0] - for i in range(stack.img_count()): - img = stack.get_single_image(i) - self.assertEqual(img.get_width(), 64) - self.assertEqual(img.get_height(), 64) - self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005) - self.assertAlmostEqual(psfs_std[i], img.get_psf().get_std()) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_data_interface.py b/tests/test_data_interface.py new file mode 100644 index 000000000..6682e8c90 --- /dev/null +++ b/tests/test_data_interface.py @@ -0,0 +1,55 @@ +from kbmod.data_interface import Interface +import unittest +from utils.utils_for_tests import get_absolute_data_path +from kbmod.search import * + + +class test_data_interface(unittest.TestCase): + def test_file_load_basic(self): + loader = Interface() + stack, wcs_list, mjds = loader.load_images( + get_absolute_data_path("fake_images"), + None, + None, + [0, 157130.2], + PSF(1.0), + verbose=False, + ) + self.assertEqual(stack.img_count(), 4) + + # Check that each image loaded corrected. + true_times = [57130.2, 57130.21, 57130.22, 57131.2] + for i in range(stack.img_count()): + img = stack.get_single_image(i) + self.assertEqual(img.get_width(), 64) + self.assertEqual(img.get_height(), 64) + self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005) + self.assertAlmostEqual(1.0, img.get_psf().get_std()) + + def test_file_load_extra(self): + p = PSF(1.0) + + loader = Interface() + stack, wcs_list, mjds = loader.load_images( + get_absolute_data_path("fake_images"), + get_absolute_data_path("fake_times.dat"), + get_absolute_data_path("fake_psfs.dat"), + [0, 157130.2], + p, + verbose=False, + ) + self.assertEqual(stack.img_count(), 4) + + # Check that each image loaded corrected. + true_times = [57130.2, 57130.21, 57130.22, 57162.0] + psfs_std = [1.0, 1.0, 1.3, 1.0] + for i in range(stack.img_count()): + img = stack.get_single_image(i) + self.assertEqual(img.get_width(), 64) + self.assertEqual(img.get_height(), 64) + self.assertAlmostEqual(img.get_obstime(), true_times[i], delta=0.005) + self.assertAlmostEqual(psfs_std[i], img.get_psf().get_std()) + + +if __name__ == "__main__": + unittest.main()