From 8bcd1464776553929febddb3d07231b8f09d0a79 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 5 Oct 2023 09:35:45 -0400 Subject: [PATCH 1/7] Add an initial benchmark --- benchmarks/bench_filter_stamps.py | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 benchmarks/bench_filter_stamps.py diff --git a/benchmarks/bench_filter_stamps.py b/benchmarks/bench_filter_stamps.py new file mode 100644 index 000000000..bf468228f --- /dev/null +++ b/benchmarks/bench_filter_stamps.py @@ -0,0 +1,72 @@ +import timeit +import numpy as np + +from kbmod.search import * + + +def setup_coadd_stamp(params): + """Create a coadded stamp to test with a single bright spot + slightly off center. + + Parameters + ---------- + params : `StampParameters` + The parameters for stamp generation and filtering. + + Returns + ------- + stamp : `RawImage` + The coadded stamp. + """ + stamp_width = 2 * params.radius + 1 + + stamp = RawImage(stamp_width, stamp_width) + stamp.set_all(0.5) + + # Insert a flux of 50.0 and apply a PSF. + flux = 50.0 + p = PSF(1.0) + psf_dim = p.get_dim() + psf_rad = p.get_radius() + for i in range(psf_dim): + for j in range(psf_dim): + stamp.set_pixel( + (params.radius - 1) - psf_rad + i, # x is one pixel off center + params.radius - psf_rad + j, # y is centered + flux * p.get_value(i, j), + ) + + return stamp + + +def run_benchmark(stamp_radius=10): + params = StampParameters() + params.radius = stamp_radius + params.do_filtering = True + params.stamp_type = StampType.STAMP_MEAN + params.center_thresh = 0.03 + params.peak_offset_x = 1.5 + params.peak_offset_y = 1.5 + params.m01_limit = 0.6 + params.m10_limit = 0.6 + params.m11_limit = 2.0 + params.m02_limit = 35.5 + params.m20_limit = 35.5 + + # Create the stamp. + stamp = setup_coadd_stamp(params) + + # Create an empty search stack. + im_stack = ImageStack([]) + search = StackSearch(im_stack) + + # Do three timing runs and use the mean of the time taken. + tmr = timeit.Timer(stmt="search.filter_stamp(stamp, params)", globals=locals()) + res_time = np.mean(tmr.repeat(repeat=10, number=20)) + return res_time + + +if __name__ == "__main__": + for r in [5, 10, 20]: + res_time = run_benchmark(r) + print(f"Stamp Radius={r} -> Ave Time={res_time}") From fc0d22a2f1cd15c42256e21229cfcee9527d6985 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 5 Oct 2023 10:02:30 -0400 Subject: [PATCH 2/7] Add tests for python filters --- benchmarks/bench_filter_stamps.py | 63 +++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/benchmarks/bench_filter_stamps.py b/benchmarks/bench_filter_stamps.py index bf468228f..b43d36d93 100644 --- a/benchmarks/bench_filter_stamps.py +++ b/benchmarks/bench_filter_stamps.py @@ -1,7 +1,9 @@ import timeit import numpy as np -from kbmod.search import * +from kbmod.filters.stamp_filters import * +from kbmod.result_list import ResultRow +from kbmod.search import ImageStack, PSF, RawImage, StackSearch, StampParameters, StampType, Trajectory def setup_coadd_stamp(params): @@ -39,9 +41,36 @@ def setup_coadd_stamp(params): return stamp -def run_benchmark(stamp_radius=10): +def run_search_benchmark(params): + stamp = setup_coadd_stamp(params) + + # Create an empty search stack. + im_stack = ImageStack([]) + search = StackSearch(im_stack) + + # Do the timing runs. + tmr = timeit.Timer(stmt="search.filter_stamp(stamp, params)", globals=locals()) + res_time = np.mean(tmr.repeat(repeat=10, number=20)) + return res_time + + +def run_row_benchmark(params, create_filter=""): + stamp = setup_coadd_stamp(params) + row = ResultRow(Trajectory(), 10) + row.stamp = np.array(stamp.get_all_pixels()) + + filt = eval(create_filter) + + # Do the timing runs. + full_cmd = "filt.keep_row(row)" + tmr = timeit.Timer(stmt="filt.keep_row(row)", globals=locals()) + res_time = np.mean(tmr.repeat(repeat=10, number=20)) + return res_time + + +def run_all_benchmarks(): params = StampParameters() - params.radius = stamp_radius + params.radius = 5 params.do_filtering = True params.stamp_type = StampType.STAMP_MEAN params.center_thresh = 0.03 @@ -53,20 +82,24 @@ def run_benchmark(stamp_radius=10): params.m02_limit = 35.5 params.m20_limit = 35.5 - # Create the stamp. - stamp = setup_coadd_stamp(params) + print(" Rad | Method | Time") + print("-" * 40) + for r in [2, 5, 10, 20]: + params.radius = r - # Create an empty search stack. - im_stack = ImageStack([]) - search = StackSearch(im_stack) + res_time = run_search_benchmark(params) + print(f" {r:2d} | C++ (all) | {res_time:10.7f}") - # Do three timing runs and use the mean of the time taken. - tmr = timeit.Timer(stmt="search.filter_stamp(stamp, params)", globals=locals()) - res_time = np.mean(tmr.repeat(repeat=10, number=20)) - return res_time + res_time = run_row_benchmark(params, f"StampPeakFilter({r}, 1.5, 1.5)") + print(f" {r:2d} | StampPeakFilter | {res_time:10.7f}") + + res_time = run_row_benchmark(params, f"StampMomentsFilter({r}, 0.6, 0.6, 2.0, 35.5, 35.5)") + print(f" {r:2d} | StampMomentsFilter | {res_time:10.7f}") + + res_time = run_row_benchmark(params, f"StampCenterFilter({r}, False, 0.03)") + print(f" {r:2d} | StampCenterFilter | {res_time:10.7f}") + print("-" * 40) if __name__ == "__main__": - for r in [5, 10, 20]: - res_time = run_benchmark(r) - print(f"Stamp Radius={r} -> Ave Time={res_time}") + run_all_benchmarks() From 49185e36d9e7d52790541229ecd13d0462597e48 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 10 Oct 2023 10:23:45 -0400 Subject: [PATCH 3/7] Address PR comments --- src/kbmod/configuration.py | 167 +++++++++++++++++------------------- tests/test_configuration.py | 76 ++++++++-------- 2 files changed, 119 insertions(+), 124 deletions(-) diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index 1c3c7f8da..efe6df5fc 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -1,10 +1,10 @@ +import ast import math from astropy.io import fits from astropy.table import Table from numpy import result_type from pathlib import Path -import pickle from yaml import dump, safe_load @@ -123,7 +123,19 @@ def set(self, param, value, strict=True): else: self._params[param] = value - def set_from_dict(self, d, strict=True): + def validate(self): + """Check that the configuration has the necessary parameters. + + Raises + ------ + Raises a ``ValueError`` if a parameter is missing. + """ + for p in self._required_params: + if self._params.get(p, None) is None: + raise ValueError(f"Required configuration parameter {p} missing.") + + @classmethod + def from_dict(cls, d, strict=True): """Sets multiple values from a dictionary. Parameters @@ -138,10 +150,13 @@ def set_from_dict(self, d, strict=True): Raises a ``KeyError`` if the parameter is not part on the list of known parameters and ``strict`` is False. """ + config = SearchConfiguration() for key, value in d.items(): - self.set(key, value, strict) + config.set(key, value, strict) + return config - def set_from_table(self, t, strict=True): + @classmethod + def from_table(cls, t, strict=True): """Sets multiple values from an astropy Table with a single row and one column for each parameter. @@ -161,110 +176,102 @@ def set_from_table(self, t, strict=True): """ if len(t) > 1: raise ValueError(f"More than one row in the configuration table ({len(t)}).") + + config = SearchConfiguration() for key in t.colnames: # We use a special indicator for serializing certain types (including # None and dict) to FITS. - if key.startswith("__PICKLED_"): - val = pickle.loads(t[key].value[0]) - key = key[10:] + if key.startswith("__NONE__"): + val = None + key = key[8:] + elif key.startswith("__DICT__"): + val = dict(t[key][0]) + key = key[8:] else: val = t[key][0] - self.set(key, val, strict) - - def to_table(self, make_fits_safe=False): - """Create an astropy table with all the configuration parameters. - - Parameter - --------- - make_fits_safe : `bool` - Override Nones and dictionaries so we can write to FITS. + config.set(key, val, strict) + return config - Returns - ------- - t: `~astropy.table.Table` - The configuration table. - """ - t = Table() - for col in self._params.keys(): - val = self._params[col] - t[col] = [val] - - # If Table does not understand the type, pickle it. - if make_fits_safe and t[col].dtype == "O": - t.remove_column(col) - t["__PICKLED_" + col] = pickle.dumps(val) - - return t - - def validate(self): - """Check that the configuration has the necessary parameters. - - Raises - ------ - Raises a ``ValueError`` if a parameter is missing. - """ - for p in self._required_params: - if self._params.get(p, None) is None: - raise ValueError(f"Required configuration parameter {p} missing.") - - def load_from_yaml_file(self, filename, strict=True): + @classmethod + def from_yaml(cls, config, strict=True): """Load a configuration from a YAML file. Parameters ---------- - filename : `str` - The filename, including path, of the configuration file. + config : `str` or `_io.TextIOWrapper` + The serialized YAML data. strict : `bool` Raise an exception on unknown parameters. Raises ------ - Raises a ``ValueError`` if the configuration file is not found. Raises a ``KeyError`` if the parameter is not part on the list of known parameters and ``strict`` is False. """ - if not Path(filename).is_file(): - raise ValueError(f"Configuration file {filename} not found.") + yaml_params = safe_load(config) + return SearchConfiguration.from_dict(yaml_params, strict) - # Read the user-specified parameters from the file. - file_params = {} - with open(filename, "r") as config: - file_params = safe_load(config) - - # Merge in the new values. - self.set_from_dict(file_params, strict) - - if strict: - self.validate() - - def load_from_fits_file(self, filename, layer=0, strict=True): + @classmethod + def from_hdu(cls, hdu, strict=True): """Load a configuration from a FITS extension file. Parameters ---------- - filename : `str` - The filename, including path, of the configuration file. - layer : `int` - The extension number to use. + hdu : `astropy.io.fits.BinTableHDU` + The HDU from which to parse the configuration information. strict : `bool` Raise an exception on unknown parameters. Raises ------ - Raises a ``ValueError`` if the configuration file is not found. Raises a ``KeyError`` if the parameter is not part on the list of known parameters and ``strict`` is False. """ - if not Path(filename).is_file(): - raise ValueError(f"Configuration file {filename} not found.") + config = SearchConfiguration() + for column in hdu.data.columns: + key = column.name + val = hdu.data[key][0] - # Read the user-specified parameters from the file. - t = Table.read(filename, hdu=layer) - self.set_from_table(t) + # We use a special indicator for serializing certain types (including + # None and dict) to FITS. + if type(val) is str and val == "__NONE__": + val = None + elif key.startswith("__DICT__"): + val = ast.literal_eval(val) + key = key[8:] + + config.set(key, val, strict) + return config + + @classmethod + def from_file(cls, filename, extension=0, strict=True): + if filename.endswith("yaml"): + with open(filename) as ff: + return SearchConfiguration.from_yaml(ff.read()) + elif ".fits" in filename: + with fits.open(filename) as ff: + return SearchConfiguration.from_hdu(ff[extension]) + raise ValueError("Configuration file suffix unrecognized.") + + def to_hdu(self): + """Create a fits HDU with all the configuration parameters. - if strict: - self.validate() + Returns + ------- + hdu : `astropy.io.fits.BinTableHDU` + The HDU with the configuration information. + """ + t = Table() + for col in self._params.keys(): + val = self._params[col] + if val is None: + t[col] = ["__NONE__"] + elif type(val) is dict: + t["__DICT__" + col] = [str(val)] + else: + t[col] = [val] + return fits.table_to_hdu(t) def save_to_yaml_file(self, filename, overwrite=False): """Save a configuration file with the parameters. @@ -282,15 +289,3 @@ def save_to_yaml_file(self, filename, overwrite=False): with open(filename, "w") as file: file.write(dump(self._params)) - - def append_to_fits(self, filename): - """Append the configuration table as a new extension on a FITS file - (creating a new file if needed). - - Parameters - ---------- - filename : str - The filename, including path, of the configuration file. - """ - t = self.to_table(make_fits_safe=True) - t.write(filename, append=True) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 86719c618..bc2e9b9e6 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -31,38 +31,43 @@ def test_set(self): # The set should fail when using unknown parameters and strict checking. self.assertRaises(KeyError, config.set, "My_new_param", 100, strict=True) - def test_set_from_dict(self): - # Everything starts at its default. - config = SearchConfiguration() - self.assertIsNone(config["im_filepath"]) - self.assertEqual(config["num_obs"], 10) - + def test_from_dict(self): d = {"im_filepath": "Here2", "num_obs": 5} - config.set_from_dict(d) + config = SearchConfiguration.from_dict(d) self.assertEqual(config["im_filepath"], "Here2") self.assertEqual(config["num_obs"], 5) - def test_set_from_table(self): - # Everything starts at its default. - config = SearchConfiguration() - self.assertIsNone(config["im_filepath"]) - self.assertEqual(config["num_obs"], 10) + def test_from_hdu(self): + t = Table([["Here3"], [7], ["__NONE__"]], names=("im_filepath", "num_obs", "cluster_type")) + hdu = fits.table_to_hdu(t) - t = Table([["Here3"], [7]], names=("im_filepath", "num_obs")) - config.set_from_table(t) + config = SearchConfiguration.from_hdu(hdu) self.assertEqual(config["im_filepath"], "Here3") self.assertEqual(config["num_obs"], 7) + self.assertIsNone(config["cluster_type"]) - def test_to_table(self): + def test_to_hdu(self): # Everything starts at its default. - config = SearchConfiguration() - d = {"im_filepath": "Here2", "num_obs": 5} - config.set_from_dict(d) - - t = config.to_table() - self.assertEqual(len(t), 1) - self.assertEqual(t["im_filepath"][0], "Here2") - self.assertEqual(t["num_obs"][0], 5) + d = { + "im_filepath": "Here2", + "num_obs": 5, + "cluster_type": None, + "mask_bits_dict": {"bit1": 1, "bit2": 2}, + "do_clustering": False, + "res_filepath": "There", + "ang_arr": [1.0, 2.0, 3.0], + } + config = SearchConfiguration.from_dict(d) + hdu = config.to_hdu() + + self.assertEqual(hdu.data["im_filepath"][0], "Here2") + self.assertEqual(hdu.data["num_obs"][0], 5) + self.assertEqual(hdu.data["cluster_type"][0], "__NONE__") + self.assertEqual(hdu.data["__DICT__mask_bits_dict"][0], "{'bit1': 1, 'bit2': 2}") + self.assertEqual(hdu.data["res_filepath"][0], "There") + self.assertEqual(hdu.data["ang_arr"][0][0], 1.0) + self.assertEqual(hdu.data["ang_arr"][0][1], 2.0) + self.assertEqual(hdu.data["ang_arr"][0][2], 3.0) def test_save_and_load_yaml(self): config = SearchConfiguration() @@ -74,12 +79,11 @@ def test_save_and_load_yaml(self): config.set("mask_grow", 5) with tempfile.TemporaryDirectory() as dir_name: - file_path = f"{dir_name}/tmp_config_data.cfg" + file_path = f"{dir_name}/tmp_config_data.yaml" self.assertFalse(Path(file_path).is_file()) # Unable to load non-existent file. - config2 = SearchConfiguration() - self.assertRaises(ValueError, config2.load_from_yaml_file, file_path) + self.assertRaises(FileNotFoundError, SearchConfiguration.from_file, file_path) # Correctly saves file. config.save_to_yaml_file(file_path) @@ -87,7 +91,7 @@ def test_save_and_load_yaml(self): # Correctly loads file. try: - config2.load_from_yaml_file(file_path) + config2 = SearchConfiguration.from_file(file_path) except ValueError: self.fail("load_configuration() raised ValueError.") @@ -112,21 +116,17 @@ def test_save_and_load_fits(self): self.assertFalse(Path(file_path).is_file()) # Unable to load non-existent file. - config2 = SearchConfiguration() - self.assertRaises(ValueError, config2.load_from_fits_file, file_path) + self.assertRaises(FileNotFoundError, SearchConfiguration.from_file, file_path) - # Generate measningless data for table 0 and the configuration for table 1. - t0 = Table([[1] * 10, [2] * 10, [3] * 10], names=("A", "B", "C")) - t0.write(file_path) - self.assertTrue(Path(file_path).is_file()) - - # Append the FITS data to extension=1 - config.append_to_fits(file_path) - self.assertTrue(Path(file_path).is_file()) + # Generate empty data for the first two tables and config for the third. + hdu0 = fits.PrimaryHDU() + hdu1 = fits.ImageHDU() + hdu_list = fits.HDUList([hdu0, hdu1, config.to_hdu()]) + hdu_list.writeto(file_path) # Correctly loads file. try: - config2.load_from_fits_file(file_path, layer=2) + config2 = SearchConfiguration.from_file(file_path, extension=2) except ValueError: self.fail("load_from_fits_file() raised ValueError.") From 13f712ba8b248aad3c9505369e1df0f7bfa38182 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 10 Oct 2023 11:08:35 -0400 Subject: [PATCH 4/7] Address remaining PR comments --- src/kbmod/configuration.py | 25 +++++++++++++++---------- src/kbmod/run_search.py | 11 ++++++----- tests/test_configuration.py | 35 +++++++++++++++++++++++++++++------ 3 files changed, 50 insertions(+), 21 deletions(-) diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index efe6df5fc..f92967e88 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -245,14 +245,9 @@ def from_hdu(cls, hdu, strict=True): return config @classmethod - def from_file(cls, filename, extension=0, strict=True): - if filename.endswith("yaml"): - with open(filename) as ff: - return SearchConfiguration.from_yaml(ff.read()) - elif ".fits" in filename: - with fits.open(filename) as ff: - return SearchConfiguration.from_hdu(ff[extension]) - raise ValueError("Configuration file suffix unrecognized.") + def from_file(cls, filename, strict=True): + with open(filename) as ff: + return SearchConfiguration.from_yaml(ff.read(), strict) def to_hdu(self): """Create a fits HDU with all the configuration parameters. @@ -273,7 +268,17 @@ def to_hdu(self): t[col] = [val] return fits.table_to_hdu(t) - def save_to_yaml_file(self, filename, overwrite=False): + def to_yaml(self): + """Save a configuration file with the parameters. + + Returns + ------- + result : `str` + The serialized YAML string. + """ + return dump(self._params) + + def to_file(self, filename, overwrite=False): """Save a configuration file with the parameters. Parameters @@ -288,4 +293,4 @@ def save_to_yaml_file(self, filename, overwrite=False): return with open(filename, "w") as file: - file.write(dump(self._params)) + file.write(self.to_yaml()) diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index a874c1d4b..14af551a2 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -44,15 +44,16 @@ class run_search: """ def __init__(self, input_parameters, config_file=None): - self.config = SearchConfiguration() - # Load parameters from a file. if config_file != None: - self.config.load_from_yaml_file(config_file) + self.config = SearchConfiguration.from_file(config_file) + else: + self.config = SearchConfiguration() # Load any additional parameters (overwriting what is there). if len(input_parameters) > 0: - self.config.set_from_dict(input_parameters) + for key, value in input_parameters.items(): + self.config.set(key, value) # Validate the configuration. self.config.validate() @@ -301,7 +302,7 @@ def run_search(self): config_filename = os.path.join( self.config["res_filepath"], f"config_{self.config['output_suffix']}.yml" ) - self.config.save_to_yaml_file(config_filename, overwrite=True) + self.config.to_file(config_filename, overwrite=True) end = time.time() print("Time taken for patch: ", end - start) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index bc2e9b9e6..b1a2a0c5d 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -3,6 +3,7 @@ import tempfile import unittest from pathlib import Path +from yaml import safe_load from kbmod.configuration import SearchConfiguration @@ -47,7 +48,6 @@ def test_from_hdu(self): self.assertIsNone(config["cluster_type"]) def test_to_hdu(self): - # Everything starts at its default. d = { "im_filepath": "Here2", "num_obs": 5, @@ -69,6 +69,30 @@ def test_to_hdu(self): self.assertEqual(hdu.data["ang_arr"][0][1], 2.0) self.assertEqual(hdu.data["ang_arr"][0][2], 3.0) + def test_to_yaml(self): + d = { + "im_filepath": "Here2", + "num_obs": 5, + "cluster_type": None, + "mask_bits_dict": {"bit1": 1, "bit2": 2}, + "do_clustering": False, + "res_filepath": "There", + "ang_arr": [1.0, 2.0, 3.0], + } + config = SearchConfiguration.from_dict(d) + yaml_str = config.to_yaml() + + yaml_dict = safe_load(yaml_str) + self.assertEqual(yaml_dict["im_filepath"], "Here2") + self.assertEqual(yaml_dict["num_obs"], 5) + self.assertEqual(yaml_dict["cluster_type"], None) + self.assertEqual(yaml_dict["mask_bits_dict"]["bit1"], 1) + self.assertEqual(yaml_dict["mask_bits_dict"]["bit2"], 2) + self.assertEqual(yaml_dict["res_filepath"], "There") + self.assertEqual(yaml_dict["ang_arr"][0], 1.0) + self.assertEqual(yaml_dict["ang_arr"][1], 2.0) + self.assertEqual(yaml_dict["ang_arr"][2], 3.0) + def test_save_and_load_yaml(self): config = SearchConfiguration() num_defaults = len(config._params) @@ -86,7 +110,7 @@ def test_save_and_load_yaml(self): self.assertRaises(FileNotFoundError, SearchConfiguration.from_file, file_path) # Correctly saves file. - config.save_to_yaml_file(file_path) + config.to_file(file_path) self.assertTrue(Path(file_path).is_file()) # Correctly loads file. @@ -125,10 +149,9 @@ def test_save_and_load_fits(self): hdu_list.writeto(file_path) # Correctly loads file. - try: - config2 = SearchConfiguration.from_file(file_path, extension=2) - except ValueError: - self.fail("load_from_fits_file() raised ValueError.") + config2 = SearchConfiguration() + with fits.open(file_path) as ff: + config2 = SearchConfiguration.from_hdu(ff[2]) self.assertEqual(len(config2._params), num_defaults) self.assertEqual(config2["im_filepath"], "Here2") From e8d960ae166ba7180013cf263219accb4a9cc263 Mon Sep 17 00:00:00 2001 From: DinoBektesevic Date: Tue, 10 Oct 2023 10:46:28 -0700 Subject: [PATCH 5/7] Don't parse our own types, let yaml do it for us. --- src/kbmod/configuration.py | 47 +++++++------------------------------ tests/test_configuration.py | 27 ++++++++++++++------- 2 files changed, 27 insertions(+), 47 deletions(-) diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index f92967e88..7b86265ed 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -5,6 +5,7 @@ from astropy.table import Table from numpy import result_type from pathlib import Path +import yaml from yaml import dump, safe_load @@ -177,21 +178,10 @@ def from_table(cls, t, strict=True): if len(t) > 1: raise ValueError(f"More than one row in the configuration table ({len(t)}).") - config = SearchConfiguration() - for key in t.colnames: - # We use a special indicator for serializing certain types (including - # None and dict) to FITS. - if key.startswith("__NONE__"): - val = None - key = key[8:] - elif key.startswith("__DICT__"): - val = dict(t[key][0]) - key = key[8:] - else: - val = t[key][0] + # guaranteed to only have 1 element due to check above + params = {col.name: safe_load(col.value[0]) for col in t.values()} + return SearchConfiguration.from_dict(params) - config.set(key, val, strict) - return config @classmethod def from_yaml(cls, config, strict=True): @@ -228,21 +218,8 @@ def from_hdu(cls, hdu, strict=True): Raises a ``KeyError`` if the parameter is not part on the list of known parameters and ``strict`` is False. """ - config = SearchConfiguration() - for column in hdu.data.columns: - key = column.name - val = hdu.data[key][0] - - # We use a special indicator for serializing certain types (including - # None and dict) to FITS. - if type(val) is str and val == "__NONE__": - val = None - elif key.startswith("__DICT__"): - val = ast.literal_eval(val) - key = key[8:] - - config.set(key, val, strict) - return config + t = Table(hdu.data) + return SearchConfiguration.from_table(t) @classmethod def from_file(cls, filename, strict=True): @@ -257,15 +234,9 @@ def to_hdu(self): hdu : `astropy.io.fits.BinTableHDU` The HDU with the configuration information. """ - t = Table() - for col in self._params.keys(): - val = self._params[col] - if val is None: - t[col] = ["__NONE__"] - elif type(val) is dict: - t["__DICT__" + col] = [str(val)] - else: - t[col] = [val] + serialized_dict = {key: dump(val, default_flow_style=True) + for key, val in self._params.items()} + t = Table(rows=[serialized_dict, ]) return fits.table_to_hdu(t) def to_yaml(self): diff --git a/tests/test_configuration.py b/tests/test_configuration.py index b1a2a0c5d..ae8f8ee66 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -39,12 +39,23 @@ def test_from_dict(self): self.assertEqual(config["num_obs"], 5) def test_from_hdu(self): - t = Table([["Here3"], [7], ["__NONE__"]], names=("im_filepath", "num_obs", "cluster_type")) + t = Table( + [ + ["Here3"], + ["7"], + ["null"], + [ + "[1, 2]", + ], + ], + names=("im_filepath", "num_obs", "cluster_type", "ang_arr"), + ) hdu = fits.table_to_hdu(t) config = SearchConfiguration.from_hdu(hdu) self.assertEqual(config["im_filepath"], "Here3") self.assertEqual(config["num_obs"], 7) + self.assertEqual(config["ang_arr"], [1, 2]) self.assertIsNone(config["cluster_type"]) def test_to_hdu(self): @@ -60,14 +71,12 @@ def test_to_hdu(self): config = SearchConfiguration.from_dict(d) hdu = config.to_hdu() - self.assertEqual(hdu.data["im_filepath"][0], "Here2") - self.assertEqual(hdu.data["num_obs"][0], 5) - self.assertEqual(hdu.data["cluster_type"][0], "__NONE__") - self.assertEqual(hdu.data["__DICT__mask_bits_dict"][0], "{'bit1': 1, 'bit2': 2}") - self.assertEqual(hdu.data["res_filepath"][0], "There") - self.assertEqual(hdu.data["ang_arr"][0][0], 1.0) - self.assertEqual(hdu.data["ang_arr"][0][1], 2.0) - self.assertEqual(hdu.data["ang_arr"][0][2], 3.0) + self.assertEqual(hdu.data["im_filepath"][0], "Here2\n...") + self.assertEqual(hdu.data["num_obs"][0], "5\n...") + self.assertEqual(hdu.data["cluster_type"][0], "null\n...") + self.assertEqual(hdu.data["mask_bits_dict"][0], "{bit1: 1, bit2: 2}") + self.assertEqual(hdu.data["res_filepath"][0], "There\n...") + self.assertEqual(hdu.data["ang_arr"][0], "[1.0, 2.0, 3.0]") def test_to_yaml(self): d = { From 23dd72a17402a09a560a46fddd813c43bb1775c9 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 10 Oct 2023 13:54:04 -0400 Subject: [PATCH 6/7] Fix linting error --- src/kbmod/configuration.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index 7b86265ed..379656b49 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -182,7 +182,6 @@ def from_table(cls, t, strict=True): params = {col.name: safe_load(col.value[0]) for col in t.values()} return SearchConfiguration.from_dict(params) - @classmethod def from_yaml(cls, config, strict=True): """Load a configuration from a YAML file. @@ -234,9 +233,12 @@ def to_hdu(self): hdu : `astropy.io.fits.BinTableHDU` The HDU with the configuration information. """ - serialized_dict = {key: dump(val, default_flow_style=True) - for key, val in self._params.items()} - t = Table(rows=[serialized_dict, ]) + serialized_dict = {key: dump(val, default_flow_style=True) for key, val in self._params.items()} + t = Table( + rows=[ + serialized_dict, + ] + ) return fits.table_to_hdu(t) def to_yaml(self): From ccb042bd3c662912e2859454f6b0c9f99b9bc9d3 Mon Sep 17 00:00:00 2001 From: DinoBektesevic Date: Tue, 10 Oct 2023 10:46:28 -0700 Subject: [PATCH 7/7] Don't parse our own types, let yaml do it for us. --- src/kbmod/configuration.py | 52 +++++++++---------------------------- tests/test_configuration.py | 27 ++++++++++++------- 2 files changed, 30 insertions(+), 49 deletions(-) diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index f92967e88..07ca2ae2d 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -1,10 +1,10 @@ -import ast import math from astropy.io import fits from astropy.table import Table from numpy import result_type from pathlib import Path +import yaml from yaml import dump, safe_load @@ -177,21 +177,9 @@ def from_table(cls, t, strict=True): if len(t) > 1: raise ValueError(f"More than one row in the configuration table ({len(t)}).") - config = SearchConfiguration() - for key in t.colnames: - # We use a special indicator for serializing certain types (including - # None and dict) to FITS. - if key.startswith("__NONE__"): - val = None - key = key[8:] - elif key.startswith("__DICT__"): - val = dict(t[key][0]) - key = key[8:] - else: - val = t[key][0] - - config.set(key, val, strict) - return config + # guaranteed to only have 1 element due to check above + params = {col.name: safe_load(col.value[0]) for col in t.values()} + return SearchConfiguration.from_dict(params) @classmethod def from_yaml(cls, config, strict=True): @@ -228,21 +216,8 @@ def from_hdu(cls, hdu, strict=True): Raises a ``KeyError`` if the parameter is not part on the list of known parameters and ``strict`` is False. """ - config = SearchConfiguration() - for column in hdu.data.columns: - key = column.name - val = hdu.data[key][0] - - # We use a special indicator for serializing certain types (including - # None and dict) to FITS. - if type(val) is str and val == "__NONE__": - val = None - elif key.startswith("__DICT__"): - val = ast.literal_eval(val) - key = key[8:] - - config.set(key, val, strict) - return config + t = Table(hdu.data) + return SearchConfiguration.from_table(t) @classmethod def from_file(cls, filename, strict=True): @@ -257,15 +232,12 @@ def to_hdu(self): hdu : `astropy.io.fits.BinTableHDU` The HDU with the configuration information. """ - t = Table() - for col in self._params.keys(): - val = self._params[col] - if val is None: - t[col] = ["__NONE__"] - elif type(val) is dict: - t["__DICT__" + col] = [str(val)] - else: - t[col] = [val] + serialized_dict = {key: dump(val, default_flow_style=True) for key, val in self._params.items()} + t = Table( + rows=[ + serialized_dict, + ] + ) return fits.table_to_hdu(t) def to_yaml(self): diff --git a/tests/test_configuration.py b/tests/test_configuration.py index b1a2a0c5d..ae8f8ee66 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -39,12 +39,23 @@ def test_from_dict(self): self.assertEqual(config["num_obs"], 5) def test_from_hdu(self): - t = Table([["Here3"], [7], ["__NONE__"]], names=("im_filepath", "num_obs", "cluster_type")) + t = Table( + [ + ["Here3"], + ["7"], + ["null"], + [ + "[1, 2]", + ], + ], + names=("im_filepath", "num_obs", "cluster_type", "ang_arr"), + ) hdu = fits.table_to_hdu(t) config = SearchConfiguration.from_hdu(hdu) self.assertEqual(config["im_filepath"], "Here3") self.assertEqual(config["num_obs"], 7) + self.assertEqual(config["ang_arr"], [1, 2]) self.assertIsNone(config["cluster_type"]) def test_to_hdu(self): @@ -60,14 +71,12 @@ def test_to_hdu(self): config = SearchConfiguration.from_dict(d) hdu = config.to_hdu() - self.assertEqual(hdu.data["im_filepath"][0], "Here2") - self.assertEqual(hdu.data["num_obs"][0], 5) - self.assertEqual(hdu.data["cluster_type"][0], "__NONE__") - self.assertEqual(hdu.data["__DICT__mask_bits_dict"][0], "{'bit1': 1, 'bit2': 2}") - self.assertEqual(hdu.data["res_filepath"][0], "There") - self.assertEqual(hdu.data["ang_arr"][0][0], 1.0) - self.assertEqual(hdu.data["ang_arr"][0][1], 2.0) - self.assertEqual(hdu.data["ang_arr"][0][2], 3.0) + self.assertEqual(hdu.data["im_filepath"][0], "Here2\n...") + self.assertEqual(hdu.data["num_obs"][0], "5\n...") + self.assertEqual(hdu.data["cluster_type"][0], "null\n...") + self.assertEqual(hdu.data["mask_bits_dict"][0], "{bit1: 1, bit2: 2}") + self.assertEqual(hdu.data["res_filepath"][0], "There\n...") + self.assertEqual(hdu.data["ang_arr"][0], "[1.0, 2.0, 3.0]") def test_to_yaml(self): d = {