Skip to content

Commit

Permalink
WIP for OfffsetVIew/DynRankView fixes for mdspan
Browse files Browse the repository at this point in the history
  • Loading branch information
nmm0 committed Aug 19, 2024
1 parent f08a036 commit 1dbabc1
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
4 changes: 2 additions & 2 deletions containers/src/Kokkos_DynRankView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ class ViewMapping<
std::integral_constant<unsigned, 0>(),
src.layout()); // Check this for integer input1 for padding, etc
dst.m_map.m_impl_handle = Kokkos::Impl::ViewDataHandle<DstTraits>::assign(
src.m_map.m_impl_handle, src.m_track.m_tracker);
dst.m_track.assign(src.m_track.m_tracker, DstTraits::is_managed);
src.data(), src.impl_track());
dst.m_track.assign(src.impl_track(), DstTraits::is_managed);
dst.m_rank = Kokkos::View<ST, SP...>::rank();
}
};
Expand Down
5 changes: 4 additions & 1 deletion containers/src/Kokkos_OffsetView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <Kokkos_Core.hpp>

#include <Kokkos_View.hpp>
#include <View/MDSpan/Kokkos_MDSpan_Accessor.hpp>

namespace Kokkos {

Expand Down Expand Up @@ -779,7 +780,9 @@ class OffsetView : public ViewTraits<DataType, Properties...> {
public:
KOKKOS_FUNCTION
view_type view() const {
view_type v(m_track, m_map);
using mdspan_type = typename view_type::mdspan_type;
using data_handle_type = typename view_type::data_handle_type;
view_type v(data_handle_type(m_track, data()), Kokkos::Impl::mapping_from_view_mapping<mdspan_type>(m_map));
return v;
}

Expand Down
13 changes: 8 additions & 5 deletions core/src/Kokkos_View.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,12 @@ class View : public Impl::BasicViewFromTraits<DataType, Properties...>::type {
using offset_type = typename map_type::offset_type;
return map_type(
data(),
offset_type((std::integral_constant<unsigned, 0>(), layout())));
offset_type(std::integral_constant<unsigned, 0>(), layout()));
}

KOKKOS_INLINE_FUNCTION
const Kokkos::Impl::SharedAllocationTracker& impl_track() const {
return base_t::m_track.m_tracker;
return base_t::data_handle().tracker();
}
//----------------------------------------

Expand Down Expand Up @@ -461,6 +461,10 @@ class View : public Impl::BasicViewFromTraits<DataType, Properties...>::type {
return *this;
}

View(typename base_t::data_handle_type p,
const typename base_t::mapping_type& m)
: base_t(p, m){};

//----------------------------------------
// Compatible view copy constructor and assignment
// may assign unmanaged from managed.
Expand Down Expand Up @@ -506,7 +510,7 @@ class View : public Impl::BasicViewFromTraits<DataType, Properties...>::type {

template<class ... Args>
View(pointer_type ptr, Args ... args)
: base_t(Kokkos::view_wrap(ptr), typename mdspan_type::mapping_type(typename mdspan_type::extents_type{args...})) {}
: View(Kokkos::view_wrap(ptr), args...) {}

// Constructor which allows always 8 sizes should be deprecated
template <class... P>
Expand All @@ -528,7 +532,6 @@ class View : public Impl::BasicViewFromTraits<DataType, Properties...>::type {
"overload taking a layout object instead.");
}

#if 0
// Wrap memory according to properties and array layout
template <class... P>
explicit KOKKOS_INLINE_FUNCTION View(
Expand All @@ -537,7 +540,7 @@ class View : public Impl::BasicViewFromTraits<DataType, Properties...>::type {
typename traits::array_layout> const& arg_layout)
: base_t(arg_prop, arg_layout) {}


#if 0
template <class... P>
explicit KOKKOS_INLINE_FUNCTION View(
const Impl::ViewCtorProp<P...>& arg_prop,
Expand Down
10 changes: 10 additions & 0 deletions core/src/View/MDSpan/Kokkos_MDSpan_Accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ class ReferenceCountedDataHandle {
m_handle = static_cast<pointer>(get_record()->data());
}

ReferenceCountedDataHandle(const SharedAllocationTracker& tracker,
pointer data_handle)
: m_tracker(tracker), m_handle(data_handle) {}

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], value_type (*)[]>>>
Expand Down Expand Up @@ -277,6 +281,7 @@ class ReferenceCountedDataHandle {
int use_count() const noexcept { return m_tracker.use_count(); }

std::string get_label() const { return m_tracker.get_label<memory_space>(); }
const SharedAllocationTracker& tracker() const noexcept { return m_tracker; }

friend bool operator==(const ReferenceCountedDataHandle& lhs,
const value_type* rhs) {
Expand Down Expand Up @@ -313,6 +318,10 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {
m_handle = static_cast<pointer>(get_record()->data());
}

ReferenceCountedDataHandle(const SharedAllocationTracker& tracker,
pointer data_handle)
: m_tracker(tracker), m_handle(data_handle) {}

template <class OtherElementType,
class = std::enable_if_t<std::is_convertible_v<
OtherElementType (*)[], value_type (*)[]>>>
Expand Down Expand Up @@ -346,6 +355,7 @@ class ReferenceCountedDataHandle<ElementType, AnonymousSpace> {
int use_count() const noexcept { return m_tracker.use_count(); }

std::string get_label() const { return m_tracker.get_label<memory_space>(); }
const SharedAllocationTracker &tracker() const noexcept { return m_tracker; }

friend bool operator==(const ReferenceCountedDataHandle& lhs,
const value_type* rhs) {
Expand Down

0 comments on commit 1dbabc1

Please sign in to comment.