Skip to content

Commit

Permalink
add operator== with refcounted data handles and their internal pointe…
Browse files Browse the repository at this point in the history
…rs, fix conversion for refcounted accessors
  • Loading branch information
nmm0 committed Jul 15, 2024
1 parent 9425083 commit f08a036
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 18 deletions.
38 changes: 27 additions & 11 deletions core/src/View/Kokkos_BasicView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class BasicView
SliceSpecifiers... slices)
: mdspan_type(submdspan(
src_view, Impl::transform_kokkos_slice_to_mdspan_slice(slices)...)) {}
public:

public:
//----------------------------------------
// Conversion to MDSpan
template <class OtherElementType, class OtherExtents, class OtherLayoutPolicy,
Expand All @@ -172,22 +172,38 @@ class BasicView
return mdspan_type(*this);
}

template <class OtherAccessorType =
Kokkos::default_accessor<typename mdspan_type::element_type>,
//Impl::SpaceAwareAccessor<
// memory_space,
// Kokkos::default_accessor<typename mdspan_type::element_type>>,
typename = std::enable_if_t<std::is_assignable_v<
// Here we use an overload instead of a default parameter as a workaround
// to a potential compiler bug with clang 17. It may be present in other compilers
template <class OtherAccessorType = AccessorPolicy,
typename = std::enable_if_t<std::is_assignable_v<
typename mdspan_type::data_handle_type,
typename OtherAccessorType::data_handle_type>>>
KOKKOS_INLINE_FUNCTION constexpr auto to_mdspan() {
using ret_mdspan_type =
mdspan<typename mdspan_type::element_type,
typename mdspan_type::extents_type,
typename mdspan_type::layout_type, OtherAccessorType>;
return ret_mdspan_type(
static_cast<typename OtherAccessorType::data_handle_type>(
mdspan_type::data_handle()),
mdspan_type::mapping(),
static_cast<OtherAccessorType>(mdspan_type::accessor()));
}

template <class OtherAccessorType = AccessorPolicy,
typename = std::enable_if_t<std::is_assignable_v<
typename mdspan_type::data_handle_type,
typename OtherAccessorType::data_handle_type>>>
KOKKOS_INLINE_FUNCTION constexpr auto to_mdspan(
const OtherAccessorType& other_accessor =
static_cast<OtherAccessorType>(mdspan_type::accessor())) {
const OtherAccessorType &other_accessor) {
using ret_mdspan_type =
mdspan<typename mdspan_type::element_type,
typename mdspan_type::extents_type,
typename mdspan_type::layout_type, OtherAccessorType>;
return ret_mdspan_type(static_cast<typename OtherAccessorType::data_handle_type>(mdspan_type::data_handle()), mdspan_type::mapping(), other_accessor);
return ret_mdspan_type(
static_cast<typename OtherAccessorType::data_handle_type>(
mdspan_type::data_handle()),
mdspan_type::mapping(), other_accessor);
}

void assign_data(element_type* ptr) {
Expand Down
49 changes: 46 additions & 3 deletions core/src/View/MDSpan/Kokkos_MDSpan_Accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ struct SpaceAwareAccessor<AnonymousSpace, NestedAccessor> {

KOKKOS_FUNCTION
explicit operator NestedAccessor() const { return nested_acc; }

template<class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
element_type(*) [], OtherElementType (*)[]> &&
Expand Down Expand Up @@ -278,6 +278,16 @@ class ReferenceCountedDataHandle {

std::string get_label() const { return m_tracker.get_label<memory_space>(); }

friend bool operator==(const ReferenceCountedDataHandle& lhs,
const value_type* rhs) {
return lhs.m_handle == rhs;
}

friend bool operator==(const value_type* lhs,
const ReferenceCountedDataHandle& rhs) {
return lhs == rhs.m_handle;
}

private:
template <class OtherElementType, class OtherSpace>
friend class ReferenceCountedDataHandle;
Expand Down Expand Up @@ -337,6 +347,16 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {

std::string get_label() const { return m_tracker.get_label<memory_space>(); }

friend bool operator==(const ReferenceCountedDataHandle& lhs,
const value_type* rhs) {
return lhs.m_handle == rhs;
}

friend bool operator==(const value_type* lhs,
const ReferenceCountedDataHandle& rhs) {
return lhs == rhs.m_handle;
}

private:
template <class OtherElementType, class OtherSpace>
friend class ReferenceCountedDataHandle;
Expand All @@ -345,6 +365,17 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {
pointer m_handle = nullptr;
};

template <class ElementType, class MemorySpace, class NestedAccessor>
class ReferenceCountedAccessor;

template <class Accessor>
struct IsReferenceCountedAccessorImpl : std::false_type {};

template <class ElementType, class MemorySpace, class NestedAccessor>
struct IsReferenceCountedAccessorImpl<
ReferenceCountedAccessor<ElementType, MemorySpace, NestedAccessor>>
: std::true_type {};

template <class ElementType, class MemorySpace, class NestedAccessor>
class ReferenceCountedAccessor {
public:
Expand All @@ -370,7 +401,13 @@ class ReferenceCountedAccessor {
constexpr ReferenceCountedAccessor(
const default_accessor<OtherElementType>&) {}

operator NestedAccessor() const { return m_nested_acc; }
template <class DstAccessor,
typename = std::enable_if_t<
!IsReferenceCountedAccessorImpl<DstAccessor>::value &&
std::is_convertible_v<NestedAccessor, DstAccessor>>>
operator DstAccessor() const {
return m_nested_acc;
}

constexpr reference access(data_handle_type p, size_t i) const {
return m_nested_acc.access(p.get(), i);
Expand Down Expand Up @@ -421,7 +458,13 @@ class ReferenceCountedAccessor<ElementType, AnonymousSpace, NestedAccessor> {
constexpr ReferenceCountedAccessor(
const default_accessor<OtherElementType>&) {}

operator NestedAccessor() const { return m_nested_acc; }
template <class DstAccessor,
typename = std::enable_if_t<
!IsReferenceCountedAccessorImpl<DstAccessor>::value &&
std::is_convertible_v<NestedAccessor, DstAccessor>>>
operator DstAccessor() const {
return m_nested_acc;
}

constexpr reference access(data_handle_type p, size_t i) const {
return m_nested_acc.access(p.get(), i);
Expand Down
6 changes: 3 additions & 3 deletions core/unit_test/default/TestDefaultDeviceDevelop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ TEST(defaultdevicetype, development_test) {
Kokkos::View<int*, Kokkos::MemoryTraits<Kokkos::Atomic>> b_atomic(b);
Kokkos::View<int*, Kokkos::MemoryTraits<Kokkos::Unmanaged>> b_unmanaged(b);
Kokkos::mdspan<int, Kokkos::dextents<int, 1>> mds(b.data(), 5);
auto sub_a = Kokkos::submdspan(mds, std::pair{1,3});
auto sub_b = Kokkos::submdspan(mds, std::array{1,3});
auto sub_c = Kokkos::submdspan(mds, Kokkos::pair{1,3});
auto sub_a = Kokkos::submdspan(mds, std::pair{1,3});
auto sub_b = Kokkos::submdspan(mds, std::array{1,3});
auto sub_c = Kokkos::submdspan(mds, Kokkos::pair{1,3});
auto acc = c.accessor();
const decltype(acc) acc_const = acc;
const Kokkos::default_accessor<float> acc_def = acc_const;//static_cast<Kokkos::default_accessor<float>>(acc_const);
Expand Down
43 changes: 42 additions & 1 deletion core/unit_test/view/TestBasicViewMDSpanConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,50 @@ static_assert(
Kokkos::Impl::checked_reference_counted_accessor<
const long long, Kokkos::HostSpace>>>);

using test_atomic_view = Kokkos::View<double *, Kokkos::Serial, Kokkos::MemoryTraits<Kokkos::Atomic>>;
using test_atomic_view = Kokkos::View<double *, Kokkos::Serial,
Kokkos::MemoryTraits<Kokkos::Atomic>>;
static_assert(std::is_same_v<
decltype(std::declval<test_atomic_view>()(std::declval<int>())),
desul::AtomicRef<double, desul::MemoryOrderRelaxed,
desul::MemoryScopeDevice>>);

static_assert(std::is_convertible_v<Kokkos::default_accessor<double>,
Kokkos::Impl::ReferenceCountedAccessor<
double, Kokkos::HostSpace,
Kokkos::default_accessor<double>>>);

static_assert(std::is_constructible_v<Kokkos::default_accessor<const double>,
Kokkos::default_accessor<double>>);

static_assert(std::is_convertible_v<Kokkos::default_accessor<double>,
Kokkos::default_accessor<const double>>);

static_assert(
std::is_constructible_v<
Kokkos::Impl::ReferenceCountedAccessor<
const double, Kokkos::HostSpace,
Kokkos::default_accessor<const double>>,
Kokkos::Impl::ReferenceCountedAccessor<
double, Kokkos::HostSpace, Kokkos::default_accessor<double>>>);

static_assert(std::is_convertible_v<
Kokkos::Impl::ReferenceCountedAccessor<
double, Kokkos::HostSpace, Kokkos::default_accessor<double>>,
Kokkos::Impl::ReferenceCountedAccessor<
const double, Kokkos::HostSpace,
Kokkos::default_accessor<const double>>>);

static_assert(std::is_constructible_v<Kokkos::default_accessor<const double>,
Kokkos::Impl::ReferenceCountedAccessor<
double, Kokkos::HostSpace,
Kokkos::default_accessor<double>>>);

static_assert(
std::is_convertible_v<
Kokkos::Impl::SpaceAwareAccessor<
Kokkos::HostSpace,
Kokkos::Impl::ReferenceCountedAccessor<
double, Kokkos::HostSpace, Kokkos::default_accessor<double>>>,
Kokkos::Impl::SpaceAwareAccessor<
Kokkos::HostSpace, Kokkos::default_accessor<const double>>>);
#endif

0 comments on commit f08a036

Please sign in to comment.