diff --git a/core/include/detray/core/detail/container_buffers.hpp b/core/include/detray/core/detail/container_buffers.hpp index 0c5f3998f..e2e9d82d8 100644 --- a/core/include/detray/core/detail/container_buffers.hpp +++ b/core/include/detray/core/detail/container_buffers.hpp @@ -126,6 +126,19 @@ struct has_buffer, void> : public std::true_type { using type = vecmem::data::vector_buffer; }; +/// Specialization of the buffer getter for @c vecmem::device_vector +template +struct has_buffer, void> : public std::true_type { + using type = vecmem::data::vector_buffer; +}; + +/// Specialization of the buffer getter for @c vecmem::device_vector - const +template +struct has_buffer, void> + : public std::true_type { + using type = vecmem::data::vector_buffer; +}; + template inline constexpr bool has_buffer_v = has_buffer::value; diff --git a/core/include/detray/core/detail/container_views.hpp b/core/include/detray/core/detail/container_views.hpp index e49ccae97..833221f2b 100644 --- a/core/include/detray/core/detail/container_views.hpp +++ b/core/include/detray/core/detail/container_views.hpp @@ -169,12 +169,26 @@ struct detail::has_view, void> : public std::true_type { using type = dvector_view; }; -/// Specialization of the view getter for @c vecmem::vector +/// Specialization of the view getter for @c vecmem::vector - const template struct detail::has_view, void> : public std::true_type { using type = dvector_view; }; +/// Specialization of the view getter for @c vecmem::device_vector +template +struct detail::has_view, void> + : public std::true_type { + using type = dvector_view; +}; + +/// Specialization of the view getter for @c vecmem::device_vector - const +template +struct detail::has_view, void> + : public std::true_type { + using type = dvector_view; +}; + /// Specialized view for @c vecmem::jagged_vector containers template using djagged_vector_view = vecmem::data::jagged_vector_view; @@ -198,13 +212,27 @@ struct detail::has_view, void> using type = djagged_vector_view; }; -/// Specialization of the view getter for @c vecmem::jagged_vector +/// Specialization of the view getter for @c vecmem::jagged_vector - const template struct detail::has_view, void> : public std::true_type { using type = djagged_vector_view; }; +/// Specialization of the view getter for @c vecmem::jagged_device_vector +template +struct detail::has_view, void> + : public std::true_type { + using type = djagged_vector_view; +}; + +/// Specialization of the view getter for @c vecmem::jagged_device_vector +/// - const +template +struct detail::has_view, void> + : public std::true_type { + using type = djagged_vector_view; +}; /// @} } // namespace detray diff --git a/core/include/detray/core/detector.hpp b/core/include/detray/core/detector.hpp index 9dbcd4391..10d06c311 100644 --- a/core/include/detray/core/detector.hpp +++ b/core/include/detray/core/detector.hpp @@ -131,6 +131,9 @@ class detector { typename surface_container::view_type, dvector_view, typename volume_finder::view_type>; + static_assert(detail::is_device_view_v, + "Detector view type ill-formed"); + using const_view_type = dmulti_view, typename transform_container::const_view_type, @@ -140,6 +143,9 @@ class detector { dvector_view, typename volume_finder::const_view_type>; + static_assert(detail::is_device_view_v, + "Detector const view type ill-formed"); + /// Detector buffer types using buffer_type = dmulti_buffer< dvector_buffer, typename transform_container::buffer_type, @@ -148,6 +154,9 @@ class detector { typename surface_container::buffer_type, dvector_buffer, typename volume_finder::buffer_type>; + static_assert(detail::is_buffer_v, + "Detector buffer type ill-formed"); + detector() = delete; // The detector holds a lot of data and should never be copied detector(const detector &) = delete; diff --git a/tests/common/include/tests/common/test_base/propagator_test.hpp b/tests/common/include/tests/common/test_base/propagator_test.hpp index b1286696a..98fe1fc69 100644 --- a/tests/common/include/tests/common/test_base/propagator_test.hpp +++ b/tests/common/include/tests/common/test_base/propagator_test.hpp @@ -10,8 +10,6 @@ // Project include(s). #include "detray/definitions/algebra.hpp" #include "detray/definitions/units.hpp" -#include "detray/detectors/bfield.hpp" -#include "detray/detectors/toy_metadata.hpp" #include "detray/propagator/actor_chain.hpp" #include "detray/propagator/actors/aborters.hpp" #include "detray/propagator/actors/parameter_resetter.hpp" @@ -27,6 +25,9 @@ // Vecmem include(s) #include +// Covfie include(s) +#include + // GTest include(s). #include @@ -35,14 +36,8 @@ namespace detray { -// Host detector type -using detector_host_t = detector; - -// Device detector type using views -using detector_device_t = detector; - // These types are identical in host and device code for all bfield types -using transform3 = typename detector_host_t::transform3; +using transform3 = __plugin::transform3; using vector3_t = typename transform3::vector3; using point3_t = typename transform3::point3; using matrix_operator = standard_matrix_operator; @@ -147,9 +142,9 @@ inline vecmem::vector generate_tracks( } /// Test function for propagator on the host -template +template inline auto run_propagation_host(vecmem::memory_resource *mr, - const detector_host_t &det, + const host_detector_t &det, covfie::field &field, const vecmem::vector &tracks) -> std::tuple, @@ -158,7 +153,7 @@ inline auto run_propagation_host(vecmem::memory_resource *mr, // Construct propagator from stepper and navigator auto stepr = rk_stepper_t::view_t>{}; - auto nav = navigator_t{}; + auto nav = navigator_t{}; using propagator_host_t = propagator; diff --git a/tests/unit_tests/device/cuda/propagator_cuda_kernel.cu b/tests/unit_tests/device/cuda/propagator_cuda_kernel.cu index e44f29d0d..98c8bf5f3 100644 --- a/tests/unit_tests/device/cuda/propagator_cuda_kernel.cu +++ b/tests/unit_tests/device/cuda/propagator_cuda_kernel.cu @@ -22,6 +22,12 @@ __global__ void propagator_test_kernel( vecmem::data::jagged_vector_view jac_transports_data) { int gid = threadIdx.x + blockIdx.x * blockDim.x; + using detector_device_t = + detector; + + static_assert(std::is_same_v, + "Host and device detector views do not match"); detector_device_t det(det_data); vecmem::device_vector tracks(tracks_data); @@ -96,19 +102,25 @@ void propagator_test( } /// Explicit instantiation for a constant magnetic field -template void propagator_test( - detector_host_t::view_type, covfie::field_view, +template void propagator_test>( + detector::view_type, + covfie::field_view, vecmem::data::vector_view&, - vecmem::data::jagged_vector_view>&, + vecmem::data::jagged_vector_view< + intersection_t>>&, vecmem::data::jagged_vector_view&, vecmem::data::jagged_vector_view&, vecmem::data::jagged_vector_view&); /// Explicit instantiation for an inhomogeneous magnetic field -template void propagator_test( - detector_host_t::view_type, covfie::field_view, +template void propagator_test>( + detector::view_type, + covfie::field_view, vecmem::data::vector_view&, - vecmem::data::jagged_vector_view>&, + vecmem::data::jagged_vector_view< + intersection_t>>&, vecmem::data::jagged_vector_view&, vecmem::data::jagged_vector_view&, vecmem::data::jagged_vector_view&); diff --git a/tests/unit_tests/device/cuda/propagator_cuda_kernel.hpp b/tests/unit_tests/device/cuda/propagator_cuda_kernel.hpp index 8200fe648..7274e0ce9 100644 --- a/tests/unit_tests/device/cuda/propagator_cuda_kernel.hpp +++ b/tests/unit_tests/device/cuda/propagator_cuda_kernel.hpp @@ -8,6 +8,8 @@ #pragma once // Project include(s) +#include "detray/detectors/bfield.hpp" +#include "detray/detectors/toy_metadata.hpp" #include "tests/common/test_base/propagator_test.hpp" // Vecmem include(s) diff --git a/tests/unit_tests/device/sycl/propagator_kernel.sycl b/tests/unit_tests/device/sycl/propagator_kernel.sycl index 1e0f5d043..f605f3f51 100644 --- a/tests/unit_tests/device/sycl/propagator_kernel.sycl +++ b/tests/unit_tests/device/sycl/propagator_kernel.sycl @@ -36,6 +36,15 @@ void propagator_test( candidates_data, path_lengths_data, positions_data, jac_transports_data]( ::sycl::nd_item<1> item) { + using detector_device_t = + detector; + + static_assert( + std::is_same_v, + "Host and device detector views do not match"); + detector_device_t dev_det(det_data); vecmem::device_vector tracks(tracks_data); @@ -94,22 +103,26 @@ void propagator_test( } /// Explicit instantiation for a constant magnetic field -template void propagator_test( - detector_host_t::view_type, covfie::field_view, +template void propagator_test>( + detector::view_type, + covfie::field_view, vecmem::data::vector_view&, - vecmem::data::jagged_vector_view>&, + vecmem::data::jagged_vector_view< + intersection_t>>&, vecmem::data::jagged_vector_view&, vecmem::data::jagged_vector_view&, vecmem::data::jagged_vector_view&, detray::sycl::queue_wrapper); /// Explicit instantiation for an inhomogeneous magnetic field -/*template void propagator_test( - detector_host_t::view_type, +/*template void propagator_test>( detector::view_type, covfie::field_view, vecmem::data::vector_view&, - vecmem::data::jagged_vector_view>&, - vecmem::data::jagged_vector_view&, + vecmem::data::jagged_vector_view>>&, vecmem::data::jagged_vector_view&, vecmem::data::jagged_vector_view&, vecmem::data::jagged_vector_view&, detray::sycl::queue_wrapper);*/ diff --git a/tests/unit_tests/device/sycl/propagator_sycl_kernel.hpp b/tests/unit_tests/device/sycl/propagator_sycl_kernel.hpp index 3fd35cd1a..4f631ed27 100644 --- a/tests/unit_tests/device/sycl/propagator_sycl_kernel.hpp +++ b/tests/unit_tests/device/sycl/propagator_sycl_kernel.hpp @@ -8,6 +8,8 @@ #pragma once // Project include(s) +#include "detray/detectors/bfield.hpp" +#include "detray/detectors/toy_metadata.hpp" #include "queue_wrapper.hpp" #include "tests/common/test_base/propagator_test.hpp"