Skip to content

Commit

Permalink
Merge pull request #394 from dirac-institute/sigma_g_move
Browse files Browse the repository at this point in the history
Move sigmaG filtering to its own file
  • Loading branch information
jeremykubica authored Dec 1, 2023
2 parents c45d315 + 8964c59 commit 62a7c75
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 191 deletions.
140 changes: 12 additions & 128 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from .file_utils import *
from .filters.clustering_filters import DBSCANFilter
from .filters.stats_filters import *
from .filters.stats_filters import LHFilter, NumObsFilter
from .filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping
from .result_list import ResultList, ResultRow


Expand Down Expand Up @@ -67,6 +68,13 @@ def load_and_filter_results(
res_num = 0
total_count = 0

# Set up the clipped sigmaG filter.
if self.sigmaG_lims is not None:
bnds = self.sigmaG_lims
else:
bnds = [25, 75]
clipper = SigmaGClipping(bnds[0], bnds[1], 2, self.clip_negative)

print("---------------------------------------")
print("Retrieving Results")
print("---------------------------------------")
Expand Down Expand Up @@ -97,11 +105,12 @@ def load_and_filter_results(
batch_size = result_batch.num_results()
print("Extracted batch of %i results for total of %i" % (batch_size, total_count))
if batch_size > 0:
self.apply_clipped_sigmaG(result_batch)
apply_clipped_sigma_g(clipper, result_batch, self.num_cores)
result_batch.apply_filter(NumObsFilter(3))

# Apply the likelihood filter if one is provided.
if lh_level > 0.0:
result_batch.apply_filter(LHFilter(lh_level, None))
result_batch.apply_filter(NumObsFilter(3))

# Add the results to the final set.
keep.extend(result_batch)
Expand Down Expand Up @@ -132,109 +141,6 @@ def get_all_stamps(self, result_list, search, stamp_radius):
# ref to its private field. That's a fix for another time.
row.all_stamps = np.array([stamp.image for stamp in stamps])

def apply_clipped_sigmaG(self, result_list):
"""This function applies a clipped median filter to the results of a KBMOD
search using sigmaG as a robust estimater of standard deviation.
Parameters
----------
result_list : `ResultList`
The values from trajectories. This data gets modified directly
by the filtering.
"""
print("Applying Clipped-sigmaG Filtering")
start_time = time.time()

# Compute the coefficients for the filtering.
if self.coeff is None:
if self.sigmaG_lims is not None:
self.percentiles = self.sigmaG_lims
else:
self.percentiles = [25, 75]
self.coeff = find_sigmaG_coeff(self.percentiles)

if self.num_cores > 1:
zipped_curves = result_list.zip_phi_psi_idx()

keep_idx_results = []
print("Starting pooling...")
pool = mp.Pool(processes=self.num_cores)
keep_idx_results = pool.starmap_async(self._clipped_sigmaG, zipped_curves)
pool.close()
pool.join()
keep_idx_results = keep_idx_results.get()

for i, res in enumerate(keep_idx_results):
result_list.results[i].filter_indices(res[1])
else:
for i, row in enumerate(result_list.results):
single_res = self._clipped_sigmaG(row.psi_curve, row.phi_curve, i)
row.filter_indices(single_res[1])

end_time = time.time()
time_elapsed = end_time - start_time
print("{:.2f}s elapsed".format(time_elapsed))
print("Completed filtering.", flush=True)
print("---------------------------------------")

def _clipped_sigmaG(self, psi_curve, phi_curve, index, n_sigma=2):
"""This function applies a clipped median filter to a set of likelihood
values. Points are eliminated if they are more than n_sigma*sigmaG away
from the median.
Parameters
----------
psi_curve : numpy array
A single Psi curve, likely from a `ResultRow`.
phi_curve : numpy array
A single Phi curve, likely from a `ResultRow`.
index : int
The index of the ResultRow being processed. Used track
multiprocessing.
n_sigma : int
The number of standard deviations away from the median that
the largest likelihood values (N=num_clipped) must be in order
to be eliminated.
Returns
-------
index : int
The index of the ResultRow being processed. Used track multiprocessing.
good_index: numpy array
The indices that pass the filtering for a given set of curves.
new_lh : float
The new maximum likelihood of the set of curves, after max_lh_index has
been applied.
"""
masked_phi = np.copy(phi_curve)
masked_phi[masked_phi == 0] = 1e9

lh = psi_curve / np.sqrt(masked_phi)
good_index = self._exclude_outliers(lh, n_sigma)
if len(good_index) == 0:
new_lh = 0
good_index = []
else:
new_lh = kb.calculate_likelihood_psi_phi(psi_curve[good_index], phi_curve[good_index])
return (index, good_index, new_lh)

def _exclude_outliers(self, lh, n_sigma):
if self.clip_negative:
lower_per, median, upper_per = np.percentile(
lh[lh > 0], [self.percentiles[0], 50, self.percentiles[1]]
)
sigmaG = self.coeff * (upper_per - lower_per)
nSigmaG = n_sigma * sigmaG
good_index = np.where(
np.logical_and(lh != 0, np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))
)[0]
else:
lower_per, median, upper_per = np.percentile(lh, [self.percentiles[0], 50, self.percentiles[1]])
sigmaG = self.coeff * (upper_per - lower_per)
nSigmaG = n_sigma * sigmaG
good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0]
return good_index

def apply_stamp_filter(
self,
result_list,
Expand Down Expand Up @@ -382,25 +288,3 @@ def apply_clustering(self, result_list, cluster_params):
cluster_params["mjd"],
)
result_list.apply_batch_filter(f)


# Additional math utilities -----------


def invert_Gaussian_CDF(z):
if z < 0.5:
sign = -1
else:
sign = 1
x = sign * np.sqrt(2) * erfinv(sign * (2 * z - 1)) # mpmath.erfinv(sign * (2 * z - 1))
return float(x)


def find_sigmaG_coeff(percentiles):
z1 = percentiles[0] / 100
z2 = percentiles[1] / 100

x1 = invert_Gaussian_CDF(z1)
x2 = invert_Gaussian_CDF(z2)
coeff = 1 / (x2 - x1)
return coeff
123 changes: 123 additions & 0 deletions src/kbmod/filters/sigma_g_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Functions to help with the SigmaG clipping.
For more details see:
Sifting Through the Static: Moving Objectg Detection in Difference Images
by Smotherman et. al. 2021
"""

import multiprocessing as mp
import numpy as np
from scipy.special import erfinv

from kbmod.result_list import ResultList, ResultRow


class SigmaGClipping:
"""This class contains the basic information for performing SigmaG clipping.
Attributes
----------
low_bnd : `float`
The lower bound of the interval to use to estimate the standard deviation.
high_bnd : `float`
The upper bound of the interval to use to estimate the standard deviation.
n_sigma : `float`
The number of standard deviations to use for the bound.
clip_negative : `bool`
A Boolean indicating whether to use negative values when computing
standard deviation.
coeff : `float`
The precomputed coefficient based on the given bounds.
"""

def __init__(self, low_bnd=25, high_bnd=75, n_sigma=2, clip_negative=False):
if low_bnd > high_bnd or low_bnd <= 0.0 or high_bnd >= 100.0:
raise ValueError(f"Invalid bounds [{low_bnd}, {high_bnd}]")
if n_sigma <= 0.0:
raise ValueError(f"Invalid n_sigma {n_sigma}")

self.low_bnd = low_bnd
self.high_bnd = high_bnd
self.clip_negative = clip_negative
self.n_sigma = n_sigma
self.coeff = SigmaGClipping.find_sigma_g_coeff(low_bnd, high_bnd)

@staticmethod
def find_sigma_g_coeff(low_bnd, high_bnd):
x1 = SigmaGClipping.invert_gauss_cdf(low_bnd / 100.0)
x2 = SigmaGClipping.invert_gauss_cdf(high_bnd / 100.0)
return 1 / (x2 - x1)

@staticmethod
def invert_gauss_cdf(z):
if z < 0.5:
sign = -1
else:
sign = 1
x = sign * np.sqrt(2) * erfinv(sign * (2 * z - 1))
return float(x)

def compute_clipped_sigma_g(self, lh):
"""Compute the SigmaG clipping on the given likelihood curve.
Points are eliminated if they are more than n_sigma*sigmaG away from the median.
Parameters
----------
lh : numpy array
A single likelihood curve.
Returns
-------
good_index: numpy array
The indices that pass the filtering for a given set of curves.
"""
if self.clip_negative:
lower_per, median, upper_per = np.percentile(lh[lh > 0], [self.low_bnd, 50, self.high_bnd])
else:
lower_per, median, upper_per = np.percentile(lh, [self.low_bnd, 50, self.high_bnd])

delta = max(upper_per - lower_per, 1e-8)
sigmaG = self.coeff * delta
nSigmaG = self.n_sigma * sigmaG

# Its unclear why we only filter zeros for one of the two cases, but leaving the logic in
# to stay consistent with the original code.
if self.clip_negative:
good_index = np.where(
np.logical_and(lh != 0, np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))
)[0]
else:
good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0]

return good_index


def apply_clipped_sigma_g(params, result_list, num_threads=1):
"""This function applies a clipped median filter to the results of a KBMOD
search using sigmaG as a robust estimater of standard deviation.
Parameters
----------
params : `SigmaGClipping`
The object to apply the SigmaG clipping.
result_list : `ResultList`
The values from trajectories. This data gets modified directly by the filtering.
num_threads : `int`
The number of threads to use.
"""
if num_threads > 1:
lh_list = [[row.likelihood_curve] for row in result_list.results]

keep_idx_results = []
pool = mp.Pool(processes=num_threads)
keep_idx_results = pool.starmap_async(params.compute_clipped_sigma_g, lh_list)
pool.close()
pool.join()
keep_idx_results = keep_idx_results.get()

for i, res in enumerate(keep_idx_results):
result_list.results[i].filter_indices(res)
else:
for i, row in enumerate(result_list.results):
single_res = params.compute_clipped_sigma_g(row.likelihood_curve)
row.filter_indices(single_res)
45 changes: 21 additions & 24 deletions src/kbmod/result_list.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import math
import multiprocessing as mp
import os.path as ospath

import numpy as np
import os.path as ospath

from kbmod.file_utils import *

Expand Down Expand Up @@ -52,18 +51,16 @@ def light_curve(self):
Returns
-------
lc : list
The likelihood curve. This is an empty list if either
lc : `numpy.ndarray`
The light curve. This is an empty array if either
psi or phi are not set.
"""
if self.psi_curve is None or self.phi_curve is None:
return []
return np.array([])

num_elements = len(self.psi_curve)
lc = [0.0] * num_elements
for i in range(num_elements):
if self.phi_curve[i] != 0.0:
lc[i] = self.psi_curve[i] / self.phi_curve[i]
masked_phi = np.copy(self.phi_curve)
masked_phi[masked_phi == 0] = 1e12
lc = np.divide(self.psi_curve, masked_phi)
return lc

@property
Expand All @@ -72,20 +69,16 @@ def likelihood_curve(self):
Returns
-------
lh : list
The likelihood curve. This is an empty list if either
lh : `numpy.ndarray`
The likelihood curve. This is an empty array if either
psi or phi are not set.
"""
if self.psi_curve is None:
raise ValueError("Psi curve is None")
if self.phi_curve is None:
raise ValueError("Phi curve is None")

num_elements = len(self.psi_curve)
lh = [0.0] * num_elements
for i in range(num_elements):
if self.phi_curve[i] > 0.0:
lh[i] = self.psi_curve[i] / math.sqrt(self.phi_curve[i])
if self.psi_curve is None or self.phi_curve is None:
return np.array([])

masked_phi = np.copy(self.phi_curve)
masked_phi[masked_phi == 0] = 1e12
lh = np.divide(self.psi_curve, np.sqrt(masked_phi))
return lh

def valid_indices_as_booleans(self):
Expand Down Expand Up @@ -172,8 +165,8 @@ def _update_likelihood(self):
self.trajectory.lh = 0.0
self.trajectory.flux = 0.0
else:
self.final_likelihood = psi_sum / math.sqrt(phi_sum)
self.trajectory.lh = psi_sum / math.sqrt(phi_sum)
self.final_likelihood = psi_sum / np.sqrt(phi_sum)
self.trajectory.lh = psi_sum / np.sqrt(phi_sum)
self.trajectory.flux = psi_sum / phi_sum


Expand Down Expand Up @@ -208,6 +201,10 @@ def num_results(self):
"""
return len(self.results)

def __len__(self):
"""Return the number of results in the list."""
return len(self.results)

def clear(self):
"""Clear the list of results."""
self.results.clear()
Expand Down
Loading

0 comments on commit 62a7c75

Please sign in to comment.