diff --git a/core/include/detray/propagator/navigator.hpp b/core/include/detray/propagator/navigator.hpp index 6124f4b9bf..4f627fb6fe 100644 --- a/core/include/detray/propagator/navigator.hpp +++ b/core/include/detray/propagator/navigator.hpp @@ -109,16 +109,26 @@ class navigator { typename vector_type::const_iterator; public: + using detector_type = navigator::detector_type; + /// Default constructor state() = default; + state(const detector_type &det) : _detector(&det) {} + /// Constructor with memory resource DETRAY_HOST - state(vecmem::memory_resource &resource) : _candidates(&resource) {} + state(const detector_type &det, vecmem::memory_resource &resource) + : _detector(&det), _candidates(&resource) {} /// Constructor from candidates vector_view - DETRAY_HOST_DEVICE state(vector_type candidates) - : _candidates(candidates) {} + DETRAY_HOST_DEVICE state(const detector_type &det, + vector_type candidates) + : _detector(&det), _candidates(candidates) {} + + /// @returns a pointer of detector + DETRAY_HOST_DEVICE + auto detector() const { return _detector; } /// Scalar representation of the navigation state, /// @returns distance to next @@ -357,6 +367,9 @@ class navigator { /// Heartbeat of this navigation flow signals navigation is alive bool _heartbeat = false; + /// Detector pointer + const detector_type *const _detector; + /// Our cache of candidates (intersections with any kind of surface) vector_type _candidates = {}; @@ -387,16 +400,6 @@ class navigator { dindex _volume_index = 0; }; - /// Constructor from detector object, which is not owned by the navigator - /// and needs to be guaranteed to have a lifetime beyond that of the - /// navigator - DETRAY_HOST_DEVICE - navigator(const detector_t &d) : _detector(&d) {} - - /// @returns reference to the detector - DETRAY_HOST_DEVICE - const detector_t &get_detector() const { return *_detector; } - /// Helper method to initialize a volume. /// /// Calls the volumes accelerator structure for local navigation, then tests @@ -410,8 +413,9 @@ class navigator { DETRAY_HOST_DEVICE inline bool init(propagator_state_t &propagation) const { state &navigation = propagation._navigation; + const auto det = navigation.detector(); const auto &track = propagation._stepping(); - const auto &volume = _detector->volume_by_index(navigation.volume()); + const auto &volume = det->volume_by_index(navigation.volume()); // Clean up state navigation.clear(); @@ -421,11 +425,10 @@ class navigator { // Loop over all indexed objects in volume, intersect and fill // @todo - will come from the local object finder - const auto &tf_store = _detector->transform_store(); - const auto &mask_store = _detector->mask_store(); + const auto &tf_store = det->transform_store(); + const auto &mask_store = det->mask_store(); - for (const auto [obj_idx, obj] : - enumerate(_detector->surfaces(), volume)) { + for (const auto [obj_idx, obj] : enumerate(det->surfaces(), volume)) { std::size_t count = mask_store.template execute( @@ -526,6 +529,7 @@ class navigator { propagator_state_t &propagation) const { state &navigation = propagation._navigation; + const auto det = navigation.detector(); const auto &track = propagation._stepping(); // Current candidates are up to date, nothing left to do @@ -539,7 +543,7 @@ class navigator { navigation.n_candidates() == 1) { // Update next candidate: If not reachable, 'high trust' is broken - if (not update_candidate(*navigation.next(), track)) { + if (not update_candidate(*navigation.next(), track, det)) { navigation.set_state(navigation::status::e_unknown, dindex_invalid, navigation::trust_level::e_no_trust); @@ -565,7 +569,7 @@ class navigator { // Else: Track is on module. // Ready the next candidate after the current module - if (update_candidate(*navigation.next(), track)) { + if (update_candidate(*navigation.next(), track, det)) { return; } @@ -581,7 +585,7 @@ class navigator { for (auto &candidate : navigation.candidates()) { // Disregard this candidate if it is not reachable - if (not update_candidate(candidate, track)) { + if (not update_candidate(candidate, track, det)) { // Forcefully set dist to numeric max for sorting candidate.path = std::numeric_limits::max(); } @@ -671,15 +675,15 @@ class navigator { /// @returns whether the track can reach this candidate. template DETRAY_HOST_DEVICE inline bool update_candidate( - intersection_type &candidate, const track_t &track) const { + intersection_type &candidate, const track_t &track, + const detector_type *det) const { // Remember the surface this candidate belongs to const dindex obj_idx = candidate.index; - const auto &mask_store = _detector->mask_store(); - const auto &sf = _detector->surface_by_index(obj_idx); + const auto &mask_store = det->mask_store(); + const auto &sf = det->surface_by_index(obj_idx); candidate = mask_store.template execute( - sf.mask_type(), detail::ray(track), sf, - _detector->transform_store()); + sf.mask_type(), detail::ray(track), sf, det->transform_store()); candidate.index = obj_idx; // Check whether this candidate is reachable by the track @@ -702,9 +706,6 @@ class navigator { return detail::find_if(candidates.begin(), candidates.end(), not_reachable); } - - /// the containers for all data - const detector_t *const _detector; }; /// @return the vecmem jagged vector buffer for surface candidates diff --git a/core/include/detray/propagator/propagator.hpp b/core/include/detray/propagator/propagator.hpp index 4e3ac9b978..c8ddfdcd0a 100644 --- a/core/include/detray/propagator/propagator.hpp +++ b/core/include/detray/propagator/propagator.hpp @@ -59,20 +59,22 @@ struct propagator { /// @param candidates buffer for intersections in the navigator DETRAY_HOST_DEVICE state( const free_track_parameters_type &t_in, + const typename navigator_t::detector_type &det, typename actor_chain_t::state actor_states = {}, vector_type &&candidates = {}) : _stepping(t_in), - _navigation(std::move(candidates)), + _navigation(det, std::move(candidates)), _actor_states(actor_states) {} /// Construct the propagation state with bound parameter DETRAY_HOST_DEVICE state( const bound_track_parameters_type ¶m, - const transform3_type &trf3, + const typename stepper_t::transform3_type &trf3, + const typename navigator_t::detector_type &det, typename actor_chain_t::state actor_states = {}, vector_type &&candidates = {}) : _stepping(param, trf3), - _navigation(std::move(candidates)), + _navigation(det, std::move(candidates)), _actor_states(actor_states) {} // Is the propagation still alive? diff --git a/tests/benchmarks/cuda/benchmark_propagator_cuda.cpp b/tests/benchmarks/cuda/benchmark_propagator_cuda.cpp index 2c082ada82..30ec0e8416 100644 --- a/tests/benchmarks/cuda/benchmark_propagator_cuda.cpp +++ b/tests/benchmarks/cuda/benchmark_propagator_cuda.cpp @@ -56,7 +56,7 @@ static void BM_PROPAGATOR_CPU(benchmark::State &state) { rk_stepper_type s(B_field); // Create navigator - navigator_host_type n(det); + navigator_host_type n; // Create propagator propagator_host_type p(std::move(s), std::move(n)); @@ -75,7 +75,7 @@ static void BM_PROPAGATOR_CPU(benchmark::State &state) { for (auto &track : tracks) { // Create the propagator state - propagator_host_type::state p_state(track); + propagator_host_type::state p_state(track, det); // Run propagation p.propagate(p_state); diff --git a/tests/benchmarks/cuda/benchmark_propagator_cuda_kernel.cu b/tests/benchmarks/cuda/benchmark_propagator_cuda_kernel.cu index 3a09911171..8b31964cd3 100644 --- a/tests/benchmarks/cuda/benchmark_propagator_cuda_kernel.cu +++ b/tests/benchmarks/cuda/benchmark_propagator_cuda_kernel.cu @@ -34,14 +34,14 @@ __global__ void propagator_benchmark_kernel( rk_stepper_type s(B_field); // Create navigator - navigator_device_type n(det); + navigator_device_type n; // Create propagator propagator_device_type p(std::move(s), std::move(n)); // Create the propagator state propagator_device_type::state p_state( - tracks.at(gid), actor_chain<>::state{}, candidates.at(gid)); + tracks.at(gid), det, actor_chain<>::state{}, candidates.at(gid)); // Run propagation p.propagate(p_state); diff --git a/tests/common/include/tests/common/check_geometry_navigation.inl b/tests/common/include/tests/common/check_geometry_navigation.inl index b7e1011823..e5566e6370 100644 --- a/tests/common/include/tests/common/check_geometry_navigation.inl +++ b/tests/common/include/tests/common/check_geometry_navigation.inl @@ -56,7 +56,7 @@ TEST(ALGEBRA_PLUGIN, straight_line_navigation) { using propagator_t = propagator>; // Propagator - propagator_t prop(stepper_t{}, navigator_t{det}); + propagator_t prop(stepper_t{}, navigator_t{}); constexpr std::size_t theta_steps{50}; constexpr std::size_t phi_steps{50}; @@ -75,7 +75,7 @@ TEST(ALGEBRA_PLUGIN, straight_line_navigation) { // Now follow that ray with a track and check, if we find the same // volumes and distances along the way free_track_parameters_type track(ray.pos(), 0, ray.dir(), -1); - propagator_t::state propagation(track); + propagator_t::state propagation(track, det); // Retrieve navigation information auto &inspector = propagation._navigation.inspector(); @@ -141,7 +141,7 @@ TEST(ALGEBRA_PLUGIN, helix_navigation) { const vector3 B{0. * unit_constants::T, 0. * unit_constants::T, 2. * unit_constants::T}; b_field_t b_field(B); - propagator_t prop(stepper_t{b_field}, navigator_t{det}); + propagator_t prop(stepper_t{b_field}, navigator_t{}); constexpr std::size_t theta_steps{10}; constexpr std::size_t phi_steps{10}; @@ -168,7 +168,7 @@ TEST(ALGEBRA_PLUGIN, helix_navigation) { // Now follow that helix with the same track and check, if we find // the same volumes and distances along the way - propagator_t::state propagation(track); + propagator_t::state propagation(track, det); // Retrieve navigation information auto &inspector = propagation._navigation.inspector(); diff --git a/tests/common/include/tests/common/test_telescope_detector.inl b/tests/common/include/tests/common/test_telescope_detector.inl index 4aea832967..7b621228eb 100644 --- a/tests/common/include/tests/common/test_telescope_detector.inl +++ b/tests/common/include/tests/common/test_telescope_detector.inl @@ -34,7 +34,9 @@ struct prop_state { navigation_t _navigation; template - prop_state(const track_t &t_in) : _stepping(t_in) {} + prop_state(const track_t &t_in, + const typename navigation_t::detector_type &det) + : _stepping(t_in), _navigation(det) {} }; } // anonymous namespace @@ -119,18 +121,19 @@ TEST(ALGEBRA_PLUGIN, telescope_detector) { free_track_parameters test_track_x(pos, 0, mom, -1); // navigators - navigator navigator_z1(z_tel_det1); - navigator navigator_z2(z_tel_det2); - navigator navigator_x(x_tel_det); + navigator navigator_z1; + navigator navigator_z2; + navigator navigator_x; using navigation_state_t = decltype(navigator_z1)::state; using stepping_state_t = rk_stepper_t::state; // propagation states prop_state propgation_z1( - test_track_z1); + test_track_z1, z_tel_det1); prop_state propgation_z2( - test_track_z2); - prop_state propgation_x(test_track_x); + test_track_z2, z_tel_det2); + prop_state propgation_x(test_track_x, + x_tel_det); stepping_state_t &stepping_z1 = propgation_z1._stepping; stepping_state_t &stepping_z2 = propgation_z2._stepping; @@ -201,10 +204,10 @@ TEST(ALGEBRA_PLUGIN, telescope_detector) { host_mr, n_surfaces, tel_length, pilot_track, rk_stepper_z); // make at least sure it is navigatable - navigator tel_navigator(tel_detector); + navigator tel_navigator; prop_state tel_propagation( - pilot_track); + pilot_track, tel_detector); navigation_state_t &tel_navigation = tel_propagation._navigation; // run propagation diff --git a/tests/common/include/tests/common/tools_guided_navigator.inl b/tests/common/include/tests/common/tools_guided_navigator.inl index 08a06bdfe9..9e3eb84bb1 100644 --- a/tests/common/include/tests/common/tools_guided_navigator.inl +++ b/tests/common/include/tests/common/tools_guided_navigator.inl @@ -64,9 +64,8 @@ TEST(ALGEBRA_PLUGIN, guided_navigator) { pathlimit_aborter::state pathlimit{200. * unit_constants::cm}; // Propagator - propagator_t p(runge_kutta_stepper{b_field}, - guided_navigator{telescope_det}); - propagator_t::state guided_state(track, std::tie(pathlimit)); + propagator_t p(runge_kutta_stepper{b_field}, guided_navigator{}); + propagator_t::state guided_state(track, telescope_det, std::tie(pathlimit)); // Propagate p.propagate(guided_state); diff --git a/tests/common/include/tests/common/tools_navigator.inl b/tests/common/include/tests/common/tools_navigator.inl index ebe7a528e4..2cd6366c59 100644 --- a/tests/common/include/tests/common/tools_navigator.inl +++ b/tests/common/include/tests/common/tools_navigator.inl @@ -130,10 +130,10 @@ TEST(ALGEBRA_PLUGIN, navigator) { free_track_parameters traj(pos, 0, mom, -1); stepper_t stepper; - navigator_t nav(toy_det); + navigator_t nav; prop_state propagation{ - stepper_t::state{traj}, navigator_t::state{}}; + stepper_t::state{traj}, navigator_t::state(toy_det, host_mr)}; navigator_t::state &navigation = propagation._navigation; stepper_t::state &stepping = propagation._stepping; diff --git a/tests/common/include/tests/common/tools_propagator.inl b/tests/common/include/tests/common/tools_propagator.inl index 0b32dd9c37..53b42c17c6 100644 --- a/tests/common/include/tests/common/tools_propagator.inl +++ b/tests/common/include/tests/common/tools_propagator.inl @@ -91,9 +91,9 @@ TEST(ALGEBRA_PLUGIN, propagator_line_stepper) { const vector3 mom{1., 1., 0.}; free_track_parameters track(pos, 0, mom, -1); - propagator_t p(stepper_t{}, navigator_t{d}); + propagator_t p(stepper_t{}, navigator_t{}); - propagator_t::state state(track); + propagator_t::state state(track, d); EXPECT_TRUE(p.propagate(state)) << state._navigation.inspector().to_string() << std::endl; @@ -142,7 +142,7 @@ TEST_P(PropagatorWithRkStepper, propagator_rk_stepper) { const b_field_t b_field(B); // Propagator is built from the stepper and navigator - propagator_t p(stepper_t{b_field}, navigator_t{d}); + propagator_t p(stepper_t{b_field}, navigator_t{}); // Iterate through uniformly distributed momentum directions for (auto track : @@ -167,8 +167,8 @@ TEST_P(PropagatorWithRkStepper, propagator_rk_stepper) { helix_insp_state, lim_print_insp_state, pathlimit_aborter_state); // Init propagator states - propagator_t::state state(track, actor_states); - propagator_t::state lim_state(lim_track, lim_actor_states); + propagator_t::state state(track, d, actor_states); + propagator_t::state lim_state(lim_track, d, lim_actor_states); // Set step constraints state._stepping.template set_constraint( diff --git a/tests/unit_tests/cuda/navigator_cuda.cpp b/tests/unit_tests/cuda/navigator_cuda.cpp index 87a181d419..6ae473a77a 100644 --- a/tests/unit_tests/cuda/navigator_cuda.cpp +++ b/tests/unit_tests/cuda/navigator_cuda.cpp @@ -29,7 +29,7 @@ TEST(navigator_cuda, navigator) { n_edc_layers); // Create navigator - navigator_host_t nav(det); + navigator_host_t nav; // Create the vector of initial track parameters vecmem::vector> tracks_host(&mng_mr); @@ -63,7 +63,7 @@ TEST(navigator_cuda, navigator) { stepper_t stepper; prop_state propagation{ - stepper_t::state{track}, navigator_host_t::state{mng_mr}}; + stepper_t::state{track}, navigator_host_t::state(det, mng_mr)}; navigator_host_t::state& navigation = propagation._navigation; stepper_t::state& stepping = propagation._stepping; diff --git a/tests/unit_tests/cuda/navigator_cuda_kernel.cu b/tests/unit_tests/cuda/navigator_cuda_kernel.cu index cdc00bf4c6..f1860441d6 100644 --- a/tests/unit_tests/cuda/navigator_cuda_kernel.cu +++ b/tests/unit_tests/cuda/navigator_cuda_kernel.cu @@ -31,13 +31,14 @@ __global__ void navigator_test_kernel( return; } - navigator_device_t nav(det); + navigator_device_t nav; auto& traj = tracks.at(gid); stepper_t stepper; prop_state propagation{ - stepper_t::state{traj}, navigator_device_t::state{candidates.at(gid)}}; + stepper_t::state{traj}, + navigator_device_t::state(det, candidates.at(gid))}; navigator_device_t::state& navigation = propagation._navigation; stepper_t::state& stepping = propagation._stepping; diff --git a/tests/unit_tests/cuda/propagator_cuda.cpp b/tests/unit_tests/cuda/propagator_cuda.cpp index 0d05f90d2c..2fdb9a033e 100644 --- a/tests/unit_tests/cuda/propagator_cuda.cpp +++ b/tests/unit_tests/cuda/propagator_cuda.cpp @@ -60,7 +60,7 @@ TEST_P(CudaPropagatorWithRkStepper, propagator) { // Create RK stepper rk_stepper_type s(B_field); // Create navigator - navigator_host_type n(det); + navigator_host_type n; // Create propagator propagator_host_type p(std::move(s), std::move(n)); @@ -76,7 +76,7 @@ TEST_P(CudaPropagatorWithRkStepper, propagator) { pathlimit_aborter::state pathlimit_state{path_limit}; propagator_host_type::state state( - tracks_host[i], thrust::tie(insp_state, pathlimit_state)); + tracks_host[i], det, thrust::tie(insp_state, pathlimit_state)); state._stepping.template set_constraint( constrainted_step_size); diff --git a/tests/unit_tests/cuda/propagator_cuda_kernel.cu b/tests/unit_tests/cuda/propagator_cuda_kernel.cu index 0c0a8d7ad3..edd44c0a99 100644 --- a/tests/unit_tests/cuda/propagator_cuda_kernel.cu +++ b/tests/unit_tests/cuda/propagator_cuda_kernel.cu @@ -40,7 +40,7 @@ __global__ void propagator_test_kernel( rk_stepper_type s(B_field); // Create navigator - navigator_device_type n(det); + navigator_device_type n; // Create propagator propagator_device_type p(std::move(s), std::move(n)); @@ -51,7 +51,7 @@ __global__ void propagator_test_kernel( pathlimit_aborter::state aborter_state{path_limit}; // Create the propagator state - propagator_device_type::state state(tracks[gid], + propagator_device_type::state state(tracks[gid], det, thrust::tie(insp_state, aborter_state), candidates.at(gid));