Skip to content

Commit

Permalink
Merge branch 'main' into single_trajectory
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Feb 12, 2024
2 parents f556434 + baa02c3 commit fb6d7c2
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 4 deletions.
31 changes: 29 additions & 2 deletions src/kbmod/search/pydocs/stack_search_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,38 @@ static const auto DOC_StackSearch_prepare_psi_phi = R"doc(
)doc";

static const auto DOC_StackSearch_get_results = R"doc(
todo
Get a batch of cached results.
Parameters
----------
start : `int`
The starting index of the results to retrieve. Returns
an empty list if start is past the end of the cache.
count : `int`
The maximum number of results to retrieve. Returns fewer
results if there are not enough in the cache.
Returns
-------
results : `List`
A list of ``Trajectory`` objects for the cached results.
Raises
------
``RunTimeError`` if start < 0 or count <= 0.
)doc";

static const auto DOC_StackSearch_set_results = R"doc(
todo
Set the cached results. Used for testing.
Parameters
----------
new_results : `List`
The list of results to store.
)doc";

static const auto DOC_StackSearch_clear_results = R"doc(
Clear the cached results.
)doc";

static const auto DOC_StackSearch_evaluate_single_trajectory = R"doc(
Expand Down
8 changes: 6 additions & 2 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,18 @@ void StackSearch::sort_results() {
}

std::vector<Trajectory> StackSearch::get_results(int start, int count) {
if (start < 0) throw std::runtime_error("start must be 0 or greater");
if (count <= 0) throw std::runtime_error("count must be greater than 0");

if (start + count >= results.size()) {
count = results.size() - start;
}
if (start < 0) throw std::runtime_error("start must be 0 or greater");
return std::vector<Trajectory>(results.begin() + start, results.begin() + start + count);
}

// This function is used only for testing by injecting known result trajectories.
void StackSearch::set_results(const std::vector<Trajectory>& new_results) { results = new_results; }
void StackSearch::clear_results() { results.clear(); }

#ifdef Py_PYTHON_H
static void stack_search_bindings(py::module& m) {
Expand Down Expand Up @@ -303,7 +306,8 @@ static void stack_search_bindings(py::module& m) {
.def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi)
.def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi)
.def("get_results", &ks::get_results, pydocs::DOC_StackSearch_get_results)
.def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results);
.def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results)
.def("clear_results", &ks::clear_results, pydocs::DOC_StackSearch_clear_results);
}
#endif /* Py_PYTHON_H */

Expand Down
1 change: 1 addition & 0 deletions src/kbmod/search/stack_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class StackSearch {

// Helper functions for testing
void set_results(const std::vector<Trajectory>& new_results);
void clear_results();

virtual ~StackSearch(){};

Expand Down
32 changes: 32 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@ def setUp(self):
self.params.m02_limit = 35.5
self.params.m20_limit = 35.5

def test_set_get_results(self):
results = self.search.get_results(0, 10)
self.assertEqual(len(results), 0)

trjs = [make_trajectory(i, i, 0.0, 0.0) for i in range(10)]
self.search.set_results(trjs)

# Check that we extract them all.
results = self.search.get_results(0, 10)
self.assertEqual(len(results), 10)
for i in range(10):
self.assertEqual(results[i].x, i)

# Check that we can run past the end of the results.
results = self.search.get_results(0, 100)
self.assertEqual(len(results), 10)

# Check that we can pull a subset.
results = self.search.get_results(2, 2)
self.assertEqual(len(results), 2)
self.assertEqual(results[0].x, 2)
self.assertEqual(results[1].x, 3)

# Check invalid settings
self.assertRaises(RuntimeError, self.search.get_results, -1, 5)
self.assertRaises(RuntimeError, self.search.get_results, 0, 0)

# Check that clear works.
self.search.clear_results()
results = self.search.get_results(0, 10)
self.assertEqual(len(results), 0)

@unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)")
def test_evaluate_single_trajectory(self):
test_trj = make_trajectory(
Expand Down

0 comments on commit fb6d7c2

Please sign in to comment.