diff --git a/benchmarks/bench_filter_stamps.py b/benchmarks/bench_filter_stamps.py index b43d36d93..3e9b45176 100644 --- a/benchmarks/bench_filter_stamps.py +++ b/benchmarks/bench_filter_stamps.py @@ -3,7 +3,16 @@ from kbmod.filters.stamp_filters import * from kbmod.result_list import ResultRow -from kbmod.search import ImageStack, PSF, RawImage, StackSearch, StampParameters, StampType, Trajectory +from kbmod.search import ( + ImageStack, + PSF, + RawImage, + StackSearch, + StampParameters, + StampType, + Trajectory, + StampCreator, +) def setup_coadd_stamp(params): @@ -30,7 +39,7 @@ def setup_coadd_stamp(params): p = PSF(1.0) psf_dim = p.get_dim() psf_rad = p.get_radius() - for i in range(psf_dim): + for i in range(1, psf_dim): for j in range(psf_dim): stamp.set_pixel( (params.radius - 1) - psf_rad + i, # x is one pixel off center @@ -45,11 +54,11 @@ def run_search_benchmark(params): stamp = setup_coadd_stamp(params) # Create an empty search stack. - im_stack = ImageStack([]) - search = StackSearch(im_stack) + # im_stack = ImageStack([]) + sc = StampCreator() # Do the timing runs. - tmr = timeit.Timer(stmt="search.filter_stamp(stamp, params)", globals=locals()) + tmr = timeit.Timer(stmt="sc.filter_stamp(stamp, params)", globals=locals()) res_time = np.mean(tmr.repeat(repeat=10, number=20)) return res_time @@ -57,7 +66,7 @@ def run_search_benchmark(params): def run_row_benchmark(params, create_filter=""): stamp = setup_coadd_stamp(params) row = ResultRow(Trajectory(), 10) - row.stamp = np.array(stamp.get_all_pixels()) + row.stamp = stamp.image filt = eval(create_filter) diff --git a/src/kbmod/filters/stamp_filters.py b/src/kbmod/filters/stamp_filters.py index 2bad57602..3d5db49aa 100644 --- a/src/kbmod/filters/stamp_filters.py +++ b/src/kbmod/filters/stamp_filters.py @@ -103,7 +103,7 @@ def keep_row(self, row: ResultRow): return False # Find the peak in the image. - stamp = row.stamp.reshape([self.width, self.width]) + stamp = row.stamp peak_pos = RawImage(stamp).find_peak(True) return ( abs(peak_pos.i - self.stamp_radius) < self.x_thresh @@ -179,7 +179,7 @@ def keep_row(self, row: ResultRow): return False # Find the peack in the image. - stamp = row.stamp.reshape([self.width, self.width]) + stamp = row.stamp moments = RawImage(stamp).find_central_moments() return ( (abs(moments.m01) < self.m01_thresh) @@ -235,25 +235,5 @@ def keep_row(self, row: ResultRow): bool An indicator of whether to keep the row. """ - # Filter rows without a valid stamp. - if not self._check_row_valid(row): - return False - - # Find the value of the center pixel. - stamp = row.stamp.flatten() - center_index = self.width * self.stamp_radius + self.stamp_radius - center_val = stamp[center_index] - - # Find the total flux in the image and check for other local_maxima - flux_sum = 0.0 - for i in range(self.width * self.width): - pix_val = stamp[i] - if pix_val != KB_NO_DATA: - flux_sum += pix_val - if i != center_index and self.local_max and (pix_val >= center_val): - return False - - # Check the flux percentage. - if flux_sum == 0.0: - return False - return center_val / flux_sum >= self.flux_thresh + image = RawImage(row.stamp) + return image.center_is_local_max(self.flux_thresh, self.local_max) diff --git a/src/kbmod/search/pydocs/raw_image_docs.h b/src/kbmod/search/pydocs/raw_image_docs.h index 2a34f0d7c..39bcd6d8d 100644 --- a/src/kbmod/search/pydocs/raw_image_docs.h +++ b/src/kbmod/search/pydocs/raw_image_docs.h @@ -176,6 +176,26 @@ static const auto DOC_RawImage_find_central_moments = R"doc( Image moments. )doc"; +static const auto DOC_RawImage_center_is_local_max = R"doc( + A filter on whether the center of the stamp is a local + maxima and the percentage of the stamp's total flux in this + pixel. + + Parameters + ---------- + local_max : ``bool`` + Require the central pixel to be a local maximum. + flux_thresh : ``float`` + The fraction of the stamp's total flux that needs to be in + the center pixel [0.0, 1.0]. + + Returns + ------- + keep_row : `bool` + Whether or not the stamp passes the check. + )doc"; + + static const auto DOC_RawImage_create_stamp = R"doc( Create an image stamp around a given region. diff --git a/src/kbmod/search/raw_image.cpp b/src/kbmod/search/raw_image.cpp index 7e430a129..a70c34984 100644 --- a/src/kbmod/search/raw_image.cpp +++ b/src/kbmod/search/raw_image.cpp @@ -331,6 +331,28 @@ ImageMoments RawImage::find_central_moments() const { return res; } +bool RawImage::center_is_local_max(double flux_thresh, bool local_max) const { + const int num_pixels = width * height; + int c_x = width / 2; + int c_y = height / 2; + int c_ind = c_y * width + c_x; + + auto pixels = image.reshaped(); + double center_val = pixels[c_ind]; + + // Find the sum of the zero-shifted (non-NO_DATA) pixels. + double sum = 0.0; + for (int p = 0; p < num_pixels; ++p) { + float pix_val = pixels[p]; + if (p != c_ind && local_max && pix_val >= center_val) { + return false; + } + sum += (pix_val != NO_DATA) ? pix_val : 0.0; + } + if (sum == 0.0) return false; + return center_val / sum >= flux_thresh; +} + void RawImage::load_time_from_file(fitsfile* fptr) { int mjd_status = 0; @@ -603,7 +625,9 @@ static void raw_image_bindings(py::module& m) { .def("compute_bounds", &rie::compute_bounds, pydocs::DOC_RawImage_compute_bounds) .def("find_peak", &rie::find_peak, pydocs::DOC_RawImage_find_peak) .def("find_central_moments", &rie::find_central_moments, - pydocs::DOC_RawImage_find_central_moments) + pydocs::DOC_RawImage_find_central_moments) + .def("center_is_local_max", &rie::center_is_local_max, + pydocs::DOC_RawImage_center_is_local_max) .def("create_stamp", &rie::create_stamp, pydocs::DOC_RawImage_create_stamp) .def("interpolate", &rie::interpolate, pydocs::DOC_RawImage_interpolate) .def("interpolated_add", &rie::interpolated_add, pydocs::DOC_RawImage_interpolated_add) diff --git a/src/kbmod/search/raw_image.h b/src/kbmod/search/raw_image.h index f60c458e4..bacdfc005 100644 --- a/src/kbmod/search/raw_image.h +++ b/src/kbmod/search/raw_image.h @@ -111,6 +111,8 @@ class RawImage { // Elements with NO_DATA are treated as zero. ImageMoments find_central_moments() const; + bool center_is_local_max(double flux_thresh, bool local_max) const; + // Load the image data from a specific layer of a FITS file. // Overwrites the current image data. void from_fits(const std::string& file_path, int layer_num);