From f55643410484c9e10b40f61e175d2992a6385dd1 Mon Sep 17 00:00:00 2001
From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com>
Date: Mon, 12 Feb 2024 16:03:34 -0500
Subject: [PATCH] Create a tool for evaluating a single trajectory

---
 src/kbmod/filters/sigma_g_filter.py |  20 ++++-
 src/kbmod/trajectory_explorer.py    | 119 ++++++++++++++++++++++++++++
 tests/test_trajectory_explorer.py   |  77 ++++++++++++++++++
 3 files changed, 213 insertions(+), 3 deletions(-)
 create mode 100644 src/kbmod/trajectory_explorer.py
 create mode 100644 tests/test_trajectory_explorer.py

diff --git a/src/kbmod/filters/sigma_g_filter.py b/src/kbmod/filters/sigma_g_filter.py
index 89e80fbf9..5c8d39ecb 100644
--- a/src/kbmod/filters/sigma_g_filter.py
+++ b/src/kbmod/filters/sigma_g_filter.py
@@ -116,6 +116,21 @@ def compute_clipped_sigma_g(self, lh):
         return good_index
 
 
+def apply_single_clipped_sigma_g(params, result):
+    """This function applies a clipped median filter to a single result from
+    KBMOD using sigmaG as a robust estimater of standard deviation.
+
+    Parameters
+    ----------
+    params : `SigmaGClipping`
+        The object to apply the SigmaG clipping.
+    result : `ResultRow`
+        The result details. This data gets modified directly by the filtering.
+    """
+    single_res = params.compute_clipped_sigma_g(result.likelihood_curve)
+    result.filter_indices(single_res)
+
+
 def apply_clipped_sigma_g(params, result_list, num_threads=1):
     """This function applies a clipped median filter to the results of a KBMOD
     search using sigmaG as a robust estimater of standard deviation.
@@ -142,6 +157,5 @@ def apply_clipped_sigma_g(params, result_list, num_threads=1):
         for i, res in enumerate(keep_idx_results):
             result_list.results[i].filter_indices(res)
     else:
-        for i, row in enumerate(result_list.results):
-            single_res = params.compute_clipped_sigma_g(row.likelihood_curve)
-            row.filter_indices(single_res)
+        for row in result_list.results:
+            apply_single_clipped_sigma_g(params, row)
diff --git a/src/kbmod/trajectory_explorer.py b/src/kbmod/trajectory_explorer.py
new file mode 100644
index 000000000..c1c73ad6d
--- /dev/null
+++ b/src/kbmod/trajectory_explorer.py
@@ -0,0 +1,119 @@
+import numpy as np
+
+from kbmod.configuration import SearchConfiguration
+from kbmod.filters.sigma_g_filter import apply_single_clipped_sigma_g, SigmaGClipping
+from kbmod.masking import apply_mask_operations
+from kbmod.result_list import ResultRow
+from kbmod.search import StackSearch, StampCreator
+
+
+class TrajectoryExplorer:
+    """A class to interactively run test trajectories through KBMOD.
+
+    Attributes
+    ----------
+    config : `SearchConfiguration`
+        The configuration parameters.
+    debug : `bool`
+        Use verbose debug output.
+    search : `kb.StackSearch`
+        The search object (with cached data).
+    """
+
+    def __init__(self, img_stack, config=None, debug=False):
+        """
+        Parameters
+        ----------
+        im_stack : `ImageStack`
+            The images to search.
+        config : `SearchConfiguration`, optional
+            The configuration parameters. If ``None`` uses the default
+            configuration parameters.
+        debug : `bool`
+            Use verbose debug output.
+        """
+        self._data_initalized = False
+        self.im_stack = img_stack
+        if config is None:
+            self.config = SearchConfiguration()
+        else:
+            self.config = config
+        self.debug = debug
+
+        # Allocate and configure the StackSearch object.
+        self.search = None
+
+    def initialize_data(self):
+        """Perform any needed initialization and preprocessing on the images."""
+        if self._data_initalized:
+            return
+
+        # Check if we need to apply legacy masking.
+        if self.config["do_mask"]:
+            self.im_stack = apply_mask_operations(self.config, self.im_stack)
+
+        # If we are using an encoded image representation on GPU, enable it and
+        # set the parameters.
+        if self.config["encode_num_bytes"] > 0:
+            self.search.enable_gpu_encoding(self.config["encode_num_bytes"])
+            if self.debug:
+                print(f"Setting encoding = {self.config['encode_num_bytes']}")
+
+        # Allocate the search structure.
+        self.search = StackSearch(self.im_stack)
+        self.search.set_debug(self.debug)
+
+        self._data_initalized = True
+
+    def evaluate_linear_trajectory(self, x, y, vx, vy):
+        """Evaluate a single linear trajectory in pixel space. Skips all the filtering
+        steps and returns the raw data.
+
+        Parameters
+        ----------
+        x : `int`
+            The starting x pixel of the trajectory.
+        y : `int`
+            The starting y pixel of the trajectory.
+        vx : `float`
+            The x velocity of the trajectory in pixels per day.
+        vy : `float`
+            The y velocity of the trajectory in pixels per day.
+
+        Returns
+        -------
+        result : `ResultRow`
+            The result data with all fields filled out.
+        """
+        self.initialize_data()
+
+        # Evaluate the trajectory.
+        trj = self.search.search_linear_trajectory(x, y, vx, vy)
+        result = ResultRow(trj, self.im_stack.img_count())
+
+        # Get the psi and phi curves and do the sigma_g filtering.
+        psi_curve = np.array(self.search.get_psi_curves(trj))
+        phi_curve = np.array(self.search.get_phi_curves(trj))
+        result.set_psi_phi(psi_curve, phi_curve)
+
+        # Get the individual stamps.
+        stamps = StampCreator.get_stamps(self.im_stack, result.trajectory, self.config["stamp_radius"])
+        result.all_stamps = np.array([stamp.image for stamp in stamps])
+
+        return result
+
+    def apply_sigma_g(self, result):
+        """Apply sigma G clipping to a single ResultRow. Modifies the row in-place.
+
+        Parameters
+        ----------
+        result : `ResultRow`
+            The row to test for filtering.
+        """
+        clipper = SigmaGClipping(
+            self.config["sigmaG_lims"][0],
+            self.config["sigmaG_lims"][1],
+            2,
+            self.config["clip_negative"],
+        )
+        apply_single_clipped_sigma_g(clipper, result)
diff --git a/tests/test_trajectory_explorer.py b/tests/test_trajectory_explorer.py
new file mode 100644
index 000000000..44f75066b
--- /dev/null
+++ b/tests/test_trajectory_explorer.py
@@ -0,0 +1,77 @@
+import unittest
+
+import numpy as np
+
+from kbmod.fake_data.fake_data_creator import FakeDataSet
+from kbmod.search import HAS_GPU
+from kbmod.trajectory_explorer import TrajectoryExplorer
+from kbmod.trajectory_utils import make_trajectory
+
+
+class test_trajectory_explorer(unittest.TestCase):
+    def setUp(self):
+        # image properties
+        self.img_count = 20
+        self.dim_x = 120
+        self.dim_y = 115
+
+        # create a Trajectory for the object
+        self.x0 = 27
+        self.y0 = 50
+        self.vx = 21.0
+        self.vy = -5.0
+        self.trj = make_trajectory(self.x0, self.y0, self.vx, self.vy, flux=500.0)
+
+        # create image set with single moving object
+        fake_times = [i / self.img_count for i in range(self.img_count)]
+        fake_ds = FakeDataSet(
+            self.dim_x,
+            self.dim_y,
+            fake_times,
+            noise_level=2.0,
+            psf_val=1.0,
+            use_seed=True,
+        )
+        fake_ds.insert_object(self.trj)
+
+        # Remove at least observation from the trajectory.
+        pred_x = int(self.x0 + fake_times[10] * self.vx + 0.5)
+        pred_y = int(self.y0 + fake_times[10] * self.vy + 0.5)
+        sci_t10 = fake_ds.stack.get_single_image(10).get_science()
+        for dy in [-1, 0, 1]:
+            for dx in [-1, 0, 1]:
+                sci_t10.set_pixel(pred_y + dy, pred_x + dx, 0.0001)
+
+        self.explorer = TrajectoryExplorer(fake_ds.stack)
+
+    @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)")
+    def test_evaluate_trajectory(self):
+        result = self.explorer.evaluate_linear_trajectory(self.x0, self.y0, self.vx, self.vy)
+
+        # We found the trajectory we were loooking for.
+        self.assertIsNotNone(result.trajectory)
+        self.assertEqual(result.trajectory.x, self.x0)
+        self.assertEqual(result.trajectory.y, self.y0)
+        self.assertEqual(result.trajectory.vx, self.vx)
+        self.assertEqual(result.trajectory.vy, self.vy)
+
+        # The statistics seem reasonable.
+        self.assertGreater(result.trajectory.lh, 50.0)
+        self.assertGreater(result.trajectory.flux, 50.0)
+        self.assertGreater(result.trajectory.obs_count, 10)
+
+        # We compute the rest of the data we need.
+        self.assertEqual(result.num_times, 20)
+        self.assertEqual(len(result.valid_indices), 20)
+        self.assertEqual(len(result.psi_curve), 20)
+        self.assertEqual(len(result.phi_curve), 20)
+        self.assertEqual(len(result.all_stamps), 20)
+
+        # At least one index 10 should be filtered by sigma G filtering.
+        self.explorer.apply_sigma_g(result)
+        self.assertLess(len(result.valid_indices), 20)
+        self.assertFalse(10 in result.valid_indices)
+
+
+if __name__ == "__main__":
+    unittest.main()