Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to output results to single table #437

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/source/user_manual/search_params.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ This document serves to provide a quick overview of the existing parameters and
| | | directory with multiple FITS files |
| | | (one for each exposure). |
+------------------------+-----------------------------+----------------------------------------+
| ``ind_output_files`` | True | Output results to a series of |
| | | individual files (legacy format) |
+------------------------+-----------------------------+----------------------------------------+
| ``known_obj_obs`` | 3 | The minimum number of observations |
| | | needed to count a known object match. |
+------------------------+-----------------------------+----------------------------------------+
Expand Down Expand Up @@ -145,7 +148,12 @@ This document serves to provide a quick overview of the existing parameters and
| | | mask. See :ref:`Masking`. |
+------------------------+-----------------------------+----------------------------------------+
| ``res_filepath`` | None | The path of the directory in which to |
| | | store the results files. |
| | | store the individual results files. |
+------------------------+-----------------------------+----------------------------------------+
| ``result_filename`` | None | Full filename and path for a single |
| | | tabular result saves as ecsv. |
| | | Can be use used in addition to |
| | | outputting individual result files. |
+------------------------+-----------------------------+----------------------------------------+
| ``sigmaG_lims`` | [25, 75] | The percentiles to use in sigmaG |
| | | filtering, if |
Expand Down
2 changes: 2 additions & 0 deletions src/kbmod/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self):
"encode_num_bytes": -1,
"flag_keys": default_flag_keys,
"gpu_filter": False,
"ind_output_files": True,
"im_filepath": None,
"known_obj_obs": 3,
"known_obj_thresh": None,
Expand All @@ -72,6 +73,7 @@ def __init__(self):
"psf_file": None,
"repeated_flag_keys": default_repeated_flag_keys,
"res_filepath": None,
"result_filename": None,
"sigmaG_lims": [25, 75],
"stamp_radius": 10,
"stamp_type": "sum",
Expand Down
234 changes: 233 additions & 1 deletion src/kbmod/result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,44 @@
import multiprocessing as mp
import numpy as np
import os.path as ospath
from pathlib import Path

from astropy.table import Table
from yaml import dump, safe_load

from kbmod.file_utils import *
from kbmod.trajectory_utils import (
make_trajectory,
trajectory_from_yaml,
trajectory_predict_skypos,
trajectory_to_yaml,
)


def _check_optional_allclose(arr1, arr2):
"""Check whether toward optional numpy arrays have the same information.

Parameters
----------
arr1 : `numpy.ndarray` or `None`
The first array.
arr1 : `numpy.ndarray` or `None`
The second array.

Returns
-------
result : `bool`
Indicates whether the arrays are the same.
"""
if arr1 is None and arr2 is None:
return True
if arr1 is not None and arr2 is None:
return False
if arr1 is None and arr2 is not None:
return False
return np.allclose(arr1, arr2)


class ResultRow:
"""This class stores a collection of related data from a single kbmod result.
In order to maintain a consistent internal state, the class uses private variables
Expand Down Expand Up @@ -86,6 +112,104 @@ def __init__(self, trj, num_times):
self.trajectory = trj
self._valid_indices = [i for i in range(num_times)]

@classmethod
def from_table_row(cls, data, num_times=None):
"""Create a ResultRow object directly from an AstroPy Table row.

Parameters
----------
data : 'astropy.table.row.Row'
The incoming row.
all_times : `int`, optional
The number of total times in the data. If ``None`` tries
to extract from a "num_times" or "all_times" column.

Raises
------
KeyError if a column is missing.
"""
if num_times is None:
if "num_times" in data.columns:
num_times = data["num_times"]
elif "all_times" in data.columns:
num_times = len(data["all_times"])
else:
raise KeyError("Number of times is not specified.")

# Create the Trajectory object from the correct fields.
trj = make_trajectory(
data["trajectory_x"],
data["trajectory_y"],
data["trajectory_vx"],
data["trajectory_vy"],
data["flux"],
data["likelihood"],
data["obs_count"],
)

# Manually fill in all the rest of the values. We let the stamp related columns
# be empty to save space.
row = ResultRow(trj, num_times)
row._final_likelihood = data["likelihood"]
row._phi_curve = data["phi_curve"]
row.pred_dec = data["pred_dec"]
row.pred_ra = data["pred_ra"]
row._psi_curve = data["psi_curve"]
row._valid_indices = data["valid_indices"]

if "all_stamps" in data.columns:
row.all_stamps = data["all_stamps"]
else:
row.all_stamps = None

if "stamp" in data.columns:
row.stamp = data["stamp"]
else:
row.stamp = None

return row

def __eq__(self, other):
"""Test if two result rows are equal."""
if not isinstance(other, ResultRow):
return False

# Check the attributes of the trajectory first.
if (
self.trajectory.x != other.trajectory.x
or self.trajectory.y != other.trajectory.y
or self.trajectory.vx != other.trajectory.vx
or self.trajectory.vy != other.trajectory.vy
or self.trajectory.lh != other.trajectory.lh
or self.trajectory.flux != other.trajectory.flux
or self.trajectory.obs_count != other.trajectory.obs_count
):
return False

# Check the simple attributes.
if not self._num_times == other._num_times:
return False
if not self._final_likelihood == other._final_likelihood:
return False

# Check the curves and stamps.
if not _check_optional_allclose(self.all_stamps, other.all_stamps):
return False
if not _check_optional_allclose(self._phi_curve, other._phi_curve):
return False
if not _check_optional_allclose(self._psi_curve, other._psi_curve):
return False
if not _check_optional_allclose(self.stamp, other.stamp):
return False
if not _check_optional_allclose(self._valid_indices, other._valid_indices):
return False
if not _check_optional_allclose(self.pred_dec, other.pred_dec):
return False
if not _check_optional_allclose(self.pred_ra, other.pred_ra):
return False

return True

@property
def final_likelihood(self):
return self._final_likelihood
Expand Down Expand Up @@ -399,6 +523,59 @@ def from_yaml(cls, yaml_str):
result_list.filtered[key] = [ResultRow.from_yaml(row) for row in yaml_dict["filtered"][key]]
return result_list

@classmethod
def from_table(self, data, all_times=None, track_filtered=False):
"""Extract the ResultList from an astropy Table.

Parameters
----------
data : `astropy.table.Table`
The input data.
all_times : `List` or `numpy.ndarray` or None
The list of all time stamps. Must either be set or there
must be an all_times column in the Table.
track_filtered : `bool`
Indicates whether the ResultList should track future filtered points.

Raises
------
KeyError if any columns are missing or if is ``all_times`` is None and there
is no all_times column in the data.
"""
# Check that we have some list of time stamps and place it in all_times.
if all_times is None:
if "all_times" not in data.columns:
raise KeyError(f"No time stamps provided.")
else:
all_times = data["all_times"][0]
num_times = len(all_times)

result_list = ResultList(all_times, track_filtered)
for i in range(len(data)):
row = ResultRow.from_table_row(data[i], num_times)
result_list.append_result(row)
return result_list

@classmethod
def read_table(self, filename):
"""Read the ResultList from a table file.

Parameters
----------
filename : `str`
The name of the file to load.


Raises
------
FileNotFoundError if the file is not found.
KeyError if any of the columns are missing.
"""
if not Path(filename).is_file():
raise FileNotFoundError
data = Table.read(filename)
return ResultList.from_table(data)

def num_results(self):
"""Return the number of results in the list.

Expand All @@ -413,6 +590,37 @@ def __len__(self):
"""Return the number of results in the list."""
return len(self.results)

def __eq__(self, other):
"""Test if two ResultLists are equal. Includes both ordering and values."""
if not isinstance(other, ResultList):
return False
if not np.allclose(self._all_times, other._all_times):
return False
if self.track_filtered != other.track_filtered:
return False

num_results = len(self.results)
if num_results != len(other.results):
return False
for i in range(num_results):
if self.results[i] != other.results[i]:
return False

if len(self.filtered) != len(other.filtered):
return False
for key in self.filtered.keys():
if key not in other.filtered:
return False

num_filtered = len(self.filtered[key])
if num_filtered != len(other.filtered):
return False
for i in range(num_filtered):
if self.filtered[key][i] != other.filtered[key][i]:
return False

return True

def clear(self):
"""Clear the list of results."""
self.results.clear()
Expand Down Expand Up @@ -648,14 +856,16 @@ def revert_filter(self, label=None):

return self

def to_table(self, filtered_label=None):
def to_table(self, filtered_label=None, append_times=False):
"""Extract the results into an astropy table.

Parameters
----------
filtered_label : `str`, optional
The filtering label to extract. If None then extracts
the unfiltered rows. (default=None)
append_times : `bool`
Append the list of all times as a column in the data.

Returns
-------
Expand Down Expand Up @@ -690,12 +900,34 @@ def to_table(self, filtered_label=None):
"pred_ra": [],
"pred_dec": [],
}
if append_times:
table_dict["all_times"] = []

# Use a (slow) linear scan to do the transformation.
for row in list_ref:
row.append_to_dict(table_dict, True)
if append_times:
table_dict["all_times"].append(self._all_times)

return Table(table_dict)

def write_table(self, filename, overwrite=True):
"""Write the unfiltered results to a single (ecsv) file.

Parameter
---------
filename : `str`
The name of the result file.
overwrite : `bool`
Overwrite the file if it already exists.
"""
table_version = self.to_table(append_times=True)

# Drop the all stamps column as this is often too large to write in a CSV entry.
table_version.remove_column("all_stamps")

table_version.write(filename, overwrite=True)

def to_yaml(self, serialize_filtered=False):
"""Serialize the ResultList as a YAML string.

Expand Down
4 changes: 3 additions & 1 deletion src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,13 @@ def run_search(self, config, stack):

# Save the results and the configuration information used.
print(f"Found {keep.num_results()} potential trajectories.")
if config["res_filepath"] is not None:
if config["res_filepath"] is not None and config["ind_output_files"]:
keep.save_to_files(config["res_filepath"], config["output_suffix"])

config_filename = os.path.join(config["res_filepath"], f"config_{config['output_suffix']}.yml")
config.to_file(config_filename, overwrite=True)
if config["result_filename"] is not None:
keep.write_table(config["result_filename"])

end = time.time()
print("Time taken for patch: ", end - start)
Expand Down
Loading