Skip to content

Commit

Permalink
Merge pull request #399 from dirac-institute/add_filters
Browse files Browse the repository at this point in the history
Upgrade filter logic
  • Loading branch information
jeremykubica authored Dec 5, 2023
2 parents 58e2283 + e1143b9 commit 0ff88d0
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 29 deletions.
16 changes: 10 additions & 6 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .file_utils import *
from .filters.clustering_filters import DBSCANFilter
from .filters.stats_filters import LHFilter, NumObsFilter
from .filters.stats_filters import CombinedStatsFilter
from .filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping
from .result_list import ResultList, ResultRow

Expand All @@ -25,6 +25,7 @@ class PostProcess:
def __init__(self, config, mjds):
self.coeff = None
self.num_cores = config["num_cores"]
self.num_obs = config["num_obs"]
self.sigmaG_lims = config["sigmaG_lims"]
self.eps = config["eps"]
self.cluster_type = config["cluster_type"]
Expand Down Expand Up @@ -75,6 +76,12 @@ def load_and_filter_results(
bnds = [25, 75]
clipper = SigmaGClipping(bnds[0], bnds[1], 2, self.clip_negative)

# Set up the combined stats filter.
if lh_level > 0.0:
stats_filter = CombinedStatsFilter(min_obs=self.num_obs, min_lh=lh_level)
else:
stats_filter = CombinedStatsFilter(min_obs=self.num_obs)

print("---------------------------------------")
print("Retrieving Results")
print("---------------------------------------")
Expand All @@ -94,6 +101,7 @@ def load_and_filter_results(
if trj.lh < lh_level:
likelihood_limit = True
break

if trj.lh < max_lh:
row = ResultRow(trj, len(self._mjds))
psi_curve = np.array(search.get_psi_curves(trj))
Expand All @@ -106,11 +114,7 @@ def load_and_filter_results(
print("Extracted batch of %i results for total of %i" % (batch_size, total_count))
if batch_size > 0:
apply_clipped_sigma_g(clipper, result_batch, self.num_cores)
result_batch.apply_filter(NumObsFilter(3))

# Apply the likelihood filter if one is provided.
if lh_level > 0.0:
result_batch.apply_filter(LHFilter(lh_level, None))
result_batch.apply_filter(stats_filter)

# Add the results to the final set.
keep.extend(result_batch)
Expand Down
103 changes: 103 additions & 0 deletions src/kbmod/filters/stats_filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from kbmod.filters.base_filter import RowFilter
from kbmod.result_list import ResultRow

Expand Down Expand Up @@ -112,3 +114,104 @@ def keep_row(self, row: ResultRow):
An indicator of whether to keep the row.
"""
return len(row.valid_indices) >= self.min_obs


class CombinedStatsFilter(RowFilter):
"""A filter for result's likelihood and number of observations."""

def __init__(self, min_obs=0, min_lh=-np.inf, max_lh=np.inf, *args, **kwargs):
"""Create a ResultsLHFilter.
Parameters
----------
min_obs : ``int``
The minimum number of observations.
min_lh : ``float``
Minimal allowed likelihood.
max_lh : ``float``
Maximal allowed likelihood.
"""
super().__init__(*args, **kwargs)

self.min_obs = min_obs
self.min_lh = min_lh
self.max_lh = max_lh

def get_filter_name(self):
"""Get the name of the filter.
Returns
-------
str
The filter name.
"""
return f"CombinedStats_{self.min_obs}_{self.min_lh}_to_{self.max_lh}"

def keep_row(self, row: ResultRow):
"""Determine whether to keep an individual row based on
the likelihood.
Parameters
----------
row : ResultRow
The row to evaluate.
Returns
-------
bool
An indicator of whether to keep the row.
"""
if row.final_likelihood < self.min_lh or row.final_likelihood > self.max_lh:
return False
if len(row.valid_indices) < self.min_obs:
return False
return True


class DurationFilter(RowFilter):
"""A filter for the amount of time covered by the trajectory"""

def __init__(self, all_times, min_duration, *args, **kwargs):
"""Create a ResultsLHFilter.
Parameters
----------
all_times : ``list``
The time stamps in increasing order.
min_duration : ``float``
The minimum duration in days for a valid result.
"""
super().__init__(*args, **kwargs)

self.all_times = all_times
self.min_duration = min_duration

def get_filter_name(self):
"""Get the name of the filter.
Returns
-------
str
The filter name.
"""
return f"Duration_{self.min_duration}"

def keep_row(self, row: ResultRow):
"""Determine whether to keep an individual row based on
the likelihood.
Parameters
----------
row : ResultRow
The row to evaluate.
Returns
-------
bool
An indicator of whether to keep the row.
"""
min_index = np.min(row.valid_indices)
max_index = np.max(row.valid_indices)
if self.all_times[max_index] - self.all_times[min_index] < self.min_duration:
return False
return True
40 changes: 17 additions & 23 deletions tests/test_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@
from kbmod.fake_data_creator import add_fake_object
from kbmod.result_list import *
from kbmod.search import *
from kbmod.trajectory_utils import make_trajectory

from utils.utils_for_tests import get_absolute_data_path


class test_analysis_utils(unittest.TestCase):
def _make_trajectory(self, x0, y0, xv, yv, lh):
t = Trajectory()
t.x = x0
t.y = y0
t.vx = xv
t.vy = yv
t.lh = lh
return t

def setUp(self):
# The configuration parameters.
self.default_mask_bits_dict = {
Expand Down Expand Up @@ -277,11 +269,11 @@ def test_clustering(self):
cluster_params["mjd"] = np.array(self.stack.build_zeroed_times())

trjs = [
self._make_trajectory(10, 11, 1, 2, 100.0),
self._make_trajectory(10, 11, 10, 20, 100.0),
self._make_trajectory(40, 5, -1, 2, 100.0),
self._make_trajectory(5, 0, 1, 2, 100.0),
self._make_trajectory(5, 1, 1, 2, 100.0),
make_trajectory(x=10, y=11, vx=1, vy=2, lh=100.0),
make_trajectory(x=10, y=11, vx=10, vy=20, lh=100.0),
make_trajectory(x=40, y=5, vx=-1, vy=2, lh=100.0),
make_trajectory(x=5, y=0, vx=1, vy=2, lh=100.0),
make_trajectory(x=5, y=1, vx=1, vy=2, lh=100.0),
]

# Try clustering with positions, velocities, and angles.
Expand All @@ -306,15 +298,15 @@ def test_clustering(self):
self.assertEqual(results2.num_results(), 3)

def test_load_and_filter_results_lh(self):
# Create fake result trajectories with given initial likelihoods.
# Create fake result trajectories with given initial likelihoods. The 1st is
# filtered by max likelihood. The 4th and 5th are filtered by min likelihood.
trjs = [
self._make_trajectory(20, 20, 0, 0, 9000.0), # Filtered by max likelihood
self._make_trajectory(30, 30, 0, 0, 100.0),
self._make_trajectory(40, 40, 0, 0, 50.0),
self._make_trajectory(50, 50, 0, 0, 2.0), # Filtered by min likelihood
self._make_trajectory(60, 60, 0, 0, 1.0), # Filtered by min likelihood
make_trajectory(20, 20, 0, 0, 500.0, 9000.0, self.img_count),
make_trajectory(30, 30, 0, 0, 100.0, 100.0, self.img_count),
make_trajectory(40, 40, 0, 0, 50.0, 50.0, self.img_count),
make_trajectory(50, 50, 0, 0, 1.0, 2.0, self.img_count),
make_trajectory(60, 60, 0, 0, 1.0, 1.0, self.img_count),
]
fluxes = [500.0, 100.0, 50.0, 1.0, 0.1]

# Create fake images with the objects in them.
imlist = []
Expand All @@ -324,7 +316,7 @@ def test_load_and_filter_results_lh(self):

# Add the objects.
for j, trj in enumerate(trjs):
add_fake_object(im, trj.x, trj.y, fluxes[j], self.p)
add_fake_object(im, trj.x, trj.y, trj.flux, self.p)

# Append the image.
imlist.append(im)
Expand All @@ -334,15 +326,17 @@ def test_load_and_filter_results_lh(self):
search.set_results(trjs)

# Do the filtering.
self.config["num_obs"] = 5
kb_post_process = PostProcess(self.config, self.time_list)

results = kb_post_process.load_and_filter_results(
search,
10.0, # min likelihood
chunk_size=500000,
max_lh=1000.0,
)

# Only the middle two results should pass the filtering.
# Only two of the middle results should pass the filtering.
self.assertEqual(results.num_results(), 2)
self.assertEqual(results.results[0].trajectory.y, 30)
self.assertEqual(results.results[1].trajectory.y, 40)
Expand Down
39 changes: 39 additions & 0 deletions tests/test_stats_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,45 @@ def test_filter_valid_indices(self):
for i in range(self.rs.num_results()):
self.assertGreaterEqual(len(self.rs.results[i].valid_indices), 4)

def test_combined_stats_filter(self):
self.assertEqual(self.rs.num_results(), 10)

f = CombinedStatsFilter(min_obs=4, min_lh=5.1)
self.assertEqual(f.get_filter_name(), "CombinedStats_4_5.1_to_inf")

# Do the filtering and check we have the correct ones.
self.rs.apply_filter(f)
self.assertEqual(self.rs.num_results(), 4)
for row in self.rs.results:
self.assertGreaterEqual(len(row.valid_indices), 4)
self.assertGreaterEqual(row.final_likelihood, 5.1)

def test_duration_filter(self):
f = DurationFilter(self.times, 0.81)
self.assertEqual(f.get_filter_name(), "Duration_0.81")

res_list = ResultList(self.times, track_filtered=True)

# Add a full track
row0 = ResultRow(Trajectory(), self.num_times)
res_list.append_result(row0)

# Add a track with every 4th observation
row1 = ResultRow(Trajectory(), self.num_times)
row1.filter_indices([k for k in range(self.num_times) if k % 4 == 0])
res_list.append_result(row1)

# Add a track with a short burst in the middle.
row2 = ResultRow(Trajectory(), self.num_times)
row2.filter_indices([3, 4, 5, 6, 7, 8, 9])
res_list.append_result(row2)

res_list.apply_filter(f)
self.assertEqual(res_list.num_results(), 2)

self.assertGreaterEqual(len(res_list.results[0].valid_indices), self.num_times)
self.assertGreaterEqual(len(res_list.results[1].valid_indices), int(self.num_times / 4))


if __name__ == "__main__":
unittest.main()

0 comments on commit 0ff88d0

Please sign in to comment.