From 2709a2930b1cb89d2bdddbcef297ab3563d92787 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 12 Feb 2024 12:15:12 -0500 Subject: [PATCH 1/4] Add tests --- tests/test_search.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_search.py b/tests/test_search.py index c9d2739f1..1a1689463 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -89,6 +89,34 @@ def setUp(self): self.params.m02_limit = 35.5 self.params.m20_limit = 35.5 + def test_set_get_results(self): + results = self.search.get_results(0, 10) + self.assertEqual(len(results), 0) + + trjs = [make_trajectory(i, i, 0.0, 0.0) for i in range(10)] + self.search.set_results(trjs) + + # Check that we extract them all. + results = self.search.get_results(0, 10) + self.assertEqual(len(results), 10) + for i in range(10): + self.assertEqual(results[i].x, i) + + # Check that we can run past the end of the results. + results = self.search.get_results(0, 100) + self.assertEqual(len(results), 10) + + # Check that we can pull a subset. + results = self.search.get_results(2, 4) + self.assertEqual(len(results), 2) + self.assertEqual(results[0].x, 2) + self.assertEqual(results[1].x, 3) + + # Check that clear works. + self.search.clear_results() + results = self.search.get_results(0, 10) + self.assertEqual(len(results), 0) + @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") def test_evaluate_single_trajectory(self): test_trj = make_trajectory( From 9d52bce74b2bab5c96916f8b3a027d900dac3c46 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 12 Feb 2024 12:16:17 -0500 Subject: [PATCH 2/4] Add a clear function --- src/kbmod/search/pydocs/stack_search_docs.h | 25 +++++++++++++++++++-- src/kbmod/search/stack_search.cpp | 4 +++- src/kbmod/search/stack_search.h | 1 + 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index 530ac5aa3..95519e82e 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -157,11 +157,32 @@ static const auto DOC_StackSearch_prepare_psi_phi = R"doc( )doc"; static const auto DOC_StackSearch_get_results = R"doc( - todo + Get a batch of cached results [start, finish). + + Parameters + ---------- + start : `int` + The (inclusive) starting index of the results to retrieve. + end : `int` + The (exclusive) ending index of the results to retrieve. + + Returns + ------- + results : `List` + A list of ``Trajectory`` objects for the cached results. )doc"; static const auto DOC_StackSearch_set_results = R"doc( - todo + Set the cached results. Used for testing. + + Parameters + ---------- + new_results : `List` + The list of results to store. + )doc"; + +static const auto DOC_StackSearch_clear_results = R"doc( + Clear the cached results. )doc"; static const auto DOC_StackSearch_evaluate_single_trajectory = R"doc( diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index be3f32224..7c46be4a3 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -265,6 +265,7 @@ std::vector StackSearch::get_results(int start, int count) { // This function is used only for testing by injecting known result trajectories. void StackSearch::set_results(const std::vector& new_results) { results = new_results; } +void StackSearch::clear_results() { results.clear(); } #ifdef Py_PYTHON_H static void stack_search_bindings(py::module& m) { @@ -303,7 +304,8 @@ static void stack_search_bindings(py::module& m) { .def("prepare_psi_phi", &ks::prepare_psi_phi, pydocs::DOC_StackSearch_prepare_psi_phi) .def("clear_psi_phi", &ks::clear_psi_phi, pydocs::DOC_StackSearch_clear_psi_phi) .def("get_results", &ks::get_results, pydocs::DOC_StackSearch_get_results) - .def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results); + .def("set_results", &ks::set_results, pydocs::DOC_StackSearch_set_results) + .def("clear_results", &ks::clear_results, pydocs::DOC_StackSearch_clear_results); } #endif /* Py_PYTHON_H */ diff --git a/src/kbmod/search/stack_search.h b/src/kbmod/search/stack_search.h index 9d59342df..9c492dad5 100644 --- a/src/kbmod/search/stack_search.h +++ b/src/kbmod/search/stack_search.h @@ -63,6 +63,7 @@ class StackSearch { // Helper functions for testing void set_results(const std::vector& new_results); + void clear_results(); virtual ~StackSearch(){}; From 26bfa8be40527eb56143dc522382719db1f91212 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 12 Feb 2024 12:28:37 -0500 Subject: [PATCH 3/4] Fixes --- src/kbmod/search/pydocs/stack_search_docs.h | 14 ++++++++++---- src/kbmod/search/stack_search.cpp | 4 +++- tests/test_search.py | 6 +++++- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index 95519e82e..0a9dea843 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -157,19 +157,25 @@ static const auto DOC_StackSearch_prepare_psi_phi = R"doc( )doc"; static const auto DOC_StackSearch_get_results = R"doc( - Get a batch of cached results [start, finish). + Get a batch of cached results. Parameters ---------- start : `int` - The (inclusive) starting index of the results to retrieve. - end : `int` - The (exclusive) ending index of the results to retrieve. + The starting index of the results to retrieve. Returns + an empty list is start is past the end of the cache. + count : `int` + The maximum number of results to retrieve. Returns fewer + results if there are not enough in the cache. Returns ------- results : `List` A list of ``Trajectory`` objects for the cached results. + + Raises + ------ + ``RunTimeError`` if start < 0 or count <= 0. )doc"; static const auto DOC_StackSearch_set_results = R"doc( diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index 7c46be4a3..ca91644ad 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -256,10 +256,12 @@ void StackSearch::sort_results() { } std::vector StackSearch::get_results(int start, int count) { + if (start < 0) throw std::runtime_error("start must be 0 or greater"); + if (count <= 0) throw std::runtime_error("count must be greater than 0"); + if (start + count >= results.size()) { count = results.size() - start; } - if (start < 0) throw std::runtime_error("start must be 0 or greater"); return std::vector(results.begin() + start, results.begin() + start + count); } diff --git a/tests/test_search.py b/tests/test_search.py index 1a1689463..29c7ce148 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -107,11 +107,15 @@ def test_set_get_results(self): self.assertEqual(len(results), 10) # Check that we can pull a subset. - results = self.search.get_results(2, 4) + results = self.search.get_results(2, 2) self.assertEqual(len(results), 2) self.assertEqual(results[0].x, 2) self.assertEqual(results[1].x, 3) + # Check invalid settings + self.assertRaises(RuntimeError, self.search.get_results, -1, 5) + self.assertRaises(RuntimeError, self.search.get_results, 0, 0) + # Check that clear works. self.search.clear_results() results = self.search.get_results(0, 10) From 484ad80ce2f70c9d1ec77900b102780598b2f678 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 12 Feb 2024 15:50:28 -0500 Subject: [PATCH 4/4] fix typo --- src/kbmod/search/pydocs/stack_search_docs.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kbmod/search/pydocs/stack_search_docs.h b/src/kbmod/search/pydocs/stack_search_docs.h index 0a9dea843..2ccbf73c2 100644 --- a/src/kbmod/search/pydocs/stack_search_docs.h +++ b/src/kbmod/search/pydocs/stack_search_docs.h @@ -163,7 +163,7 @@ static const auto DOC_StackSearch_get_results = R"doc( ---------- start : `int` The starting index of the results to retrieve. Returns - an empty list is start is past the end of the cache. + an empty list if start is past the end of the cache. count : `int` The maximum number of results to retrieve. Returns fewer results if there are not enough in the cache.