Skip to content

Commit

Permalink
Breaking out helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Dec 18, 2024
1 parent f5c0148 commit 4898df1
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 136 deletions.
5 changes: 5 additions & 0 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
and helper functions for filtering and maintaining consistency between different attributes in each row.
"""

import copy
import csv
import logging
import numpy as np
Expand Down Expand Up @@ -132,6 +133,10 @@ def get_num_times(self):
return self.table["obs_valid"].shape[1]
return 0

def copy(self):
"""Return a deep copy of the current Results object."""
return copy.deepcopy(self)

@classmethod
def from_trajectories(cls, trajectories, track_filtered=False):
"""Extract data from a list of Trajectory objects.
Expand Down
97 changes: 97 additions & 0 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from scipy.optimize import linear_sum_assignment

from kbmod.search import Trajectory

Expand Down Expand Up @@ -263,3 +264,99 @@ def evaluate_trajectory_mse(trj, x_vals, y_vals, zeroed_times, centered=True):
# Compute the errors.
sq_err = (x_vals - pred_x) ** 2 + (y_vals - pred_y) ** 2
return np.mean(sq_err)


def ave_trajectory_distance(trjA, trjB, times=[0.0]):
"""Evaluate the average distance between two trajectories (in pixels)
across different times.
Parameters
----------
trjA : `Trajectory`
The first Trajectory to evaluate.
trjB : `Trajectory`
The second Trajectory to evaluate.
times : `list` or `numpy.ndarray`
The zero-shifted times at which to evaluate the matches.
The average of the distances at these times are used.
Returns
-------
ave_dist : `float`
The average distance in pixels.
"""
times = np.asarray(times)
if len(times) == 0:
raise ValueError("Empty times array.")

# Compute the predicted x and y positions for the first trajectory.
px_a = trjA.x + times * trjA.vx
py_a = trjA.y + times * trjA.vy

# Compute the predicted x and y positions for the second trajectory.
px_b = trjB.x + times * trjB.vx
py_b = trjB.y + times * trjB.vy

# Compute the Euclidean distance at each point and then the average distance.
dists = np.sqrt((px_a - px_b) ** 2 + (py_a - py_b) ** 2)
ave_dist = np.mean(dists)
return ave_dist


def match_trajectory_sets(traj_query, traj_base, threshold, times=[0.0]):
"""Find the best matching pairs of queries (smallest distance) between the
query trajectories and base trajectories such that each trajectory is used in
at most one pair.
Note
----
This function is designed to evaluate the performance of searches by determining
which true trajectories (traj_query) were found in the result set (traj_base).
Parameters
----------
traj_query : `list`
A list of trajectories to compare.
traj_base : `list`
The second list of trajectories to compare.
threshold : float
The distance threshold between two trajectories to count a match (in pixels).
times : `list`
The list of zero-shifted times at which to evaluate the matches.
The average of the distances at these times are used.
Returns
-------
results : `list`
A list the same length as traj_query where each entry i indicates the index
of the trajectory in traj_base that best matches trajectory traj_query[i] or
-1 if no match was found with a distance below the given threshold.
"""
times = np.asarray(times)
if len(times) == 0:
raise ValueError("Empty times array.")

if threshold <= 0.0:
raise ValueError(f"Threshold must be greater than zero: {threshold}")

num_query = len(traj_query)
num_base = len(traj_base)

# Compute the matrix of distances between each pair. If this double FOR loop
# becomes a bottleneck, we can vectorize.
dists = np.zeros((num_query, num_base))
for q_idx in range(num_query):
for b_idx in range(num_base):
dists[q_idx][b_idx] = ave_trajectory_distance(traj_query[q_idx], traj_base[b_idx], times)

# Use scipy to solve the optimal bipartite matching problem.
row_inds, col_inds = linear_sum_assignment(dists)

# For each query (row) we find the best matching column and check
# the distance against the threshold.
results = np.full(num_query, -1)
for row, col in zip(row_inds, col_inds):
if dists[row, col] < threshold:
results[row] = col

return results
142 changes: 6 additions & 136 deletions tests/test_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,144 +20,13 @@
from kbmod.results import Results
from kbmod.run_search import SearchRunner
from kbmod.search import *
from kbmod.trajectory_utils import match_trajectory_sets
from kbmod.wcs_utils import make_fake_wcs_info
from kbmod.work_unit import WorkUnit

logger = logging.getLogger(__name__)


def ave_trajectory_distance(trjA, trjB, times=[0.0]):
"""Evaluate the average distance between two trajectories (in pixels)
at different times.
Parameters
----------
trjA : `kbmod.search.Trajectory`
The first Trajectory to evaluate.
trjB : `kbmod.search.Trajectory`
The second Trajectory to evaluate.
times : `list`
The list of zero-shifted times at which to evaluate the
matches. The average of the distances at these times
are used.
Returns
-------
ave_dist : `float`
The average distance in pixels.
"""
total = 0.0
for t in times:
dx = (trjA.x + t * trjA.vx) - (trjB.x + t * trjB.vx)
dy = (trjA.y + t * trjA.vy) - (trjB.y + t * trjB.vy)
total += math.sqrt(dx * dx + dy * dy)

ave_dist = total / len(times)
return ave_dist


def find_unique_overlap(traj_query, traj_base, threshold, times=[0.0]):
"""Finds the set of trajectories in traj_query that are 'close' to
trajectories in traj_base such that each Trajectory in traj_base
is used at most once.
Used to evaluate the performance of algorithms.
Parameters
----------
traj1 : `list`
A list of trajectories to compare.
traj2 : `list`
The second list of trajectories to compare.
threshold : float
The distance threshold between two observations to count a
match (in pixels).
times : `list`
The list of zero-shifted times at which to evaluate the matches.
The average of the distances at these times are used.
Returns
-------
results : `list`
The list of trajectories that appear in both traj1 and traj2
where each Trajectory in each set is only used once.
"""
num_times = len(times)
size_base = len(traj_base)
used = [False] * size_base

results = []
for query in traj_query:
best_dist = 10.0 * threshold
best_ind = -1

# Check the current query against all unused base trajectories.
for j in range(size_base):
if not used[j]:
dist = ave_trajectory_distance(query, traj_base[j], times)
if dist < best_dist:
best_dist = dist
best_ind = j

# If we found a good match, save it.
if best_dist <= threshold:
results.append(query)
used[best_ind] = True
return results


def find_set_difference(traj_query, traj_base, threshold, times=[0.0]):
"""Finds the set of trajectories in traj_query that are NOT 'close' to
any trajectories in traj_base such that each Trajectory in traj_base
is used at most once.
Used to evaluate the performance of algorithms.
Parameters
----------
traj_query : `list`
A list of trajectories to compare.
traj_base : `list`
The second list of trajectories to compare.
threshold : `float`
The distance threshold between two observations
to count a match (in pixels).
times : `list`
The list of zero-shifted times at which to evaluate the matches.
The average of the distances at these times are used.
Returns
-------
results : `list`
A list of trajectories that appear in traj_query but not
in traj_base where each Trajectory in each set is only
used once.
"""
num_times = len(times)
size_base = len(traj_base)
used = [False] * size_base

results = []
for query in traj_query:
best_dist = 10.0 * threshold
best_ind = -1

# Check the current query against all unused base trajectories.
for j in range(size_base):
if not used[j]:
dist = ave_trajectory_distance(query, traj_base[j], times)
if dist < best_dist:
best_dist = dist
best_ind = j

# If we found a good match, save it.
if best_dist <= threshold:
used[best_ind] = True
else:
results.append(query)
return results


def make_fake_ImageStack(times, trjs, psf_vals):
"""Make a stack of fake layered images.
Expand Down Expand Up @@ -356,20 +225,21 @@ def run_full_test():
)

# Determine which trajectories we did not recover.
overlap = find_unique_overlap(trjs, found, 3.0, [0.0, 2.0])
missing = find_set_difference(trjs, found, 3.0, [0.0, 2.0])
matches = match_trajectory_sets(trjs, found, 3.0, [0.0, 2.0])
overlap = np.where(matches > -1)[0]
missing = np.where(matches == -1)[0]

logger.debug("\nRecovered %i matching trajectories:" % len(overlap))
for x in overlap:
logger.debug(x)
logger.debug(trjs[x])

if len(missing) == 0:
logger.debug("*** PASSED ***")
return True
else:
logger.debug("\nFailed to recover %i trajectories:" % len(missing))
for x in missing:
logger.debug(x)
logger.debug(trjs[x])
logger.debug("*** FAILED ***")
return False

Expand Down
10 changes: 10 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def test_from_table(self):
table = Results(Table(self.input_dict))
self._assert_results_match_dict(table, self.input_dict)

def test_copy(self):
table1 = Results(self.input_dict)
table2 = table1.copy()

# Add a new column to table2 and check that it is not in table1
# (i.e. we have done a deep copy).
table2.table["something_added"] = [i for i in range(self.num_entries)]
self.assertTrue("something_added" in table2.colnames)
self.assertFalse("something_added" in table1.colnames)

def test_make_trajectory_list(self):
self.input_dict["something_added"] = [i for i in range(self.num_entries)]
table = Results(self.input_dict)
Expand Down
73 changes: 73 additions & 0 deletions tests/test_trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,79 @@ def test_evaluate_trajectory_mse(self):

self.assertRaises(ValueError, evaluate_trajectory_mse, trj, [], [], [])

def test_ave_trajectory_distance(self):
times_0 = np.array([0.0])
times_1 = np.array([0.0, 1.0])
times_5 = np.array([0.0, 1.0, 2.0, 3.0, 4.0])

# A trajectory is always zero pixels from itself.
trjA = Trajectory(x=1, y=2, vx=1.0, vy=-1.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjA, times_0), 0.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjA, times_1), 0.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjA, times_5), 0.0)

# Create a trajectory with a constant 1 pixel offset in the y direction.
trjB = Trajectory(x=1, y=1, vx=1.0, vy=-1.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjB, times_0), 1.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjB, times_1), 1.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjB, times_5), 1.0)

# Create a trajectory with an increasing offset in the x direction.
trjC = Trajectory(x=1, y=2, vx=2.0, vy=-1.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjC, times_0), 0.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjC, times_1), 0.5)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjC, times_5), 2.0)

# Create a trajectory with an increasing offset in the y direction.
trjC = Trajectory(x=1, y=2, vx=1.0, vy=1.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjC, times_0), 0.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjC, times_1), 1.0)
self.assertAlmostEqual(ave_trajectory_distance(trjA, trjC, times_5), 4.0)

# A list of empty times is invalid.
with self.assertRaises(ValueError):
_ = ave_trajectory_distance(trjA, trjC, [])

def test_match_trajectory_sets(self):
queries = [
Trajectory(x=0, y=0, vx=0.0, vy=0.0),
Trajectory(x=10, y=10, vx=0.5, vy=-2.0),
Trajectory(x=50, y=80, vx=-1.0, vy=0.0),
]
candidates = [
Trajectory(x=0, y=0, vx=0.0, vy=0.0), # Same as queries[0]
Trajectory(x=49, y=82, vx=-1.0, vy=0.01), # Close to queries[2]
]
results = match_trajectory_sets(queries, candidates, 5.0, [0.0, 10.0])
self.assertTrue(np.array_equal(results, [0, -1, 1]))

# Add a trajectory that is too far from queries[1] to be a good match.
candidates.append(Trajectory(x=15, y=15, vx=0.5, vy=-2.0))
results = match_trajectory_sets(queries, candidates, 5.0, [0.0, 10.0])
self.assertTrue(np.array_equal(results, [0, -1, 1]))

# Add a trajectory that is close to queries[1].
candidates.append(Trajectory(x=10, y=10, vx=0.6, vy=-2.5))
results = match_trajectory_sets(queries, candidates, 5.0, [0.0, 10.0])
self.assertTrue(np.array_equal(results, [0, 3, 1]))

# Add a trajectory that is even closer to queries[1].
candidates.append(Trajectory(x=10, y=10, vx=0.6, vy=-2.1))
results = match_trajectory_sets(queries, candidates, 5.0, [0.0, 10.0])
self.assertTrue(np.array_equal(results, [0, 4, 1]))

# Add another query trajectory that is close to queries[0], but
# not close enough to steal its match.
queries.append(Trajectory(x=1, y=0, vx=0.0, vy=0.0))
results = match_trajectory_sets(queries, candidates, 5.0, [0.0, 10.0])
self.assertTrue(np.array_equal(results, [0, 4, 1, -1]))

# Add another trajectory that is close to queries[0], but not as close
# as its current match. So this gets matched with queries[3] instead.
candidates.append(Trajectory(x=0, y=0, vx=0.0, vy=0.01))
results = match_trajectory_sets(queries, candidates, 5.0, [0.0, 10.0])
self.assertTrue(np.array_equal(results, [0, 4, 1, 5]))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4898df1

Please sign in to comment.