Skip to content

Commit

Permalink
Merge pull request acts-project#293 from beomki-yeo/detector-into-state
Browse files Browse the repository at this point in the history
Put detector in navigator_state
  • Loading branch information
beomki-yeo authored Sep 27, 2022
2 parents 8c6597b + a7e14b1 commit 01aad4a
Show file tree
Hide file tree
Showing 13 changed files with 73 additions and 67 deletions.
59 changes: 30 additions & 29 deletions core/include/detray/propagator/navigator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,26 @@ class navigator {
typename vector_type<intersection_type>::const_iterator;

public:
using detector_type = navigator::detector_type;

/// Default constructor
state() = default;

state(const detector_type &det) : _detector(&det) {}

/// Constructor with memory resource
DETRAY_HOST
state(vecmem::memory_resource &resource) : _candidates(&resource) {}
state(const detector_type &det, vecmem::memory_resource &resource)
: _detector(&det), _candidates(&resource) {}

/// Constructor from candidates vector_view
DETRAY_HOST_DEVICE state(vector_type<intersection_type> candidates)
: _candidates(candidates) {}
DETRAY_HOST_DEVICE state(const detector_type &det,
vector_type<intersection_type> candidates)
: _detector(&det), _candidates(candidates) {}

/// @returns a pointer of detector
DETRAY_HOST_DEVICE
auto detector() const { return _detector; }

/// Scalar representation of the navigation state,
/// @returns distance to next
Expand Down Expand Up @@ -357,6 +367,9 @@ class navigator {
/// Heartbeat of this navigation flow signals navigation is alive
bool _heartbeat = false;

/// Detector pointer
const detector_type *const _detector;

/// Our cache of candidates (intersections with any kind of surface)
vector_type<intersection_type> _candidates = {};

Expand Down Expand Up @@ -387,16 +400,6 @@ class navigator {
dindex _volume_index = 0;
};

/// Constructor from detector object, which is not owned by the navigator
/// and needs to be guaranteed to have a lifetime beyond that of the
/// navigator
DETRAY_HOST_DEVICE
navigator(const detector_t &d) : _detector(&d) {}

/// @returns reference to the detector
DETRAY_HOST_DEVICE
const detector_t &get_detector() const { return *_detector; }

/// Helper method to initialize a volume.
///
/// Calls the volumes accelerator structure for local navigation, then tests
Expand All @@ -410,8 +413,9 @@ class navigator {
DETRAY_HOST_DEVICE inline bool init(propagator_state_t &propagation) const {

state &navigation = propagation._navigation;
const auto det = navigation.detector();
const auto &track = propagation._stepping();
const auto &volume = _detector->volume_by_index(navigation.volume());
const auto &volume = det->volume_by_index(navigation.volume());

// Clean up state
navigation.clear();
Expand All @@ -421,11 +425,10 @@ class navigator {

// Loop over all indexed objects in volume, intersect and fill
// @todo - will come from the local object finder
const auto &tf_store = _detector->transform_store();
const auto &mask_store = _detector->mask_store();
const auto &tf_store = det->transform_store();
const auto &mask_store = det->mask_store();

for (const auto [obj_idx, obj] :
enumerate(_detector->surfaces(), volume)) {
for (const auto [obj_idx, obj] : enumerate(det->surfaces(), volume)) {

std::size_t count =
mask_store.template execute<intersection_initialize>(
Expand Down Expand Up @@ -526,6 +529,7 @@ class navigator {
propagator_state_t &propagation) const {

state &navigation = propagation._navigation;
const auto det = navigation.detector();
const auto &track = propagation._stepping();

// Current candidates are up to date, nothing left to do
Expand All @@ -539,7 +543,7 @@ class navigator {
navigation.n_candidates() == 1) {

// Update next candidate: If not reachable, 'high trust' is broken
if (not update_candidate(*navigation.next(), track)) {
if (not update_candidate(*navigation.next(), track, det)) {
navigation.set_state(navigation::status::e_unknown,
dindex_invalid,
navigation::trust_level::e_no_trust);
Expand All @@ -565,7 +569,7 @@ class navigator {

// Else: Track is on module.
// Ready the next candidate after the current module
if (update_candidate(*navigation.next(), track)) {
if (update_candidate(*navigation.next(), track, det)) {
return;
}

Expand All @@ -581,7 +585,7 @@ class navigator {

for (auto &candidate : navigation.candidates()) {
// Disregard this candidate if it is not reachable
if (not update_candidate(candidate, track)) {
if (not update_candidate(candidate, track, det)) {
// Forcefully set dist to numeric max for sorting
candidate.path = std::numeric_limits<scalar>::max();
}
Expand Down Expand Up @@ -671,15 +675,15 @@ class navigator {
/// @returns whether the track can reach this candidate.
template <typename track_t>
DETRAY_HOST_DEVICE inline bool update_candidate(
intersection_type &candidate, const track_t &track) const {
intersection_type &candidate, const track_t &track,
const detector_type *det) const {
// Remember the surface this candidate belongs to
const dindex obj_idx = candidate.index;

const auto &mask_store = _detector->mask_store();
const auto &sf = _detector->surface_by_index(obj_idx);
const auto &mask_store = det->mask_store();
const auto &sf = det->surface_by_index(obj_idx);
candidate = mask_store.template execute<intersection_update>(
sf.mask_type(), detail::ray(track), sf,
_detector->transform_store());
sf.mask_type(), detail::ray(track), sf, det->transform_store());

candidate.index = obj_idx;
// Check whether this candidate is reachable by the track
Expand All @@ -702,9 +706,6 @@ class navigator {
return detail::find_if(candidates.begin(), candidates.end(),
not_reachable);
}

/// the containers for all data
const detector_t *const _detector;
};

/// @return the vecmem jagged vector buffer for surface candidates
Expand Down
8 changes: 5 additions & 3 deletions core/include/detray/propagator/propagator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,22 @@ struct propagator {
/// @param candidates buffer for intersections in the navigator
DETRAY_HOST_DEVICE state(
const free_track_parameters_type &t_in,
const typename navigator_t::detector_type &det,
typename actor_chain_t::state actor_states = {},
vector_type<line_plane_intersection> &&candidates = {})
: _stepping(t_in),
_navigation(std::move(candidates)),
_navigation(det, std::move(candidates)),
_actor_states(actor_states) {}

/// Construct the propagation state with bound parameter
DETRAY_HOST_DEVICE state(
const bound_track_parameters_type &param,
const transform3_type &trf3,
const typename stepper_t::transform3_type &trf3,
const typename navigator_t::detector_type &det,
typename actor_chain_t::state actor_states = {},
vector_type<line_plane_intersection> &&candidates = {})
: _stepping(param, trf3),
_navigation(std::move(candidates)),
_navigation(det, std::move(candidates)),
_actor_states(actor_states) {}

// Is the propagation still alive?
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/cuda/benchmark_propagator_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ static void BM_PROPAGATOR_CPU(benchmark::State &state) {
rk_stepper_type s(B_field);

// Create navigator
navigator_host_type n(det);
navigator_host_type n;

// Create propagator
propagator_host_type p(std::move(s), std::move(n));
Expand All @@ -75,7 +75,7 @@ static void BM_PROPAGATOR_CPU(benchmark::State &state) {
for (auto &track : tracks) {

// Create the propagator state
propagator_host_type::state p_state(track);
propagator_host_type::state p_state(track, det);

// Run propagation
p.propagate(p_state);
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/cuda/benchmark_propagator_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ __global__ void propagator_benchmark_kernel(
rk_stepper_type s(B_field);

// Create navigator
navigator_device_type n(det);
navigator_device_type n;

// Create propagator
propagator_device_type p(std::move(s), std::move(n));

// Create the propagator state
propagator_device_type::state p_state(
tracks.at(gid), actor_chain<>::state{}, candidates.at(gid));
tracks.at(gid), det, actor_chain<>::state{}, candidates.at(gid));

// Run propagation
p.propagate(p_state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TEST(ALGEBRA_PLUGIN, straight_line_navigation) {
using propagator_t = propagator<stepper_t, navigator_t, actor_chain<>>;

// Propagator
propagator_t prop(stepper_t{}, navigator_t{det});
propagator_t prop(stepper_t{}, navigator_t{});

constexpr std::size_t theta_steps{50};
constexpr std::size_t phi_steps{50};
Expand All @@ -75,7 +75,7 @@ TEST(ALGEBRA_PLUGIN, straight_line_navigation) {
// Now follow that ray with a track and check, if we find the same
// volumes and distances along the way
free_track_parameters_type track(ray.pos(), 0, ray.dir(), -1);
propagator_t::state propagation(track);
propagator_t::state propagation(track, det);

// Retrieve navigation information
auto &inspector = propagation._navigation.inspector();
Expand Down Expand Up @@ -141,7 +141,7 @@ TEST(ALGEBRA_PLUGIN, helix_navigation) {
const vector3 B{0. * unit_constants::T, 0. * unit_constants::T,
2. * unit_constants::T};
b_field_t b_field(B);
propagator_t prop(stepper_t{b_field}, navigator_t{det});
propagator_t prop(stepper_t{b_field}, navigator_t{});

constexpr std::size_t theta_steps{10};
constexpr std::size_t phi_steps{10};
Expand All @@ -168,7 +168,7 @@ TEST(ALGEBRA_PLUGIN, helix_navigation) {

// Now follow that helix with the same track and check, if we find
// the same volumes and distances along the way
propagator_t::state propagation(track);
propagator_t::state propagation(track, det);

// Retrieve navigation information
auto &inspector = propagation._navigation.inspector();
Expand Down
21 changes: 12 additions & 9 deletions tests/common/include/tests/common/test_telescope_detector.inl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ struct prop_state {
navigation_t _navigation;

template <typename track_t>
prop_state(const track_t &t_in) : _stepping(t_in) {}
prop_state(const track_t &t_in,
const typename navigation_t::detector_type &det)
: _stepping(t_in), _navigation(det) {}
};

} // anonymous namespace
Expand Down Expand Up @@ -119,18 +121,19 @@ TEST(ALGEBRA_PLUGIN, telescope_detector) {
free_track_parameters<transform3> test_track_x(pos, 0, mom, -1);

// navigators
navigator<decltype(z_tel_det1), inspector_t> navigator_z1(z_tel_det1);
navigator<decltype(z_tel_det2), inspector_t> navigator_z2(z_tel_det2);
navigator<decltype(x_tel_det), inspector_t> navigator_x(x_tel_det);
navigator<decltype(z_tel_det1), inspector_t> navigator_z1;
navigator<decltype(z_tel_det2), inspector_t> navigator_z2;
navigator<decltype(x_tel_det), inspector_t> navigator_x;
using navigation_state_t = decltype(navigator_z1)::state;
using stepping_state_t = rk_stepper_t::state;

// propagation states
prop_state<stepping_state_t, navigation_state_t> propgation_z1(
test_track_z1);
test_track_z1, z_tel_det1);
prop_state<stepping_state_t, navigation_state_t> propgation_z2(
test_track_z2);
prop_state<stepping_state_t, navigation_state_t> propgation_x(test_track_x);
test_track_z2, z_tel_det2);
prop_state<stepping_state_t, navigation_state_t> propgation_x(test_track_x,
x_tel_det);

stepping_state_t &stepping_z1 = propgation_z1._stepping;
stepping_state_t &stepping_z2 = propgation_z2._stepping;
Expand Down Expand Up @@ -201,10 +204,10 @@ TEST(ALGEBRA_PLUGIN, telescope_detector) {
host_mr, n_surfaces, tel_length, pilot_track, rk_stepper_z);

// make at least sure it is navigatable
navigator<decltype(tel_detector), inspector_t> tel_navigator(tel_detector);
navigator<decltype(tel_detector), inspector_t> tel_navigator;

prop_state<stepping_state_t, navigation_state_t> tel_propagation(
pilot_track);
pilot_track, tel_detector);
navigation_state_t &tel_navigation = tel_propagation._navigation;

// run propagation
Expand Down
5 changes: 2 additions & 3 deletions tests/common/include/tests/common/tools_guided_navigator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ TEST(ALGEBRA_PLUGIN, guided_navigator) {
pathlimit_aborter::state pathlimit{200. * unit_constants::cm};

// Propagator
propagator_t p(runge_kutta_stepper{b_field},
guided_navigator{telescope_det});
propagator_t::state guided_state(track, std::tie(pathlimit));
propagator_t p(runge_kutta_stepper{b_field}, guided_navigator{});
propagator_t::state guided_state(track, telescope_det, std::tie(pathlimit));

// Propagate
p.propagate(guided_state);
Expand Down
4 changes: 2 additions & 2 deletions tests/common/include/tests/common/tools_navigator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ TEST(ALGEBRA_PLUGIN, navigator) {
free_track_parameters<transform3> traj(pos, 0, mom, -1);

stepper_t stepper;
navigator_t nav(toy_det);
navigator_t nav;

prop_state<stepper_t::state, navigator_t::state> propagation{
stepper_t::state{traj}, navigator_t::state{}};
stepper_t::state{traj}, navigator_t::state(toy_det, host_mr)};
navigator_t::state &navigation = propagation._navigation;
stepper_t::state &stepping = propagation._stepping;

Expand Down
10 changes: 5 additions & 5 deletions tests/common/include/tests/common/tools_propagator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ TEST(ALGEBRA_PLUGIN, propagator_line_stepper) {
const vector3 mom{1., 1., 0.};
free_track_parameters<transform3> track(pos, 0, mom, -1);

propagator_t p(stepper_t{}, navigator_t{d});
propagator_t p(stepper_t{}, navigator_t{});

propagator_t::state state(track);
propagator_t::state state(track, d);

EXPECT_TRUE(p.propagate(state))
<< state._navigation.inspector().to_string() << std::endl;
Expand Down Expand Up @@ -142,7 +142,7 @@ TEST_P(PropagatorWithRkStepper, propagator_rk_stepper) {
const b_field_t b_field(B);

// Propagator is built from the stepper and navigator
propagator_t p(stepper_t{b_field}, navigator_t{d});
propagator_t p(stepper_t{b_field}, navigator_t{});

// Iterate through uniformly distributed momentum directions
for (auto track :
Expand All @@ -167,8 +167,8 @@ TEST_P(PropagatorWithRkStepper, propagator_rk_stepper) {
helix_insp_state, lim_print_insp_state, pathlimit_aborter_state);

// Init propagator states
propagator_t::state state(track, actor_states);
propagator_t::state lim_state(lim_track, lim_actor_states);
propagator_t::state state(track, d, actor_states);
propagator_t::state lim_state(lim_track, d, lim_actor_states);

// Set step constraints
state._stepping.template set_constraint<step::constraint::e_accuracy>(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/cuda/navigator_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TEST(navigator_cuda, navigator) {
n_edc_layers);

// Create navigator
navigator_host_t nav(det);
navigator_host_t nav;

// Create the vector of initial track parameters
vecmem::vector<free_track_parameters<transform3>> tracks_host(&mng_mr);
Expand Down Expand Up @@ -63,7 +63,7 @@ TEST(navigator_cuda, navigator) {
stepper_t stepper;

prop_state<navigator_host_t::state> propagation{
stepper_t::state{track}, navigator_host_t::state{mng_mr}};
stepper_t::state{track}, navigator_host_t::state(det, mng_mr)};

navigator_host_t::state& navigation = propagation._navigation;
stepper_t::state& stepping = propagation._stepping;
Expand Down
5 changes: 3 additions & 2 deletions tests/unit_tests/cuda/navigator_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ __global__ void navigator_test_kernel(
return;
}

navigator_device_t nav(det);
navigator_device_t nav;

auto& traj = tracks.at(gid);
stepper_t stepper;

prop_state<navigator_device_t::state> propagation{
stepper_t::state{traj}, navigator_device_t::state{candidates.at(gid)}};
stepper_t::state{traj},
navigator_device_t::state(det, candidates.at(gid))};

navigator_device_t::state& navigation = propagation._navigation;
stepper_t::state& stepping = propagation._stepping;
Expand Down
Loading

0 comments on commit 01aad4a

Please sign in to comment.