From c666db101c64d5883b090a1389e58d6e3e357a7c Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 1 Dec 2023 09:21:48 -0500 Subject: [PATCH 1/4] Add multiple functions to ResultList --- src/kbmod/result_list.py | 283 +++++++++++++++++++++++++++------ src/kbmod/trajectory_utils.py | 31 ++++ tests/test_result_list.py | 119 +++++++++++++- tests/test_stats_filters.py | 7 +- tests/test_trajectory_utils.py | 25 +++ 5 files changed, 406 insertions(+), 59 deletions(-) diff --git a/src/kbmod/result_list.py b/src/kbmod/result_list.py index bf5ee2ddc..b59ac1ef0 100644 --- a/src/kbmod/result_list.py +++ b/src/kbmod/result_list.py @@ -1,3 +1,8 @@ +"""ResultList is a row-based data structure for tracking results with additional logic for +filtering and maintaining consistency between different attributes in each row. Each row is +represented as a ResultRow. +""" + import math import multiprocessing as mp import numpy as np @@ -7,32 +12,107 @@ from yaml import dump, safe_load from kbmod.file_utils import * -from kbmod.trajectory_utils import trajectory_from_yaml, trajectory_to_yaml +from kbmod.trajectory_utils import ( + trajectory_from_yaml, + trajectory_predict_skypos, + trajectory_to_yaml, +) class ResultRow: - """This class stores a collection of related data from a single kbmod result.""" + """This class stores a collection of related data from a single kbmod result. + In order to maintain a consistent internal state, the class uses private variables + with getter only properties for key pieces of information. While this adds overhead, + it requires the users to work through specific setter functions that update one + attribute in relation to another. Do not set any of the private data members directly. + + Example: + To set the psi and phi curves use: set_psi_phi(psi, phi). This will update the curves + the final likelihood, the trajectory's likelihood, and the trajectory's flux. + + Most attributes are optional so we only use space for the ones needed. The attributes + are set to None when unused. + + Attributes + ---------- + all_stamps : `numpy.ndarray` + An array of numpy arrays representing stamps for each timestep + (including the invalid/filtered ones). [Optional] + final_likelihood : `float` + The final likelihood as computed from the valid indices of the psi and phi + curves. Initially set from the trajectory's lh field. [Required] + num_times : `int` + The number of timesteps. [Required] + stamp : `numpy.ndarray` + The coadded stamp computed from the valid timesteps. [Optional] + phi_curve : `list` or `numpy.ndarray` + An array of numpy arrays representing phi values for each timestep + (including the invalid/filtered ones). [Optional] + pred_dec : `numpy.ndarray` + An array of the predict positions dec. [Optional] + pred_ra : `numpy.ndarray` + An array of the predict positions RA. [Optional] + psi_curve : `list` or `numpy.ndarray` + An array of numpy arrays representing psi values for each timestep + (including the invalid/filtered ones). [Optional] + trajectory : `kbmod.search.Trajectory` + The result trajectory in pixel space. [Required] + valid_indices : `list` or `numpy.ndarray` + The indices of the timesteps that are unfiltered (valid). [Required] + """ __slots__ = ( - "trajectory", - "stamp", "all_stamps", - "final_likelihood", - "valid_indices", - "psi_curve", - "phi_curve", - "num_times", + "_final_likelihood", + "_num_times", + "_phi_curve", + "pred_ra", + "pred_dec", + "_psi_curve", + "stamp", + "trajectory", + "_valid_indices", ) def __init__(self, trj, num_times): - self.trajectory = trj - self.stamp = None - self.final_likelihood = trj.lh - self.valid_indices = [i for i in range(num_times)] self.all_stamps = None - self.psi_curve = None - self.phi_curve = None - self.num_times = num_times + self._final_likelihood = trj.lh + self._num_times = num_times + self._phi_curve = None + self.pred_dec = None + self.pred_ra = None + self._psi_curve = None + self.stamp = None + self.trajectory = trj + self._valid_indices = [i for i in range(num_times)] + + @property + def final_likelihood(self): + return self._final_likelihood + + @property + def valid_indices(self): + return self._valid_indices + + @property + def psi_curve(self): + return self._psi_curve + + @property + def phi_curve(self): + return self._phi_curve + + @property + def num_times(self): + return self._num_times + + @property + def obs_count(self): + return self.trajectory.obs_count + + @property + def flux(self): + return self.trajectory.flux @classmethod def from_yaml(cls, yaml_str): @@ -53,7 +133,8 @@ def from_yaml(cls, yaml_str): # Copy the values into the object. for attr in ResultRow.__slots__: if attr != "trajectory": - setattr(result, attr, yaml_params[attr]) + attr_name = attr.lstrip("_") + setattr(result, attr, yaml_params[attr_name]) # Convert the stamps to np arrays if result.stamp is not None: @@ -83,7 +164,8 @@ def to_yaml(self): elif type(value) is np.float64: value = float(value) - yaml_dict[attr] = value + attr_name = attr.lstrip("_") + yaml_dict[attr_name] = value return dump(yaml_dict) def valid_times(self, all_times): @@ -99,7 +181,7 @@ def valid_times(self, all_times): list The times for the valid indices. """ - return [all_times[i] for i in self.valid_indices] + return [all_times[i] for i in self._valid_indices] @property def light_curve(self): @@ -111,12 +193,12 @@ def light_curve(self): The light curve. This is an empty array if either psi or phi are not set. """ - if self.psi_curve is None or self.phi_curve is None: + if self._psi_curve is None or self._phi_curve is None: return np.array([]) - masked_phi = np.copy(self.phi_curve) + masked_phi = np.copy(self._phi_curve) masked_phi[masked_phi == 0] = 1e12 - lc = np.divide(self.psi_curve, masked_phi) + lc = np.divide(self._psi_curve, masked_phi) return lc @property @@ -129,12 +211,12 @@ def likelihood_curve(self): The likelihood curve. This is an empty array if either psi or phi are not set. """ - if self.psi_curve is None or self.phi_curve is None: + if self._psi_curve is None or self._phi_curve is None: return np.array([]) - masked_phi = np.copy(self.phi_curve) + masked_phi = np.copy(self._phi_curve) masked_phi[masked_phi == 0] = 1e12 - lh = np.divide(self.psi_curve, np.sqrt(masked_phi)) + lh = np.divide(self._psi_curve, np.sqrt(masked_phi)) return lh def valid_indices_as_booleans(self): @@ -145,7 +227,7 @@ def valid_indices_as_booleans(self): result : list A list of bool indicating which indices appear in valid_indices """ - indices_set = set(self.valid_indices) + indices_set = set(self._valid_indices) result = [(x in indices_set) for x in range(self.num_times)] return result @@ -169,10 +251,27 @@ def set_psi_phi(self, psi, phi): raise ValueError( f"Expected arrays of length {self.num_times} got {len(phi)} and {len(psi)} instead" ) - self.psi_curve = psi - self.phi_curve = phi + self._psi_curve = psi + self._phi_curve = phi self._update_likelihood() + def compute_predicted_skypos(self, times, wcs): + """Set the predicted sky positions at each time. + + Parameters + ---------- + times : `list` or `numpy.ndarray` + The times at which to predict the positions. + wcs : `astropy.wcs.WCS` + The WCS for the images. + """ + if len(times) != self.num_times: + raise ValueError(f"Expected an array of length {self.num_times} got {len(times)} instead") + sky_pos = trajectory_predict_skypos(self.trajectory, wcs, times) + + self.pred_ra = sky_pos.ra.value + self.pred_dec = sky_pos.dec.value + def filter_indices(self, indices_to_keep): """Remove invalid indices and times from the ResultRow. This uses relative filtering where valid_indices[i] is kept for all i in indices_to_keep. @@ -187,15 +286,15 @@ def filter_indices(self, indices_to_keep): ------ ValueError: If any of the given indices are out of bounds. """ - current_num_inds = len(self.valid_indices) + current_num_inds = len(self._valid_indices) if any(v >= current_num_inds or v < 0 for v in indices_to_keep): raise ValueError(f"Out of bounds index in {indices_to_keep}") - self.valid_indices = [self.valid_indices[i] for i in indices_to_keep] + self._valid_indices = [self._valid_indices[i] for i in indices_to_keep] self._update_likelihood() # Update the count of valid observations in the trajectory object. - self.trajectory.obs_count = len(self.valid_indices) + self.trajectory.obs_count = len(self._valid_indices) def _update_likelihood(self): """Update the likelihood based on the result's psi and phi curves @@ -204,24 +303,23 @@ def _update_likelihood(self): Note ---- Requires that psi_curve and phi_curve have both been set. Otherwise - defaults to a likelihood of 0.0. + does not perform any updates. """ - if self.psi_curve is None or self.phi_curve is None: - self.final_likelihood = 0.0 + if self._psi_curve is None or self._phi_curve is None: return psi_sum = 0.0 phi_sum = 0.0 - for ind in self.valid_indices: - psi_sum += self.psi_curve[ind] - phi_sum += self.phi_curve[ind] + for ind in self._valid_indices: + psi_sum += self._psi_curve[ind] + phi_sum += self._phi_curve[ind] if phi_sum <= 0.0: - self.final_likelihood = 0.0 + self._final_likelihood = 0.0 self.trajectory.lh = 0.0 self.trajectory.flux = 0.0 else: - self.final_likelihood = psi_sum / np.sqrt(phi_sum) + self._final_likelihood = psi_sum / np.sqrt(phi_sum) self.trajectory.lh = psi_sum / np.sqrt(phi_sum) self.trajectory.flux = psi_sum / phi_sum @@ -246,13 +344,15 @@ def append_to_dict(self, result_dict, expand_trajectory=False): result_dict["flux"].append(self.trajectory.flux) else: result_dict["trajectory"].append(trajectory) - result_dict["likelihood"].append(self.final_likelihood) + result_dict["likelihood"].append(self._final_likelihood) result_dict["stamp"].append(self.stamp) result_dict["all_stamps"].append(self.all_stamps) - result_dict["valid_indices"].append(self.valid_indices) - result_dict["psi_curve"].append(self.psi_curve) - result_dict["phi_curve"].append(self.phi_curve) + result_dict["valid_indices"].append(self._valid_indices) + result_dict["psi_curve"].append(self._psi_curve) + result_dict["phi_curve"].append(self._phi_curve) + result_dict["pred_ra"].append(self.pred_ra) + result_dict["pred_dec"].append(self.pred_dec) class ResultList: @@ -269,13 +369,18 @@ def __init__(self, all_times, track_filtered=False): Whether to track (save) the filtered trajectories. This will use more memory and is recommended only for analysis. """ - self.all_times = all_times + self._all_times = all_times self.results = [] # Set up information to track which row is filtered at which round. self.track_filtered = track_filtered self.filtered = {} + # All times should be externally read-only once set. + @property + def all_times(self): + return self._all_times + @classmethod def from_yaml(cls, yaml_str): """Deserialize a ResultList from a YAML string. @@ -341,16 +446,43 @@ def extend(self, result_list): else: self.filtered[key] = result_list.filtered[key] - def zip_phi_psi_idx(self): - """Create and return a list of tuples for each psi/phi curve. + def sort(self, key="final_likelihood", reverse=True): + """Sort the results by the given key. This must correspond + to one of the proporties in ResultRow. + + Parameters + ---------- + key : `str` + A string representing the property by which to sort. + Default = final_likelihood + reverse : `bool` + Sort in increasing order. Returns ------- - iterable - A list of tuples with (psi_curve, phi_curve, index) for - each result in the ResultList. + self : ResultList + Returns a reference to itself to allow chaining. """ - return ((x.psi_curve, x.phi_curve, i) for i, x in enumerate(self.results)) + self.results.sort(key=lambda x: getattr(x, key), reverse=reverse) + return self + + def compute_predicted_skypos(self, wcs): + """Compute the predict sky position for each result's trajectory + at each time step. + + Parameters + ---------- + wcs : `astropy.wcs.WCS` + The WCS for the images. + + Returns + ------- + self : ResultList + Returns a reference to itself to allow chaining. + """ + for row in self.results: + row.compute_predicted_skypos(self._all_times, wcs) + return self def filter_results(self, indices_to_keep, label=None): """Filter the rows in the ResultList to only include those indices @@ -473,6 +605,49 @@ def get_filtered(self, label=None): return result + def revert_filter(self, label=None): + """Revert the filtering by re-adding filtered ResultRows. + + Note + ---- + Filtered rows are appended to the end of the list. Does not return + the results to the original ordering. + + Parameters + ---------- + label : str + The filtering stage to use. If no label is provided, + revert all filtered rows. + + Returns + ------- + self : ResultList + Returns a reference to itself to allow chaining. + + Raises + ------ + ValueError if filtering is not enabled. + KeyError if label is unknown. + """ + if not self.track_filtered: + raise ValueError("ResultList filter tracking not enabled.") + + if label is not None: + # Check if anything was filtered at this stage. + if label in self.filtered: + self.results.extend(self.filtered[label]) + del self.filtered[label] + else: + raise KeyError(f"Unknown filtered label {label}") + else: + for key in self.filtered: + self.results.extend(self.filtered[key]) + + # Reset the entire dictionary. + self.filtered = {} + + return self + def to_table(self, filtered_label=None): """Extract the results into an astropy table. @@ -512,6 +687,8 @@ def to_table(self, filtered_label=None): "psi_curve": [], "phi_curve": [], "all_stamps": [], + "pred_ra": [], + "pred_dec": [], } # Use a (slow) linear scan to do the transformation. @@ -533,7 +710,7 @@ def to_yaml(self, serialize_filtered=False): The serialized string. """ yaml_dict = { - "all_times": self.all_times, + "all_times": self._all_times, "results": [row.to_yaml() for row in self.results], "track_filtered": False, "filtered": {}, @@ -582,12 +759,12 @@ def save_to_files(self, res_filepath, out_suffix): ) FileUtils.save_csv_from_list( ospath.join(res_filepath, f"times_{out_suffix}.txt"), - [x.valid_times(self.all_times) for x in self.results], + [x.valid_times(self._all_times) for x in self.results], True, ) FileUtils.save_csv_from_list( ospath.join(res_filepath, f"all_times_{out_suffix}.txt"), - [self.all_times], + [self._all_times], True, ) np.savetxt( diff --git a/src/kbmod/trajectory_utils.py b/src/kbmod/trajectory_utils.py index bce8541ae..e05fde489 100644 --- a/src/kbmod/trajectory_utils.py +++ b/src/kbmod/trajectory_utils.py @@ -7,9 +7,13 @@ * Convert a Trajectory into another data type. * Serialize and deserialize a Trajectory. + +* Use a trajectory and WCS to predict RA, dec positions. """ import numpy as np + +from astropy.wcs import WCS from yaml import dump, safe_load from kbmod.search import Trajectory @@ -51,6 +55,33 @@ def make_trajectory(x=0, y=0, vx=0.0, vy=0.0, flux=0.0, lh=0.0, obs_count=0): return trj +def trajectory_predict_skypos(trj, wcs, times): + """Predict the (RA, dec) locations of the trajectory at different times. + + Parameters + ---------- + trj : `Trajectory` + The corresponding trajectory object. + wcs : `astropy.wcs.WCS` + The WCS for the images. + times : `list` or `numpy.ndarray` + The times at which to predict the positions. + + Returns + ------- + result : `astropy.coordinates.SkyCoord` + A SkyCoord with the transformed locations. + """ + np_times = np.array(times) + + # Predict locations in pixel space. + x_vals = trj.x + trj.vx * np_times + y_vals = trj.y + trj.vy * np_times + + result = wcs.pixel_to_world(x_vals, y_vals) + return result + + def trajectory_from_np_object(result): """Transform a numpy object holding trajectory information into a trajectory object. diff --git a/tests/test_result_list.py b/tests/test_result_list.py index e15b9fc07..44990321f 100644 --- a/tests/test_result_list.py +++ b/tests/test_result_list.py @@ -5,6 +5,7 @@ import unittest from astropy.table import Table +from astropy.wcs import WCS from pathlib import Path from kbmod.analysis_utils import * @@ -75,6 +76,8 @@ def test_to_from_yaml(self): self.assertEqual(row2.valid_times(self.times), [1.0, 2.0, 3.0, 4.0]) self.assertEqual(row2.trajectory.obs_count, 4) self.assertIsNone(row2.stamp) + self.assertIsNone(row2.pred_ra) + self.assertIsNone(row2.pred_dec) self.assertIsNotNone(row2.all_stamps) self.assertEqual(row2.all_stamps.shape[0], 4) self.assertEqual(row2.all_stamps.shape[1], 5) @@ -84,12 +87,43 @@ def test_to_from_yaml(self): self.assertAlmostEqual(row2.trajectory.flux, 1.15) self.assertAlmostEqual(row2.trajectory.lh, 2.3) + def test_compute_predicted_skypos(self): + self.assertIsNone(self.rdr.pred_ra) + self.assertIsNone(self.rdr.pred_dec) + + # Fill out the trajectory details + self.rdr.trajectory.x = 9 + self.rdr.trajectory.y = 9 + self.rdr.trajectory.vx = -1.0 + self.rdr.trajectory.vy = 3.0 + + # Create a fake WCS with a known pointing. + my_wcs = WCS(naxis=2) + my_wcs.wcs.crpix = [10.0, 10.0] # Reference point on the image (1-indexed) + my_wcs.wcs.crval = [45.0, -15.0] # Reference pointing on the sky + my_wcs.wcs.cdelt = [0.05, 0.15] # Pixel step size + my_wcs.wcs.ctype = ["RA---TAN-SIP", "DEC--TAN-SIP"] + + times = [0.0, 1.0, 2.0, 3.0] + self.rdr.compute_predicted_skypos(times, my_wcs) + self.assertEqual(len(self.rdr.pred_ra), 4) + self.assertEqual(len(self.rdr.pred_dec), 4) + self.assertAlmostEqual(self.rdr.pred_ra[0], 45.0, delta=0.01) + self.assertAlmostEqual(self.rdr.pred_dec[0], -15.0, delta=0.01) + class test_result_list(unittest.TestCase): def setUp(self): self.times = [(10.0 + 0.1 * float(i)) for i in range(20)] self.num_times = len(self.times) + # Create a fake WCS with a known pointing to use for the (RA, dec) predictions. + self.my_wcs = WCS(naxis=2) + self.my_wcs.wcs.crpix = [50.0, 50.0] # Reference point on the image (1-indexed) + self.my_wcs.wcs.crval = [45.0, -15.0] # Reference pointing on the sky + self.my_wcs.wcs.cdelt = [0.05, 0.05] # Pixel step size + self.my_wcs.wcs.ctype = ["RA---TAN-SIP", "DEC--TAN-SIP"] + def test_append_single(self): rs = ResultList(self.times) self.assertEqual(rs.num_results(), 0) @@ -144,6 +178,28 @@ def test_clear(self): rs.clear() self.assertEqual(rs.num_results(), 0) + def test_sort(self): + rs = ResultList(self.times) + rs.append_result(ResultRow(make_trajectory(x=0, lh=1.0, obs_count=1), self.num_times)) + rs.append_result(ResultRow(make_trajectory(x=1, lh=-1.0, obs_count=2), self.num_times)) + rs.append_result(ResultRow(make_trajectory(x=2, lh=5.0, obs_count=3), self.num_times)) + rs.append_result(ResultRow(make_trajectory(x=3, lh=4.0, obs_count=5), self.num_times)) + rs.append_result(ResultRow(make_trajectory(x=4, lh=6.0, obs_count=4), self.num_times)) + + # Sort by final likelihood. + rs.sort() + self.assertEqual(rs.num_results(), 5) + expected_order = [4, 2, 3, 0, 1] + for i, val in enumerate(expected_order): + self.assertEqual(rs.results[i].trajectory.x, val) + + # Sort by the number of observations. + rs.sort(key="obs_count", reverse=False) + self.assertEqual(rs.num_results(), 5) + expected_order = [0, 1, 2, 4, 3] + for i, val in enumerate(expected_order): + self.assertEqual(rs.results[i].trajectory.x, val) + def test_filter(self): rs = ResultList(self.times) for i in range(10): @@ -163,6 +219,12 @@ def test_filter(self): # Without tracking there should be nothing stored in the ResultList's # filtered dictionary. self.assertEqual(len(rs.filtered), 0) + with self.assertRaises(ValueError): + rs.get_filtered() + + # Without tracking we cannot revert anything. + with self.assertRaises(ValueError): + rs.revert_filter() def test_filter_dups(self): rs = ResultList(self.times, track_filtered=False) @@ -188,9 +250,8 @@ def test_filter_dups(self): def test_filter_track(self): rs = ResultList(self.times, track_filtered=True) for i in range(10): - t = Trajectory() - t.x = i - rs.append_result(ResultRow(t, self.num_times)) + trj = make_trajectory(x=i) + rs.append_result(ResultRow(trj, self.num_times)) self.assertEqual(rs.num_results(), 10) # Do the filtering. First remove elements 0 and 2. Then remove elements @@ -221,12 +282,59 @@ def test_filter_track(self): f_all = rs.get_filtered() self.assertEqual(len(f_all), 5) + def test_revert_filter(self): + rs = ResultList(self.times, track_filtered=True) + for i in range(10): + trj = make_trajectory(x=i) + rs.append_result(ResultRow(trj, self.num_times)) + self.assertEqual(rs.num_results(), 10) + + # Do the filtering. First remove elements 0 and 2. Then remove elements + # 0, 5, and 6 from the resulting list (1, 7, 8 in the original list). Then + # remove item 5 (9 from the original list). + rs.filter_results([1, 3, 4, 5, 6, 7, 8, 9], label="1") + self.assertEqual(rs.num_results(), 8) + rs.filter_results([1, 2, 3, 4, 7], label="2") + self.assertEqual(rs.num_results(), 5) + rs.filter_results([0, 1, 2, 3], label="3") + self.assertEqual(rs.num_results(), 4) + + # Test that we can recover the items filtered in stage 1. These are added to + # end, so we should get [3, 4, 5, 6, 0, 2] + rs.revert_filter(label="1") + self.assertEqual(rs.num_results(), 6) + expected_order = [3, 4, 5, 6, 0, 2] + for i, value in enumerate(expected_order): + self.assertEqual(rs.results[i].trajectory.x, value) + + # Test that we can recover the all items if we don't provide a label. + rs.revert_filter() + self.assertEqual(rs.num_results(), 10) + expected_order = [3, 4, 5, 6, 0, 2, 1, 7, 8, 9] + for i, value in enumerate(expected_order): + self.assertEqual(rs.results[i].trajectory.x, value) + + with self.assertRaises(KeyError): + rs.revert_filter(label="wrong") + + def test_compute_predicted_skypos(self): + rs = ResultList(self.times, track_filtered=True) + for i in range(5): + trj = make_trajectory(x=49 + i, y=49 + i, vx=2 * i, vy=-3 * i, obs_count=self.num_times - i) + + # Check that we have computed a position for each row and time. + rs.compute_predicted_skypos(self.my_wcs) + for row in rs.results: + self.assertEqual(len(row.pred_ra), len(self.times)) + self.assertEqual(len(row.pred_dec), len(self.times)) + def test_to_from_yaml(self): rs = ResultList(self.times, track_filtered=True) for i in range(10): row = ResultRow(Trajectory(), self.num_times) row.set_psi_phi(np.array([i] * self.num_times), np.array([0.01 * i] * self.num_times)) rs.append_result(row) + rs.compute_predicted_skypos(self.my_wcs) # Do the filtering and check we have the correct ones. inds = [0, 2, 6, 7] @@ -242,6 +350,8 @@ def test_to_from_yaml(self): for i in range(len(inds)): self.assertAlmostEqual(rs_a.results[i].psi_curve[0], inds[i]) self.assertAlmostEqual(rs_a.results[i].phi_curve[0], 0.01 * inds[i]) + self.assertEqual(len(rs_a.results[i].pred_ra), self.num_times) + self.assertEqual(len(rs_a.results[i].pred_dec), self.num_times) self.assertFalse(rs_a.track_filtered) self.assertEqual(len(rs_a.filtered), 0) @@ -269,6 +379,7 @@ def test_to_table(self): row.stamp = np.ones((10, 10)) row.all_stamps = np.array([np.ones((10, 10)) for _ in range(self.num_times)]) rs.append_result(row) + rs.compute_predicted_skypos(self.my_wcs) table = rs.to_table() self.assertEqual(len(table), 10) @@ -285,6 +396,8 @@ def test_to_table(self): self.assertEqual(len(table["valid_indices"][i]), self.num_times) self.assertEqual(len(table["psi_curve"][i]), self.num_times) self.assertEqual(len(table["phi_curve"][i]), self.num_times) + self.assertEqual(len(table["pred_ra"][i]), self.num_times) + self.assertEqual(len(table["pred_dec"][i]), self.num_times) for j in range(self.num_times): self.assertEqual(table["all_stamps"][i][j].shape, (10, 10)) diff --git a/tests/test_stats_filters.py b/tests/test_stats_filters.py index 83348a1ec..cc32a08ce 100644 --- a/tests/test_stats_filters.py +++ b/tests/test_stats_filters.py @@ -13,9 +13,9 @@ def setUp(self): self.rs = ResultList(self.times, track_filtered=True) for i in range(10): t = Trajectory() + t.lh = float(i) row = ResultRow(t, self.num_times) row.filter_indices([k for k in range(i)]) - row.final_likelihood = float(i) self.rs.append_result(row) def test_filter_min_likelihood(self): @@ -67,8 +67,9 @@ def test_filter_likelihood_mp(self): # Create a lot more results. rs = ResultList(self.times, track_filtered=True) for i in range(1000): - row = ResultRow(Trajectory(), self.num_times) - row.final_likelihood = 0.01 * float(i) + trj = Trajectory() + trj.lh = 0.01 * float(i) + row = ResultRow(trj, self.num_times) rs.append_result(row) self.assertEqual(rs.num_results(), 1000) diff --git a/tests/test_trajectory_utils.py b/tests/test_trajectory_utils.py index 92b4a32ab..7933eb4ca 100644 --- a/tests/test_trajectory_utils.py +++ b/tests/test_trajectory_utils.py @@ -1,5 +1,7 @@ import unittest +from astropy.wcs import WCS + from kbmod.trajectory_utils import * from kbmod.search import * @@ -15,6 +17,29 @@ def test_make_trajectory(self): self.assertEqual(trj.lh, 6.0) self.assertEqual(trj.obs_count, 7) + def test_predict_skypos(self): + # Create a fake WCS with a known pointing. + my_wcs = WCS(naxis=2) + my_wcs.wcs.crpix = [10.0, 10.0] # Reference point on the image (1-indexed) + my_wcs.wcs.crval = [45.0, -15.0] # Reference pointing on the sky + my_wcs.wcs.cdelt = [0.1, 0.1] # Pixel step size + my_wcs.wcs.ctype = ["RA---TAN-SIP", "DEC--TAN-SIP"] + + # Confirm that the wcs produces the correct prediction (using zero indexed pixel). + my_sky = my_wcs.pixel_to_world(9.0, 9.0) + self.assertAlmostEqual(my_sky.ra.deg, 45.0) + self.assertAlmostEqual(my_sky.dec.deg, -15.0) + + # Create a trajectory starting at the middle and traveling +2 pixels a day in x and -5 in y. + trj = make_trajectory(x=9, y=9, vx=2.0, vy=-5.0) + + # Predict locations at times 0.0 and 1.0 + my_sky = trajectory_predict_skypos(trj, my_wcs, [0.0, 1.0]) + self.assertAlmostEqual(my_sky.ra[0].deg, 45.0) + self.assertAlmostEqual(my_sky.dec[0].deg, -15.0) + self.assertAlmostEqual(my_sky.ra[1].deg, 45.2, delta=0.01) + self.assertAlmostEqual(my_sky.dec[1].deg, -15.5, delta=0.01) + def test_trajectory_from_np_object(self): np_obj = np.array( [(300.0, 750.0, 106.0, 44.0, 9.52, -0.5, 10.0)], From 39c9c69a3fdf6041b6412eedde9987cf2456dc85 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:35:15 -0500 Subject: [PATCH 2/4] Add two new filters --- src/kbmod/analysis_utils.py | 17 +++-- src/kbmod/filters/stats_filters.py | 101 +++++++++++++++++++++++++++++ tests/test_stats_filters.py | 44 ++++++++++++- 3 files changed, 153 insertions(+), 9 deletions(-) diff --git a/src/kbmod/analysis_utils.py b/src/kbmod/analysis_utils.py index 06697c591..b88400840 100644 --- a/src/kbmod/analysis_utils.py +++ b/src/kbmod/analysis_utils.py @@ -25,6 +25,7 @@ class PostProcess: def __init__(self, config, mjds): self.coeff = None self.num_cores = config["num_cores"] + self.num_obs = config["num_obs"] self.sigmaG_lims = config["sigmaG_lims"] self.eps = config["eps"] self.cluster_type = config["cluster_type"] @@ -75,6 +76,12 @@ def load_and_filter_results( bnds = [25, 75] clipper = SigmaGClipping(bnds[0], bnds[1], 2, self.clip_negative) + # Set up the combined stats filter. + if lh_level > 0.0: + stats_filter = CombinedStatsFilter(min_obs=self.num_obs, min_lh=lh_level) + else: + stats_filter = CombinedStatsFilter(min_obs=self.num_obs) + print("---------------------------------------") print("Retrieving Results") print("---------------------------------------") @@ -94,7 +101,9 @@ def load_and_filter_results( if trj.lh < lh_level: likelihood_limit = True break - if trj.lh < max_lh: + + # Skip points with too high a likelihood or with too few observations. + if trj.lh < max_lh and trj.obs_count >= self.num_obs: row = ResultRow(trj, len(self._mjds)) psi_curve = np.array(search.get_psi_curves(trj)) phi_curve = np.array(search.get_phi_curves(trj)) @@ -106,11 +115,7 @@ def load_and_filter_results( print("Extracted batch of %i results for total of %i" % (batch_size, total_count)) if batch_size > 0: apply_clipped_sigma_g(clipper, result_batch, self.num_cores) - result_batch.apply_filter(NumObsFilter(3)) - - # Apply the likelihood filter if one is provided. - if lh_level > 0.0: - result_batch.apply_filter(LHFilter(lh_level, None)) + result_batch.apply_filter(stats_filter) # Add the results to the final set. keep.extend(result_batch) diff --git a/src/kbmod/filters/stats_filters.py b/src/kbmod/filters/stats_filters.py index c335ad30d..4c0600f9d 100644 --- a/src/kbmod/filters/stats_filters.py +++ b/src/kbmod/filters/stats_filters.py @@ -112,3 +112,104 @@ def keep_row(self, row: ResultRow): An indicator of whether to keep the row. """ return len(row.valid_indices) >= self.min_obs + + +class CombinedStatsFilter(RowFilter): + """A filter for result's likelihood and number of observations.""" + + def __init__(self, min_obs=0, min_lh=-math.inf, max_lh=math.inf, *args, **kwargs): + """Create a ResultsLHFilter. + + Parameters + ---------- + min_obs : ``int`` + The minimum number of observations. + min_lh : ``float`` + Minimal allowed likelihood. + max_lh : ``float`` + Maximal allowed likelihood. + """ + super().__init__(*args, **kwargs) + + self.min_obs = min_obs + self.min_lh = min_lh + self.max_lh = max_lh + + def get_filter_name(self): + """Get the name of the filter. + + Returns + ------- + str + The filter name. + """ + return f"CombinedStats_{self.num_obs}_{self.min_lh}_to_{self.max_lh}" + + def keep_row(self, row: ResultRow): + """Determine whether to keep an individual row based on + the likelihood. + + Parameters + ---------- + row : ResultRow + The row to evaluate. + + Returns + ------- + bool + An indicator of whether to keep the row. + """ + if row.final_likelihood < self.min_lh or row.final_likelihood > self.max_lh: + return False + if len(row.valid_indices) >= self.min_obs: + return False + return True + + +class DurationFilter(RowFilter): + """A filter for the amount of time covered by the trajectory""" + + def __init__(self, all_times, min_duration, *args, **kwargs): + """Create a ResultsLHFilter. + + Parameters + ---------- + all_times : ``list`` + The time stamps in increasing order. + min_duration : ``float`` + The minimum duration in days for a valid result. + """ + super().__init__(*args, **kwargs) + + self.all_times = all_times + self.min_duration = min_duration + + def get_filter_name(self): + """Get the name of the filter. + + Returns + ------- + str + The filter name. + """ + return f"Duration_{self.min_duration}" + + def keep_row(self, row: ResultRow): + """Determine whether to keep an individual row based on + the likelihood. + + Parameters + ---------- + row : ResultRow + The row to evaluate. + + Returns + ------- + bool + An indicator of whether to keep the row. + """ + min_index = np.min(row.valid_indices) + max_index = np.max(row.valid_indices) + if self.all_times[max_index] - self.all_times[min_index] < self.min_duration: + return False + return True diff --git a/tests/test_stats_filters.py b/tests/test_stats_filters.py index cc32a08ce..45238ec21 100644 --- a/tests/test_stats_filters.py +++ b/tests/test_stats_filters.py @@ -12,9 +12,8 @@ def setUp(self): self.rs = ResultList(self.times, track_filtered=True) for i in range(10): - t = Trajectory() - t.lh = float(i) - row = ResultRow(t, self.num_times) + trj = make_trajectory(lh=float(i)) + row = ResultRow(trj, self.num_times) row.filter_indices([k for k in range(i)]) self.rs.append_result(row) @@ -98,6 +97,45 @@ def test_filter_valid_indices(self): for i in range(self.rs.num_results()): self.assertGreaterEqual(len(self.rs.results[i].valid_indices), 4) + def test_combined_stats_filter(self): + self.assertEqual(self.rs.num_results(), 10) + + f = CombinedStatsFilter(min_obs=4, min_lh=5.1) + self.assertEqual(f.get_filter_name(), "CombinedStats_10_5.0_inf") + + # Do the filtering and check we have the correct ones. + self.rs.apply_filter(f) + self.assertEqual(self.rs.num_results(), 4) + for row in self.rs.results: + self.assertGreaterEqual(len(row.valid_indices), 4) + self.assertGreaterEqual(len(row.final_likelihood), 5.1) + + def test_duration_filter(self): + f = DurationFilter(self.times, 5.1) + self.assertEqual(f.get_filter_name(), "Duration_9.1") + + res_list = ResultList(self.times, track_filtered=True) + + # Add a full track + row0 = ResultRow(Trajectory(), self.num_times) + res_list.append_result(row0) + + # Add a track with every 4th observation + row1 = ResultRow(Trajectory(), self.num_times) + row1.filter_indices([k for k in range(self.num_times) if k % 4 == 0]) + res_list.append_result(row1) + + # Add a track with a short burst in the middle. + row2 = ResultRow(Trajectory(), self.num_times) + row2.filter_indices([3, 4, 5, 6, 7, 8, 9]) + res_list.append_result(row1) + + res_list.apply_filter(f) + self.assertEqual(res_list.num_results(), 2) + + self.assertGreaterEqual(len(res_list.results[0].valid_indices), self.num_times) + self.assertGreaterEqual(len(res_list.results[1].valid_indices), int(self.num_times / 4)) + if __name__ == "__main__": unittest.main() From 2e93f0752e176174b69824f8c26249dd3e97aad5 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:37:27 -0500 Subject: [PATCH 3/4] Undo change (wrong branch) --- src/kbmod/analysis_utils.py | 17 ++--- src/kbmod/filters/stats_filters.py | 101 ----------------------------- tests/test_stats_filters.py | 49 ++------------ 3 files changed, 11 insertions(+), 156 deletions(-) diff --git a/src/kbmod/analysis_utils.py b/src/kbmod/analysis_utils.py index b88400840..06697c591 100644 --- a/src/kbmod/analysis_utils.py +++ b/src/kbmod/analysis_utils.py @@ -25,7 +25,6 @@ class PostProcess: def __init__(self, config, mjds): self.coeff = None self.num_cores = config["num_cores"] - self.num_obs = config["num_obs"] self.sigmaG_lims = config["sigmaG_lims"] self.eps = config["eps"] self.cluster_type = config["cluster_type"] @@ -76,12 +75,6 @@ def load_and_filter_results( bnds = [25, 75] clipper = SigmaGClipping(bnds[0], bnds[1], 2, self.clip_negative) - # Set up the combined stats filter. - if lh_level > 0.0: - stats_filter = CombinedStatsFilter(min_obs=self.num_obs, min_lh=lh_level) - else: - stats_filter = CombinedStatsFilter(min_obs=self.num_obs) - print("---------------------------------------") print("Retrieving Results") print("---------------------------------------") @@ -101,9 +94,7 @@ def load_and_filter_results( if trj.lh < lh_level: likelihood_limit = True break - - # Skip points with too high a likelihood or with too few observations. - if trj.lh < max_lh and trj.obs_count >= self.num_obs: + if trj.lh < max_lh: row = ResultRow(trj, len(self._mjds)) psi_curve = np.array(search.get_psi_curves(trj)) phi_curve = np.array(search.get_phi_curves(trj)) @@ -115,7 +106,11 @@ def load_and_filter_results( print("Extracted batch of %i results for total of %i" % (batch_size, total_count)) if batch_size > 0: apply_clipped_sigma_g(clipper, result_batch, self.num_cores) - result_batch.apply_filter(stats_filter) + result_batch.apply_filter(NumObsFilter(3)) + + # Apply the likelihood filter if one is provided. + if lh_level > 0.0: + result_batch.apply_filter(LHFilter(lh_level, None)) # Add the results to the final set. keep.extend(result_batch) diff --git a/src/kbmod/filters/stats_filters.py b/src/kbmod/filters/stats_filters.py index 4c0600f9d..c335ad30d 100644 --- a/src/kbmod/filters/stats_filters.py +++ b/src/kbmod/filters/stats_filters.py @@ -112,104 +112,3 @@ def keep_row(self, row: ResultRow): An indicator of whether to keep the row. """ return len(row.valid_indices) >= self.min_obs - - -class CombinedStatsFilter(RowFilter): - """A filter for result's likelihood and number of observations.""" - - def __init__(self, min_obs=0, min_lh=-math.inf, max_lh=math.inf, *args, **kwargs): - """Create a ResultsLHFilter. - - Parameters - ---------- - min_obs : ``int`` - The minimum number of observations. - min_lh : ``float`` - Minimal allowed likelihood. - max_lh : ``float`` - Maximal allowed likelihood. - """ - super().__init__(*args, **kwargs) - - self.min_obs = min_obs - self.min_lh = min_lh - self.max_lh = max_lh - - def get_filter_name(self): - """Get the name of the filter. - - Returns - ------- - str - The filter name. - """ - return f"CombinedStats_{self.num_obs}_{self.min_lh}_to_{self.max_lh}" - - def keep_row(self, row: ResultRow): - """Determine whether to keep an individual row based on - the likelihood. - - Parameters - ---------- - row : ResultRow - The row to evaluate. - - Returns - ------- - bool - An indicator of whether to keep the row. - """ - if row.final_likelihood < self.min_lh or row.final_likelihood > self.max_lh: - return False - if len(row.valid_indices) >= self.min_obs: - return False - return True - - -class DurationFilter(RowFilter): - """A filter for the amount of time covered by the trajectory""" - - def __init__(self, all_times, min_duration, *args, **kwargs): - """Create a ResultsLHFilter. - - Parameters - ---------- - all_times : ``list`` - The time stamps in increasing order. - min_duration : ``float`` - The minimum duration in days for a valid result. - """ - super().__init__(*args, **kwargs) - - self.all_times = all_times - self.min_duration = min_duration - - def get_filter_name(self): - """Get the name of the filter. - - Returns - ------- - str - The filter name. - """ - return f"Duration_{self.min_duration}" - - def keep_row(self, row: ResultRow): - """Determine whether to keep an individual row based on - the likelihood. - - Parameters - ---------- - row : ResultRow - The row to evaluate. - - Returns - ------- - bool - An indicator of whether to keep the row. - """ - min_index = np.min(row.valid_indices) - max_index = np.max(row.valid_indices) - if self.all_times[max_index] - self.all_times[min_index] < self.min_duration: - return False - return True diff --git a/tests/test_stats_filters.py b/tests/test_stats_filters.py index 45238ec21..83348a1ec 100644 --- a/tests/test_stats_filters.py +++ b/tests/test_stats_filters.py @@ -12,9 +12,10 @@ def setUp(self): self.rs = ResultList(self.times, track_filtered=True) for i in range(10): - trj = make_trajectory(lh=float(i)) - row = ResultRow(trj, self.num_times) + t = Trajectory() + row = ResultRow(t, self.num_times) row.filter_indices([k for k in range(i)]) + row.final_likelihood = float(i) self.rs.append_result(row) def test_filter_min_likelihood(self): @@ -66,9 +67,8 @@ def test_filter_likelihood_mp(self): # Create a lot more results. rs = ResultList(self.times, track_filtered=True) for i in range(1000): - trj = Trajectory() - trj.lh = 0.01 * float(i) - row = ResultRow(trj, self.num_times) + row = ResultRow(Trajectory(), self.num_times) + row.final_likelihood = 0.01 * float(i) rs.append_result(row) self.assertEqual(rs.num_results(), 1000) @@ -97,45 +97,6 @@ def test_filter_valid_indices(self): for i in range(self.rs.num_results()): self.assertGreaterEqual(len(self.rs.results[i].valid_indices), 4) - def test_combined_stats_filter(self): - self.assertEqual(self.rs.num_results(), 10) - - f = CombinedStatsFilter(min_obs=4, min_lh=5.1) - self.assertEqual(f.get_filter_name(), "CombinedStats_10_5.0_inf") - - # Do the filtering and check we have the correct ones. - self.rs.apply_filter(f) - self.assertEqual(self.rs.num_results(), 4) - for row in self.rs.results: - self.assertGreaterEqual(len(row.valid_indices), 4) - self.assertGreaterEqual(len(row.final_likelihood), 5.1) - - def test_duration_filter(self): - f = DurationFilter(self.times, 5.1) - self.assertEqual(f.get_filter_name(), "Duration_9.1") - - res_list = ResultList(self.times, track_filtered=True) - - # Add a full track - row0 = ResultRow(Trajectory(), self.num_times) - res_list.append_result(row0) - - # Add a track with every 4th observation - row1 = ResultRow(Trajectory(), self.num_times) - row1.filter_indices([k for k in range(self.num_times) if k % 4 == 0]) - res_list.append_result(row1) - - # Add a track with a short burst in the middle. - row2 = ResultRow(Trajectory(), self.num_times) - row2.filter_indices([3, 4, 5, 6, 7, 8, 9]) - res_list.append_result(row1) - - res_list.apply_filter(f) - self.assertEqual(res_list.num_results(), 2) - - self.assertGreaterEqual(len(res_list.results[0].valid_indices), self.num_times) - self.assertGreaterEqual(len(res_list.results[1].valid_indices), int(self.num_times / 4)) - if __name__ == "__main__": unittest.main() From 60c5caa7429467b8cbb5e7c78cbae95527ac73f6 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:47:02 -0500 Subject: [PATCH 4/4] Fix breakage caused by reverting bad push --- tests/test_stats_filters.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_stats_filters.py b/tests/test_stats_filters.py index 83348a1ec..bc45fffc5 100644 --- a/tests/test_stats_filters.py +++ b/tests/test_stats_filters.py @@ -3,6 +3,7 @@ from kbmod.filters.stats_filters import * from kbmod.result_list import * from kbmod.search import * +from kbmod.trajectory_utils import make_trajectory class test_basic_filters(unittest.TestCase): @@ -12,10 +13,9 @@ def setUp(self): self.rs = ResultList(self.times, track_filtered=True) for i in range(10): - t = Trajectory() - row = ResultRow(t, self.num_times) + trj = make_trajectory(lh=float(i)) + row = ResultRow(trj, self.num_times) row.filter_indices([k for k in range(i)]) - row.final_likelihood = float(i) self.rs.append_result(row) def test_filter_min_likelihood(self): @@ -67,8 +67,8 @@ def test_filter_likelihood_mp(self): # Create a lot more results. rs = ResultList(self.times, track_filtered=True) for i in range(1000): - row = ResultRow(Trajectory(), self.num_times) - row.final_likelihood = 0.01 * float(i) + trj = make_trajectory(lh=0.01 * float(i)) + row = ResultRow(trj, self.num_times) rs.append_result(row) self.assertEqual(rs.num_results(), 1000)