Skip to content

Commit

Permalink
Merge branch 'main' into workunit
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 10, 2023
2 parents 7e0150e + 89555b7 commit 4b6284e
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 143 deletions.
105 changes: 105 additions & 0 deletions benchmarks/bench_filter_stamps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import timeit
import numpy as np

from kbmod.filters.stamp_filters import *
from kbmod.result_list import ResultRow
from kbmod.search import ImageStack, PSF, RawImage, StackSearch, StampParameters, StampType, Trajectory


def setup_coadd_stamp(params):
"""Create a coadded stamp to test with a single bright spot
slightly off center.
Parameters
----------
params : `StampParameters`
The parameters for stamp generation and filtering.
Returns
-------
stamp : `RawImage`
The coadded stamp.
"""
stamp_width = 2 * params.radius + 1

stamp = RawImage(stamp_width, stamp_width)
stamp.set_all(0.5)

# Insert a flux of 50.0 and apply a PSF.
flux = 50.0
p = PSF(1.0)
psf_dim = p.get_dim()
psf_rad = p.get_radius()
for i in range(psf_dim):
for j in range(psf_dim):
stamp.set_pixel(
(params.radius - 1) - psf_rad + i, # x is one pixel off center
params.radius - psf_rad + j, # y is centered
flux * p.get_value(i, j),
)

return stamp


def run_search_benchmark(params):
stamp = setup_coadd_stamp(params)

# Create an empty search stack.
im_stack = ImageStack([])
search = StackSearch(im_stack)

# Do the timing runs.
tmr = timeit.Timer(stmt="search.filter_stamp(stamp, params)", globals=locals())
res_time = np.mean(tmr.repeat(repeat=10, number=20))
return res_time


def run_row_benchmark(params, create_filter=""):
stamp = setup_coadd_stamp(params)
row = ResultRow(Trajectory(), 10)
row.stamp = np.array(stamp.get_all_pixels())

filt = eval(create_filter)

# Do the timing runs.
full_cmd = "filt.keep_row(row)"
tmr = timeit.Timer(stmt="filt.keep_row(row)", globals=locals())
res_time = np.mean(tmr.repeat(repeat=10, number=20))
return res_time


def run_all_benchmarks():
params = StampParameters()
params.radius = 5
params.do_filtering = True
params.stamp_type = StampType.STAMP_MEAN
params.center_thresh = 0.03
params.peak_offset_x = 1.5
params.peak_offset_y = 1.5
params.m01_limit = 0.6
params.m10_limit = 0.6
params.m11_limit = 2.0
params.m02_limit = 35.5
params.m20_limit = 35.5

print(" Rad | Method | Time")
print("-" * 40)
for r in [2, 5, 10, 20]:
params.radius = r

res_time = run_search_benchmark(params)
print(f" {r:2d} | C++ (all) | {res_time:10.7f}")

res_time = run_row_benchmark(params, f"StampPeakFilter({r}, 1.5, 1.5)")
print(f" {r:2d} | StampPeakFilter | {res_time:10.7f}")

res_time = run_row_benchmark(params, f"StampMomentsFilter({r}, 0.6, 0.6, 2.0, 35.5, 35.5)")
print(f" {r:2d} | StampMomentsFilter | {res_time:10.7f}")

res_time = run_row_benchmark(params, f"StampCenterFilter({r}, False, 0.03)")
print(f" {r:2d} | StampCenterFilter | {res_time:10.7f}")
print("-" * 40)


if __name__ == "__main__":
run_all_benchmarks()
158 changes: 65 additions & 93 deletions src/kbmod/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from astropy.table import Table
from numpy import result_type
from pathlib import Path
import pickle
import yaml
from yaml import dump, safe_load


Expand Down Expand Up @@ -123,7 +123,19 @@ def set(self, param, value, strict=True):
else:
self._params[param] = value

def set_from_dict(self, d, strict=True):
def validate(self):
"""Check that the configuration has the necessary parameters.
Raises
------
Raises a ``ValueError`` if a parameter is missing.
"""
for p in self._required_params:
if self._params.get(p, None) is None:
raise ValueError(f"Required configuration parameter {p} missing.")

@classmethod
def from_dict(cls, d, strict=True):
"""Sets multiple values from a dictionary.
Parameters
Expand All @@ -138,10 +150,13 @@ def set_from_dict(self, d, strict=True):
Raises a ``KeyError`` if the parameter is not part on the list of known parameters
and ``strict`` is False.
"""
config = SearchConfiguration()
for key, value in d.items():
self.set(key, value, strict)
config.set(key, value, strict)
return config

def set_from_table(self, t, strict=True):
@classmethod
def from_table(cls, t, strict=True):
"""Sets multiple values from an astropy Table with a single row and
one column for each parameter.
Expand All @@ -161,112 +176,81 @@ def set_from_table(self, t, strict=True):
"""
if len(t) > 1:
raise ValueError(f"More than one row in the configuration table ({len(t)}).")
for key in t.colnames:
# We use a special indicator for serializing certain types (including
# None and dict) to FITS.
if key.startswith("__PICKLED_"):
val = pickle.loads(t[key].value[0])
key = key[10:]
else:
val = t[key][0]

self.set(key, val, strict)

def to_table(self, make_fits_safe=False):
"""Create an astropy table with all the configuration parameters.
Parameter
---------
make_fits_safe : `bool`
Override Nones and dictionaries so we can write to FITS.
Returns
-------
t: `~astropy.table.Table`
The configuration table.
"""
t = Table()
for col in self._params.keys():
val = self._params[col]
t[col] = [val]

# If Table does not understand the type, pickle it.
if make_fits_safe and t[col].dtype == "O":
t.remove_column(col)
t["__PICKLED_" + col] = pickle.dumps(val)

return t

def validate(self):
"""Check that the configuration has the necessary parameters.
Raises
------
Raises a ``ValueError`` if a parameter is missing.
"""
for p in self._required_params:
if self._params.get(p, None) is None:
raise ValueError(f"Required configuration parameter {p} missing.")
# guaranteed to only have 1 element due to check above
params = {col.name: safe_load(col.value[0]) for col in t.values()}
return SearchConfiguration.from_dict(params)

def load_from_yaml_file(self, filename, strict=True):
@classmethod
def from_yaml(cls, config, strict=True):
"""Load a configuration from a YAML file.
Parameters
----------
filename : `str`
The filename, including path, of the configuration file.
config : `str` or `_io.TextIOWrapper`
The serialized YAML data.
strict : `bool`
Raise an exception on unknown parameters.
Raises
------
Raises a ``ValueError`` if the configuration file is not found.
Raises a ``KeyError`` if the parameter is not part on the list of known parameters
and ``strict`` is False.
"""
if not Path(filename).is_file():
raise ValueError(f"Configuration file {filename} not found.")

# Read the user-specified parameters from the file.
file_params = {}
with open(filename, "r") as config:
file_params = safe_load(config)
yaml_params = safe_load(config)
return SearchConfiguration.from_dict(yaml_params, strict)

# Merge in the new values.
self.set_from_dict(file_params, strict)

if strict:
self.validate()

def load_from_fits_file(self, filename, layer=0, strict=True):
@classmethod
def from_hdu(cls, hdu, strict=True):
"""Load a configuration from a FITS extension file.
Parameters
----------
filename : `str`
The filename, including path, of the configuration file.
layer : `int`
The extension number to use.
hdu : `astropy.io.fits.BinTableHDU`
The HDU from which to parse the configuration information.
strict : `bool`
Raise an exception on unknown parameters.
Raises
------
Raises a ``ValueError`` if the configuration file is not found.
Raises a ``KeyError`` if the parameter is not part on the list of known parameters
and ``strict`` is False.
"""
if not Path(filename).is_file():
raise ValueError(f"Configuration file {filename} not found.")
t = Table(hdu.data)
return SearchConfiguration.from_table(t)

# Read the user-specified parameters from the file.
t = Table.read(filename, hdu=layer)
self.set_from_table(t)
@classmethod
def from_file(cls, filename, strict=True):
with open(filename) as ff:
return SearchConfiguration.from_yaml(ff.read(), strict)

def to_hdu(self):
"""Create a fits HDU with all the configuration parameters.
Returns
-------
hdu : `astropy.io.fits.BinTableHDU`
The HDU with the configuration information.
"""
serialized_dict = {key: dump(val, default_flow_style=True) for key, val in self._params.items()}
t = Table(
rows=[
serialized_dict,
]
)
return fits.table_to_hdu(t)

def to_yaml(self):
"""Save a configuration file with the parameters.
if strict:
self.validate()
Returns
-------
result : `str`
The serialized YAML string.
"""
return dump(self._params)

def save_to_yaml_file(self, filename, overwrite=False):
def to_file(self, filename, overwrite=False):
"""Save a configuration file with the parameters.
Parameters
Expand All @@ -281,16 +265,4 @@ def save_to_yaml_file(self, filename, overwrite=False):
return

with open(filename, "w") as file:
file.write(dump(self._params))

def append_to_fits(self, filename):
"""Append the configuration table as a new extension on a FITS file
(creating a new file if needed).
Parameters
----------
filename : str
The filename, including path, of the configuration file.
"""
t = self.to_table(make_fits_safe=True)
t.write(filename, append=True)
file.write(self.to_yaml())
11 changes: 6 additions & 5 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ class run_search:
"""

def __init__(self, input_parameters, config_file=None):
self.config = SearchConfiguration()

# Load parameters from a file.
if config_file != None:
self.config.load_from_yaml_file(config_file)
self.config = SearchConfiguration.from_file(config_file)
else:
self.config = SearchConfiguration()

# Load any additional parameters (overwriting what is there).
if len(input_parameters) > 0:
self.config.set_from_dict(input_parameters)
for key, value in input_parameters.items():
self.config.set(key, value)

# Validate the configuration.
self.config.validate()
Expand Down Expand Up @@ -301,7 +302,7 @@ def run_search(self):
config_filename = os.path.join(
self.config["res_filepath"], f"config_{self.config['output_suffix']}.yml"
)
self.config.save_to_yaml_file(config_filename, overwrite=True)
self.config.to_file(config_filename, overwrite=True)

end = time.time()
print("Time taken for patch: ", end - start)
Expand Down
Loading

0 comments on commit 4b6284e

Please sign in to comment.