Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have KBMOD append RA, dec columns if possible #770

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 85 additions & 6 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .results import Results
from .trajectory_generator import create_trajectory_generator
from .trajectory_utils import predict_pixel_locations
from .wcs_utils import wcs_to_dict
from .work_unit import WorkUnit

Expand Down Expand Up @@ -187,7 +188,14 @@ def do_gpu_search(self, config, stack, trj_generator):
keep = self.load_and_filter_results(search, config)
return keep

def run_search(self, config, stack, trj_generator=None, wcs=None, extra_meta=None):
def run_search(
self,
config,
stack,
trj_generator=None,
workunit=None,
extra_meta=None,
):
"""This function serves as the highest-level python interface for starting
a KBMOD search given an ImageStack and SearchConfiguration.

Expand All @@ -200,8 +208,8 @@ def run_search(self, config, stack, trj_generator=None, wcs=None, extra_meta=Non
trj_generator : `TrajectoryGenerator`, optional
The object to generate the candidate trajectories for each pixel.
If None uses the default EclipticCenteredSearch
wcs : `astropy.wcs.WCS`, optional
A global WCS for all images in the search.
workunit : `WorkUnit`, optional
An optional WorkUnit with additional meta-data, including the per-image WCS.
extra_meta : `dict`, optional
Any additional metadata to save as part of the results file.

Expand Down Expand Up @@ -253,8 +261,11 @@ def run_search(self, config, stack, trj_generator=None, wcs=None, extra_meta=Non
if config["save_all_stamps"]:
append_all_stamps(keep, stack, config["stamp_radius"])

# Append the WCS information if it is provided. This will be saved with the results.
keep.table.wcs = wcs
# Append additional information derived from the WorkUnit if one is provided,
# including a global WCS and per-time (RA, dec) predictions for each image.
if workunit is not None:
keep.table.wcs = workunit.wcs
append_ra_dec_to_results(workunit, keep)

# Create and save any additional meta data that should be saved with the results.
num_img = stack.img_count()
Expand Down Expand Up @@ -306,6 +317,74 @@ def run_search_from_work_unit(self, work):
work.config,
work.im_stack,
trj_generator=trj_generator,
wcs=work.wcs,
workunit=work,
extra_meta=extra_meta,
)


def append_ra_dec_to_results(workunit, results):
"""Append predicted (RA, dec) positions to the results.

Parameters
----------
workunit : `WorkUnit`
The WorkUnit with all the WCS information.
results : `Results`
The current table of results including the per-pixel trajectories.
This is modified in-place.
"""
num_results = len(results)
if num_results == 0:
return # Nothing to do

num_times = workunit.im_stack.img_count()
times = workunit.im_stack.build_zeroed_times()

# Predict where each candidate trajectory will be at each time step.
xp = predict_pixel_locations(times, results["x"], results["vx"], as_int=False)
yp = predict_pixel_locations(times, results["y"], results["vy"], as_int=False)

# Compute the predicted (RA, dec) positions for each trajectory in global space.
if workunit.wcs is not None:
logger.info("Found common WCS. Adding global_ra and global_dec columns.")

skypos = workunit.wcs.pixel_to_world(xp, yp)
results.table["global_ra"] = skypos.ra.degree
results.table["global_dec"] = skypos.dec.degree

# Loop over the trajectories to build the original positions.
all_ra = []
all_dec = []
for idx in range(num_results):
pos_tuples = [(xp[idx, j], yp[idx, j]) for j in range(num_times)]
skypos = workunit.image_positions_to_original_icrs(
image_indices=np.arange(num_times), # Compute for all times.
positions=pos_tuples,
input_format="xy",
output_format="radec",
filter_in_frame=False,
)

# We get back a list of SkyCoord, because we gave a list.
# So we flatten it and extract the coordinate values.
all_ra.append([skypos[j].ra.degree for j in range(num_times)])
all_dec.append([skypos[j].dec.degree for j in range(num_times)])

results.table["img_ra"] = all_ra
results.table["img_dec"] = all_dec
else:
logger.info("No common WCS found. Skipping global_ra and global_dec columns.")

# If there are no global WCS, we just predict per image.
all_ra = np.zeros((len(results), num_times))
all_dec = np.zeros((len(results), num_times))

for time_idx in range(num_times):
wcs = workunit.get_wcs(time_idx)
if wcs is not None:
skypos = wcs.pixel_to_world(xp[:, time_idx], yp[:, time_idx])
all_ra[:, time_idx] = skypos.ra.degree
all_dec[:, time_idx] = skypos.dec.degree

results.table["img_ra"] = all_ra
results.table["img_dec"] = all_dec
144 changes: 144 additions & 0 deletions tests/test_run_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Test some of the functions needed for running the search."""

import unittest

import numpy as np

from kbmod.configuration import SearchConfiguration
from kbmod.fake_data.fake_data_creator import create_fake_times, FakeDataSet
from kbmod.results import Results
from kbmod.run_search import append_ra_dec_to_results
from kbmod.search import *
from kbmod.wcs_utils import make_fake_wcs
from kbmod.work_unit import WorkUnit


class test_run_search(unittest.TestCase):
def test_append_ra_dec_global(self):
# Create a fake WorkUnit with 20 times, a completely random ImageStack,
# and no trajectories.
num_times = 20
fake_times = create_fake_times(num_times, t0=60676.0)
fake_ds = FakeDataSet(800, 600, fake_times)

# Append a global fake WCS and one for each time.
global_wcs = make_fake_wcs(20.0, 0.0, 800, 600, deg_per_pixel=0.5 / 3600.0)
all_wcs = []
for idx in range(num_times):
curr = make_fake_wcs(
20.01 + idx / 100.0, 0.01 + idx / 100.0, 800, 600, deg_per_pixel=0.5 / 3600.0
)
all_wcs.append(curr)

fake_wu = WorkUnit(
im_stack=fake_ds.stack,
config=SearchConfiguration(),
wcs=global_wcs,
per_image_wcs=all_wcs,
reprojected=True,
per_image_indices=[i for i in range(num_times)],
heliocentric_distance=np.full(num_times, 100.0),
obstimes=fake_times,
)

# Create three fake trajectories in the bounds of the images. We don't
# bother actually inserting them into the pixels.
trjs = [
Trajectory(x=5, y=10, vx=1, vy=1, flux=1000.0, lh=1000.0, obs_count=num_times),
Trajectory(x=400, y=300, vx=-5, vy=-2, flux=1000.0, lh=1000.0, obs_count=num_times),
Trajectory(x=100, y=500, vx=10, vy=-10, flux=1000.0, lh=1000.0, obs_count=num_times),
]
results = Results.from_trajectories(trjs)
self.assertEqual(len(results), 3)

append_ra_dec_to_results(fake_wu, results)

# The global RA should exist and be close to 20.0 for all observations.
self.assertEqual(len(results["global_ra"]), 3)
for i in range(3):
self.assertEqual(len(results["global_ra"][i]), num_times)
self.assertTrue(np.all(results["global_ra"][i] > 19.0))
self.assertTrue(np.all(results["global_ra"][i] < 21.0))

# The global Dec should exist and be close to 0.0 for all observations.
self.assertEqual(len(results["global_dec"]), 3)
for i in range(3):
self.assertEqual(len(results["global_dec"][i]), num_times)
self.assertTrue(np.all(results["global_dec"][i] > -1.0))
self.assertTrue(np.all(results["global_dec"][i] < 1.0))

# The per-image RA should exist, be close to 20.0 for all observations,
# and be different from the global RA
self.assertEqual(len(results["img_ra"]), 3)
for i in range(3):
self.assertEqual(len(results["img_ra"][i]), num_times)
self.assertTrue(np.all(results["img_ra"][i] > 19.0))
self.assertTrue(np.all(results["img_ra"][i] < 21.0))
self.assertFalse(np.any(results["img_ra"][i] == results["global_ra"][i]))

# The global Dec should exist and be close to 0.0 for all observations.
self.assertEqual(len(results["img_dec"]), 3)
for i in range(3):
self.assertEqual(len(results["img_dec"][i]), num_times)
self.assertTrue(np.all(results["img_dec"][i] > -1.0))
self.assertTrue(np.all(results["img_dec"][i] < 1.0))
self.assertFalse(np.any(results["img_dec"][i] == results["global_dec"][i]))

def test_append_ra_dec_no_global(self):
# Create a fake WorkUnit with 20 times, a completely random ImageStack,
# and no trajectories.
num_times = 20
fake_times = create_fake_times(num_times, t0=60676.0)
fake_ds = FakeDataSet(800, 600, fake_times)

# Append a global fake WCS and one for each time.
all_wcs = []
for idx in range(num_times):
curr = make_fake_wcs(
20.01 + idx / 100.0, 0.01 + idx / 100.0, 800, 600, deg_per_pixel=0.5 / 3600.0
)
all_wcs.append(curr)

fake_wu = WorkUnit(
im_stack=fake_ds.stack,
config=SearchConfiguration(),
wcs=None,
per_image_wcs=all_wcs,
reprojected=False,
per_image_indices=[i for i in range(num_times)],
obstimes=fake_times,
)

# Create three fake trajectories in the bounds of the images. We don't
# bother actually inserting them into the pixels.
trjs = [
Trajectory(x=5, y=10, vx=1, vy=1, flux=1000.0, lh=1000.0, obs_count=num_times),
Trajectory(x=400, y=300, vx=-5, vy=-2, flux=1000.0, lh=1000.0, obs_count=num_times),
Trajectory(x=100, y=500, vx=10, vy=-10, flux=1000.0, lh=1000.0, obs_count=num_times),
]
results = Results.from_trajectories(trjs)
self.assertEqual(len(results), 3)

append_ra_dec_to_results(fake_wu, results)

# The global RA and global dec should not exist.
self.assertFalse("global_ra" in results.colnames)
self.assertFalse("global_dec" in results.colnames)

# The per-image RA should exist, be close to 20.0 for all observations.
self.assertEqual(len(results["img_ra"]), 3)
for i in range(3):
self.assertEqual(len(results["img_ra"][i]), num_times)
self.assertTrue(np.all(results["img_ra"][i] > 19.0))
self.assertTrue(np.all(results["img_ra"][i] < 21.0))

# The global Dec should exist and be close to 0.0 for all observations.
self.assertEqual(len(results["img_dec"]), 3)
for i in range(3):
self.assertEqual(len(results["img_dec"][i]), num_times)
self.assertTrue(np.all(results["img_dec"][i] > -1.0))
self.assertTrue(np.all(results["img_dec"][i] < 1.0))


if __name__ == "__main__":
unittest.main()
Loading