Skip to content

Commit

Permalink
Add device container specializations for the view type handling
Browse files Browse the repository at this point in the history
  • Loading branch information
niermann999 committed Oct 19, 2023
1 parent 77a8f7c commit feb511f
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 27 deletions.
13 changes: 13 additions & 0 deletions core/include/detray/core/detail/container_buffers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ struct has_buffer<const vecmem::vector<T>, void> : public std::true_type {
using type = vecmem::data::vector_buffer<const T>;
};

/// Specialization of the buffer getter for @c vecmem::device_vector
template <typename T>
struct has_buffer<vecmem::device_vector<T>, void> : public std::true_type {
using type = vecmem::data::vector_buffer<T>;
};

/// Specialization of the buffer getter for @c vecmem::device_vector - const
template <typename T>
struct has_buffer<const vecmem::device_vector<T>, void>
: public std::true_type {
using type = vecmem::data::vector_buffer<const T>;
};

template <class T>
inline constexpr bool has_buffer_v = has_buffer<T>::value;

Expand Down
32 changes: 30 additions & 2 deletions core/include/detray/core/detail/container_views.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,26 @@ struct detail::has_view<vecmem::vector<T>, void> : public std::true_type {
using type = dvector_view<T>;
};

/// Specialization of the view getter for @c vecmem::vector
/// Specialization of the view getter for @c vecmem::vector - const
template <typename T>
struct detail::has_view<const vecmem::vector<T>, void> : public std::true_type {
using type = dvector_view<const T>;
};

/// Specialization of the view getter for @c vecmem::device_vector
template <typename T>
struct detail::has_view<vecmem::device_vector<T>, void>
: public std::true_type {
using type = dvector_view<T>;
};

/// Specialization of the view getter for @c vecmem::device_vector - const
template <typename T>
struct detail::has_view<const vecmem::device_vector<T>, void>
: public std::true_type {
using type = dvector_view<const T>;
};

/// Specialized view for @c vecmem::jagged_vector containers
template <typename T>
using djagged_vector_view = vecmem::data::jagged_vector_view<T>;
Expand All @@ -198,13 +212,27 @@ struct detail::has_view<vecmem::jagged_vector<T>, void>
using type = djagged_vector_view<T>;
};

/// Specialization of the view getter for @c vecmem::jagged_vector
/// Specialization of the view getter for @c vecmem::jagged_vector - const
template <typename T>
struct detail::has_view<const vecmem::jagged_vector<T>, void>
: public std::true_type {
using type = djagged_vector_view<const T>;
};

/// Specialization of the view getter for @c vecmem::jagged_device_vector
template <typename T>
struct detail::has_view<vecmem::jagged_device_vector<T>, void>
: public std::true_type {
using type = djagged_vector_view<T>;
};

/// Specialization of the view getter for @c vecmem::jagged_device_vector
/// - const
template <typename T>
struct detail::has_view<const vecmem::jagged_device_vector<T>, void>
: public std::true_type {
using type = djagged_vector_view<const T>;
};
/// @}

} // namespace detray
9 changes: 9 additions & 0 deletions core/include/detray/core/detector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class detector {
typename surface_container::view_type, dvector_view<surface_type>,
typename volume_finder::view_type>;

static_assert(detail::is_device_view_v<view_type>,
"Detector view type ill-formed");

using const_view_type =
dmulti_view<dvector_view<const volume_type>,
typename transform_container::const_view_type,
Expand All @@ -140,6 +143,9 @@ class detector {
dvector_view<const surface_type>,
typename volume_finder::const_view_type>;

static_assert(detail::is_device_view_v<const_view_type>,
"Detector const view type ill-formed");

/// Detector buffer types
using buffer_type = dmulti_buffer<
dvector_buffer<volume_type>, typename transform_container::buffer_type,
Expand All @@ -148,6 +154,9 @@ class detector {
typename surface_container::buffer_type, dvector_buffer<surface_type>,
typename volume_finder::buffer_type>;

static_assert(detail::is_buffer_v<buffer_type>,
"Detector buffer type ill-formed");

detector() = delete;
// The detector holds a lot of data and should never be copied
detector(const detector &) = delete;
Expand Down
19 changes: 7 additions & 12 deletions tests/common/include/tests/common/test_base/propagator_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
// Project include(s).
#include "detray/definitions/algebra.hpp"
#include "detray/definitions/units.hpp"
#include "detray/detectors/bfield.hpp"
#include "detray/detectors/toy_metadata.hpp"
#include "detray/propagator/actor_chain.hpp"
#include "detray/propagator/actors/aborters.hpp"
#include "detray/propagator/actors/parameter_resetter.hpp"
Expand All @@ -27,6 +25,9 @@
// Vecmem include(s)
#include <vecmem/memory/memory_resource.hpp>

// Covfie include(s)
#include <covfie/core/field.hpp>

// GTest include(s).
#include <gtest/gtest.h>

Expand All @@ -35,14 +36,8 @@

namespace detray {

// Host detector type
using detector_host_t = detector<toy_metadata, host_container_types>;

// Device detector type using views
using detector_device_t = detector<toy_metadata, device_container_types>;

// These types are identical in host and device code for all bfield types
using transform3 = typename detector_host_t::transform3;
using transform3 = __plugin::transform3<detray::scalar>;
using vector3_t = typename transform3::vector3;
using point3_t = typename transform3::point3;
using matrix_operator = standard_matrix_operator<scalar>;
Expand Down Expand Up @@ -147,9 +142,9 @@ inline vecmem::vector<track_t> generate_tracks(
}

/// Test function for propagator on the host
template <typename bfield_bknd_t>
template <typename bfield_bknd_t, typename host_detector_t>
inline auto run_propagation_host(vecmem::memory_resource *mr,
const detector_host_t &det,
const host_detector_t &det,
covfie::field<bfield_bknd_t> &field,
const vecmem::vector<track_t> &tracks)
-> std::tuple<vecmem::jagged_vector<scalar>,
Expand All @@ -158,7 +153,7 @@ inline auto run_propagation_host(vecmem::memory_resource *mr,

// Construct propagator from stepper and navigator
auto stepr = rk_stepper_t<typename covfie::field<bfield_bknd_t>::view_t>{};
auto nav = navigator_t<detector_host_t>{};
auto nav = navigator_t<host_detector_t>{};

using propagator_host_t =
propagator<decltype(stepr), decltype(nav), actor_chain_host_t>;
Expand Down
24 changes: 18 additions & 6 deletions tests/unit_tests/device/cuda/propagator_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ __global__ void propagator_test_kernel(
vecmem::data::jagged_vector_view<free_matrix> jac_transports_data) {

int gid = threadIdx.x + blockIdx.x * blockDim.x;
using detector_device_t =
detector<typename detector_t::metadata, device_container_types>;

static_assert(std::is_same_v<typename detector_t::view_type,
typename detector_device_t::view_type>,
"Host and device detector views do not match");

detector_device_t det(det_data);
vecmem::device_vector<track_t> tracks(tracks_data);
Expand Down Expand Up @@ -96,19 +102,25 @@ void propagator_test(
}

/// Explicit instantiation for a constant magnetic field
template void propagator_test<bfield::const_bknd_t, detector_host_t>(
detector_host_t::view_type, covfie::field_view<bfield::const_bknd_t>,
template void propagator_test<bfield::const_bknd_t,
detector<toy_metadata, host_container_types>>(
detector<toy_metadata, host_container_types>::view_type,
covfie::field_view<bfield::const_bknd_t>,
vecmem::data::vector_view<track_t>&,
vecmem::data::jagged_vector_view<intersection_t<detector_host_t>>&,
vecmem::data::jagged_vector_view<
intersection_t<detector<toy_metadata, host_container_types>>>&,
vecmem::data::jagged_vector_view<scalar>&,
vecmem::data::jagged_vector_view<vector3_t>&,
vecmem::data::jagged_vector_view<free_matrix>&);

/// Explicit instantiation for an inhomogeneous magnetic field
template void propagator_test<bfield::cuda::inhom_bknd_t, detector_host_t>(
detector_host_t::view_type, covfie::field_view<bfield::cuda::inhom_bknd_t>,
template void propagator_test<bfield::cuda::inhom_bknd_t,
detector<toy_metadata, host_container_types>>(
detector<toy_metadata, host_container_types>::view_type,
covfie::field_view<bfield::cuda::inhom_bknd_t>,
vecmem::data::vector_view<track_t>&,
vecmem::data::jagged_vector_view<intersection_t<detector_host_t>>&,
vecmem::data::jagged_vector_view<
intersection_t<detector<toy_metadata, host_container_types>>>&,
vecmem::data::jagged_vector_view<scalar>&,
vecmem::data::jagged_vector_view<vector3_t>&,
vecmem::data::jagged_vector_view<free_matrix>&);
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/device/cuda/propagator_cuda_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#pragma once

// Project include(s)
#include "detray/detectors/bfield.hpp"
#include "detray/detectors/toy_metadata.hpp"
#include "tests/common/test_base/propagator_test.hpp"

// Vecmem include(s)
Expand Down
27 changes: 20 additions & 7 deletions tests/unit_tests/device/sycl/propagator_kernel.sycl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ void propagator_test(
candidates_data, path_lengths_data,
positions_data, jac_transports_data](
::sycl::nd_item<1> item) {
using detector_device_t =
detector<typename detector_t::metadata,
device_container_types>;

static_assert(
std::is_same_v<typename detector_t::view_type,
typename detector_device_t::view_type>,
"Host and device detector views do not match");

detector_device_t dev_det(det_data);

vecmem::device_vector<track_t> tracks(tracks_data);
Expand Down Expand Up @@ -94,22 +103,26 @@ void propagator_test(
}

/// Explicit instantiation for a constant magnetic field
template void propagator_test<bfield::const_bknd_t, detector_host_t>(
detector_host_t::view_type, covfie::field_view<bfield::const_bknd_t>,
template void propagator_test<bfield::const_bknd_t,
detector<toy_metadata, host_container_types>>(
detector<toy_metadata, host_container_types>::view_type,
covfie::field_view<bfield::const_bknd_t>,
vecmem::data::vector_view<track_t>&,
vecmem::data::jagged_vector_view<intersection_t<detector_host_t>>&,
vecmem::data::jagged_vector_view<
intersection_t<detector<toy_metadata, host_container_types>>>&,
vecmem::data::jagged_vector_view<scalar>&,
vecmem::data::jagged_vector_view<vector3_t>&,
vecmem::data::jagged_vector_view<free_matrix>&,
detray::sycl::queue_wrapper);

/// Explicit instantiation for an inhomogeneous magnetic field
/*template void propagator_test<bfield::sycl::inhom_bknd_t, detector_host_t>(
detector_host_t::view_type,
/*template void propagator_test<bfield::sycl::inhom_bknd_t,
detector<toy_metadata, host_container_types>>( detector<toy_metadata,
host_container_types>::view_type,
covfie::field_view<bfield::sycl::inhom_bknd_t>,
vecmem::data::vector_view<track_t>&,
vecmem::data::jagged_vector_view<intersection_t<detector_host_t>>&,
vecmem::data::jagged_vector_view<scalar>&,
vecmem::data::jagged_vector_view<intersection_t<detector<toy_metadata,
host_container_types>>>&, vecmem::data::jagged_vector_view<scalar>&,
vecmem::data::jagged_vector_view<vector3_t>&,
vecmem::data::jagged_vector_view<free_matrix>&,
detray::sycl::queue_wrapper);*/
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/device/sycl/propagator_sycl_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#pragma once

// Project include(s)
#include "detray/detectors/bfield.hpp"
#include "detray/detectors/toy_metadata.hpp"
#include "queue_wrapper.hpp"
#include "tests/common/test_base/propagator_test.hpp"

Expand Down

0 comments on commit feb511f

Please sign in to comment.