Skip to content

Commit

Permalink
Merge pull request #1759 from LLNL/artv3/dynamic-forall-reductions
Browse files Browse the repository at this point in the history
Add support for the new reducer interface to dynamic forall
  • Loading branch information
artv3 authored Dec 12, 2024
2 parents 3c71010 + d2b2c29 commit 0e40820
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 96 deletions.
33 changes: 15 additions & 18 deletions examples/dynamic-forall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,25 @@ int main(int argc, char *argv[])

//----------------------------------------------------------------------------//

std::cout << "\n Running C-style vector addition...\n";

// _cstyle_vector_add_start
for (int i = 0; i < N; ++i) {
c[i] = a[i] + b[i];
}
// _cstyle_vector_add_end

checkResult(c, N);
//printResult(c, N);


//----------------------------------------------------------------------------//
// Example of dynamic policy selection for forall
//----------------------------------------------------------------------------//
std::cout << "\n Running dynamic forall vector addition and reductions...\n";

int sum = 0;
using VAL_INT_SUM = RAJA::expt::ValOp<int, RAJA::operators::plus>;

RAJA::RangeSegment range(0, N);

//policy is chosen from the list
RAJA::expt::dynamic_forall<policy_list>(pol, RAJA::RangeSegment(0, N), [=] RAJA_HOST_DEVICE (int i) {
RAJA::dynamic_forall<policy_list>(pol, range,
RAJA::expt::Reduce<RAJA::operators::plus>(&sum),
RAJA::expt::KernelName("RAJA dynamic forall"),
[=] RAJA_HOST_DEVICE (int i, VAL_INT_SUM &_sum) {

c[i] = a[i] + b[i];
_sum += 1;
});
// _rajaseq_vector_add_end

std::cout << "Sum = " << sum << ", expected sum: " << N << std::endl;
checkResult(c, N);
//printResult(c, N);

Expand Down Expand Up @@ -126,9 +123,9 @@ void checkResult(int* res, int len)
if ( res[i] != 0 ) { correct = false; }
}
if ( correct ) {
std::cout << "\n\t result -- PASS\n";
std::cout << "\n\t Vector sum result -- PASS\n";
} else {
std::cout << "\n\t result -- FAIL\n";
std::cout << "\n\t Vector sum result -- FAIL\n";
}
}

Expand Down
2 changes: 1 addition & 1 deletion examples/resource-dynamic-forall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ int main(int argc, char *argv[])
RAJA::resources::Resource res = RAJA::Get_Host_Resource(host_res, select_cpu_or_gpu);
#endif

RAJA::expt::dynamic_forall<policy_list>
RAJA::dynamic_forall<policy_list>
(res, pol, RAJA::RangeSegment(0, N), [=] RAJA_HOST_DEVICE (int i) {

c[i] = a[i] + b[i];
Expand Down
143 changes: 69 additions & 74 deletions include/RAJA/pattern/forall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,104 +647,99 @@ RAJA_INLINE camp::resources::EventProxy<Res> CallForallIcount::operator()(T cons
// - Returns a generic event proxy only if a resource is provided
// avoids overhead of constructing a typed erased resource
//
namespace expt
template<camp::idx_t IDX, typename POLICY_LIST>
struct dynamic_helper
{

template<camp::idx_t IDX, typename POLICY_LIST>
struct dynamic_helper
template<typename SEGMENT, typename... PARAMS>
static void invoke_forall(const int pol, SEGMENT const &seg, PARAMS&&... params)
{
template<typename SEGMENT, typename BODY>
static void invoke_forall(const int pol, SEGMENT const &seg, BODY const &body)
{
if(IDX==pol){
using t_pol = typename camp::at<POLICY_LIST,camp::num<IDX>>::type;
RAJA::forall<t_pol>(seg, body);
return;
}
dynamic_helper<IDX-1, POLICY_LIST>::invoke_forall(pol, seg, body);
}

template<typename SEGMENT, typename BODY>
static resources::EventProxy<resources::Resource>
invoke_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, BODY const &body)
{

if(IDX==pol){
using t_pol = typename camp::at<POLICY_LIST,camp::num<IDX>>::type;
using resource_type = typename resources::get_resource<t_pol>::type;

if(IDX==pol){
RAJA::forall<t_pol>(r.get<resource_type>(), seg, body);

//Return a generic event proxy from r,
//because forall returns a typed event proxy
return {r};
}

return dynamic_helper<IDX-1, POLICY_LIST>::invoke_forall(r, pol, seg, body);
RAJA::forall<t_pol>(seg, params...);
return;
}
dynamic_helper<IDX-1, POLICY_LIST>::invoke_forall(pol, seg, params...);
}

};

template<typename POLICY_LIST>
struct dynamic_helper<0, POLICY_LIST>
template<typename SEGMENT, typename... PARAMS>
static resources::EventProxy<resources::Resource>
invoke_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, PARAMS&&... params)
{
template<typename SEGMENT, typename BODY>
static void
invoke_forall(const int pol, SEGMENT const &seg, BODY const &body)
{
if(0==pol){
using t_pol = typename camp::at<POLICY_LIST,camp::num<0>>::type;
RAJA::forall<t_pol>(seg, body);
return;
}
RAJA_ABORT_OR_THROW("Policy enum not supported ");
}

template<typename SEGMENT, typename BODY>
static resources::EventProxy<resources::Resource>
invoke_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, BODY const &body)
{
if(pol != 0) RAJA_ABORT_OR_THROW("Policy value out of range ");

using t_pol = typename camp::at<POLICY_LIST,camp::num<0>>::type;
using resource_type = typename resources::get_resource<t_pol>::type;
using t_pol = typename camp::at<POLICY_LIST,camp::num<IDX>>::type;
using resource_type = typename resources::get_resource<t_pol>::type;

RAJA::forall<t_pol>(r.get<resource_type>(), seg, body);
if(IDX==pol){
RAJA::forall<t_pol>(r.get<resource_type>(), seg, params...);

//Return a generic event proxy from r,
//because forall returns a typed event proxy
return {r};
}

};
return dynamic_helper<IDX-1, POLICY_LIST>::invoke_forall(r, pol, seg, params...);
}

template<typename POLICY_LIST, typename SEGMENT, typename BODY>
void dynamic_forall(const int pol, SEGMENT const &seg, BODY const &body)
{
constexpr int N = camp::size<POLICY_LIST>::value;
static_assert(N > 0, "RAJA policy list must not be empty");
};

if(pol > N-1) {
RAJA_ABORT_OR_THROW("Policy enum not supported");
template<typename POLICY_LIST>
struct dynamic_helper<0, POLICY_LIST>
{
template<typename SEGMENT, typename... PARAMS>
static void
invoke_forall(const int pol, SEGMENT const &seg, PARAMS&&... params)
{
if(0==pol){
using t_pol = typename camp::at<POLICY_LIST,camp::num<0>>::type;
RAJA::forall<t_pol>(seg, params...);
return;
}
dynamic_helper<N-1, POLICY_LIST>::invoke_forall(pol, seg, body);
RAJA_ABORT_OR_THROW("Policy enum not supported ");
}

template<typename POLICY_LIST, typename SEGMENT, typename BODY>
resources::EventProxy<resources::Resource>
dynamic_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, BODY const &body)
template<typename SEGMENT, typename... PARAMS>
static resources::EventProxy<resources::Resource>
invoke_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, PARAMS&&... params)
{
constexpr int N = camp::size<POLICY_LIST>::value;
static_assert(N > 0, "RAJA policy list must not be empty");
if(pol != 0) RAJA_ABORT_OR_THROW("Policy value out of range ");

if(pol > N-1) {
RAJA_ABORT_OR_THROW("Policy value out of range");
}
using t_pol = typename camp::at<POLICY_LIST,camp::num<0>>::type;
using resource_type = typename resources::get_resource<t_pol>::type;

RAJA::forall<t_pol>(r.get<resource_type>(), seg, params...);

//Return a generic event proxy from r,
//because forall returns a typed event proxy
return {r};
}

};

return dynamic_helper<N-1, POLICY_LIST>::invoke_forall(r, pol, seg, body);
template<typename POLICY_LIST, typename SEGMENT, typename... PARAMS>
void dynamic_forall(const int pol, SEGMENT const &seg, PARAMS&&... params)
{
constexpr int N = camp::size<POLICY_LIST>::value;
static_assert(N > 0, "RAJA policy list must not be empty");

if(pol > N-1) {
RAJA_ABORT_OR_THROW("Policy enum not supported");
}
dynamic_helper<N-1, POLICY_LIST>::invoke_forall(pol, seg, params...);
}

template<typename POLICY_LIST, typename SEGMENT, typename... PARAMS>
resources::EventProxy<resources::Resource>
dynamic_forall(RAJA::resources::Resource r, const int pol, SEGMENT const &seg, PARAMS&&... params)
{
constexpr int N = camp::size<POLICY_LIST>::value;
static_assert(N > 0, "RAJA policy list must not be empty");

if(pol > N-1) {
RAJA_ABORT_OR_THROW("Policy value out of range");
}

} // namespace expt
return dynamic_helper<N-1, POLICY_LIST>::invoke_forall(r, pol, seg, params...);
}


} // namespace RAJA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void DynamicForallResourceRangeSegmentTestImpl(INDEX_TYPE first, INDEX_TYPE last

std::iota(test_array, test_array + RAJA::stripIndexType(N), rbegin);

RAJA::expt::dynamic_forall<POLICY_LIST>(working_res, pol, r1, [=] RAJA_HOST_DEVICE(INDEX_TYPE idx) {
RAJA::dynamic_forall<POLICY_LIST>(working_res, pol, r1, [=] RAJA_HOST_DEVICE(INDEX_TYPE idx) {
working_array[RAJA::stripIndexType(idx - rbegin)] = idx;
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void DynamicForallRangeSegmentTestImpl(INDEX_TYPE first, INDEX_TYPE last, const

std::iota(test_array, test_array + RAJA::stripIndexType(N), rbegin);

RAJA::expt::dynamic_forall<POLICY_LIST>(pol, r1, [=] RAJA_HOST_DEVICE(INDEX_TYPE idx) {
RAJA::dynamic_forall<POLICY_LIST>(pol, r1, [=] RAJA_HOST_DEVICE(INDEX_TYPE idx) {
working_array[RAJA::stripIndexType(idx - rbegin)] = idx;
});

Expand All @@ -50,7 +50,7 @@ void DynamicForallRangeSegmentTestImpl(INDEX_TYPE first, INDEX_TYPE last, const

working_res.memcpy(working_array, test_array, sizeof(INDEX_TYPE) * data_len);

RAJA::expt::dynamic_forall<POLICY_LIST>(pol, r1, [=] RAJA_HOST_DEVICE(INDEX_TYPE idx) {
RAJA::dynamic_forall<POLICY_LIST>(pol, r1, [=] RAJA_HOST_DEVICE(INDEX_TYPE idx) {
(void) idx;
working_array[0]++;
});
Expand Down

0 comments on commit 0e40820

Please sign in to comment.