Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
jeremykubica committed Feb 12, 2024
1 parent 9d52bce commit 26bfa8b
Showing 3 changed files with 18 additions and 6 deletions.
14 changes: 10 additions & 4 deletions src/kbmod/search/pydocs/stack_search_docs.h
Original file line number Diff line number Diff line change
@@ -157,19 +157,25 @@ static const auto DOC_StackSearch_prepare_psi_phi = R"doc(
)doc";

static const auto DOC_StackSearch_get_results = R"doc(
Get a batch of cached results [start, finish).
Get a batch of cached results.
Parameters
----------
start : `int`
The (inclusive) starting index of the results to retrieve.
end : `int`
The (exclusive) ending index of the results to retrieve.
The starting index of the results to retrieve. Returns
an empty list is start is past the end of the cache.
count : `int`
The maximum number of results to retrieve. Returns fewer
results if there are not enough in the cache.
Returns
-------
results : `List`
A list of ``Trajectory`` objects for the cached results.
Raises
------
``RunTimeError`` if start < 0 or count <= 0.
)doc";

static const auto DOC_StackSearch_set_results = R"doc(
4 changes: 3 additions & 1 deletion src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
@@ -256,10 +256,12 @@ void StackSearch::sort_results() {
}

std::vector<Trajectory> StackSearch::get_results(int start, int count) {
if (start < 0) throw std::runtime_error("start must be 0 or greater");
if (count <= 0) throw std::runtime_error("count must be greater than 0");

if (start + count >= results.size()) {
count = results.size() - start;
}
if (start < 0) throw std::runtime_error("start must be 0 or greater");
return std::vector<Trajectory>(results.begin() + start, results.begin() + start + count);
}

6 changes: 5 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
@@ -107,11 +107,15 @@ def test_set_get_results(self):
self.assertEqual(len(results), 10)

# Check that we can pull a subset.
results = self.search.get_results(2, 4)
results = self.search.get_results(2, 2)
self.assertEqual(len(results), 2)
self.assertEqual(results[0].x, 2)
self.assertEqual(results[1].x, 3)

# Check invalid settings
self.assertRaises(RuntimeError, self.search.get_results, -1, 5)
self.assertRaises(RuntimeError, self.search.get_results, 0, 0)

# Check that clear works.
self.search.clear_results()
results = self.search.get_results(0, 10)

0 comments on commit 26bfa8b

Please sign in to comment.