diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index 95519e82e..0a9dea843 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -157,19 +157,25 @@ static const auto DOC_StackSearch_prepare_psi_phi = R"doc( )doc"; static const auto DOC_StackSearch_get_results = R"doc( - Get a batch of cached results [start, finish). + Get a batch of cached results. Parameters ---------- start : `int` - The (inclusive) starting index of the results to retrieve. - end : `int` - The (exclusive) ending index of the results to retrieve. + The starting index of the results to retrieve. Returns + an empty list is start is past the end of the cache. + count : `int` + The maximum number of results to retrieve. Returns fewer + results if there are not enough in the cache. Returns ------- results : `List` A list of ``Trajectory`` objects for the cached results. + + Raises + ------ + ``RunTimeError`` if start < 0 or count <= 0. )doc"; static const auto DOC_StackSearch_set_results = R"doc( diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index 7c46be4a3..ca91644ad 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -256,10 +256,12 @@ void StackSearch::sort_results() { } std::vector StackSearch::get_results(int start, int count) { + if (start < 0) throw std::runtime_error("start must be 0 or greater"); + if (count <= 0) throw std::runtime_error("count must be greater than 0"); + if (start + count >= results.size()) { count = results.size() - start; } - if (start < 0) throw std::runtime_error("start must be 0 or greater"); return std::vector(results.begin() + start, results.begin() + start + count); } diff --git a/tests/test_search.py b/tests/test_search.py index 1a1689463..29c7ce148 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -107,11 +107,15 @@ def test_set_get_results(self): self.assertEqual(len(results), 10) # Check that we can pull a subset. - results = self.search.get_results(2, 4) + results = self.search.get_results(2, 2) self.assertEqual(len(results), 2) self.assertEqual(results[0].x, 2) self.assertEqual(results[1].x, 3) + # Check invalid settings + self.assertRaises(RuntimeError, self.search.get_results, -1, 5) + self.assertRaises(RuntimeError, self.search.get_results, 0, 0) + # Check that clear works. self.search.clear_results() results = self.search.get_results(0, 10)