Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use backend functions for SyncArray in CUDA and HIP #950

Merged
merged 1 commit into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions backends/cuda-ref/ceed-cuda-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,30 @@
#include <string.h>
#include "ceed-cuda-ref.h"


//------------------------------------------------------------------------------
// Check if host/device sync is needed
//------------------------------------------------------------------------------
static inline int CeedVectorNeedSync_Cuda(const CeedVector vec,
CeedMemType mem_type, bool *need_sync) {
int ierr;
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool has_valid_array = false;
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
switch (mem_type) {
case CEED_MEM_HOST:
*need_sync = has_valid_array && !impl->h_array;
break;
case CEED_MEM_DEVICE:
*need_sync = has_valid_array && !impl->d_array;
break;
}

return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Sync host to device
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
//------------------------------------------------------------------------------
// Sync arrays
//------------------------------------------------------------------------------
static inline int CeedVectorSync_Cuda(const CeedVector vec,
CeedMemType mem_type) {
static int CeedVectorSyncArray_Cuda(const CeedVector vec,
CeedMemType mem_type) {
int ierr;
// Check whether device/host sync is needed
bool need_sync = false;
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync);
CeedChkBackend(ierr);
if (!need_sync)
return CEED_ERROR_SUCCESS;

switch (mem_type) {
case CEED_MEM_HOST: return CeedVectorSyncD2H_Cuda(vec);
case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Cuda(vec);
Expand Down Expand Up @@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec,
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Check if is any array of given type
//------------------------------------------------------------------------------
static inline int CeedVectorNeedSync_Cuda(const CeedVector vec,
CeedMemType mem_type, bool *need_sync) {
int ierr;
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool has_valid_array = false;
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
switch (mem_type) {
case CEED_MEM_HOST:
*need_sync = has_valid_array && !impl->h_array;
break;
case CEED_MEM_DEVICE:
*need_sync = has_valid_array && !impl->d_array;
break;
}

return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Set array from host
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -368,11 +377,7 @@ static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type,
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

// Sync array to requested mem_type
bool need_sync = false;
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr);
if (need_sync) {
ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr);
}
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);

// Update pointer
switch (mem_type) {
Expand Down Expand Up @@ -403,14 +408,8 @@ static int CeedVectorGetArrayCore_Cuda(const CeedVector vec,
CeedVector_Cuda *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool need_sync = false, has_array_of_type = true;
ierr = CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync); CeedChkBackend(ierr);
ierr = CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type);
CeedChkBackend(ierr);
if (need_sync) {
// Sync array to requested mem_type
ierr = CeedVectorSync_Cuda(vec, mem_type); CeedChkBackend(ierr);
}
// Sync array to requested mem_type
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);

// Update pointer
switch (mem_type) {
Expand Down Expand Up @@ -763,6 +762,8 @@ int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
(int (*)())(CeedVectorSetValue_Cuda));
CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray",
CeedVectorSyncArray_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
CeedVectorGetArray_Cuda); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
Expand Down
76 changes: 39 additions & 37 deletions backends/hip-ref/ceed-hip-ref-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,30 @@
#include <string.h>
#include "ceed-hip-ref.h"


//------------------------------------------------------------------------------
// Check if host/device sync is needed
//------------------------------------------------------------------------------
static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
CeedMemType mem_type, bool *need_sync) {
int ierr;
CeedVector_Hip *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool has_valid_array = false;
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
switch (mem_type) {
case CEED_MEM_HOST:
*need_sync = has_valid_array && !impl->h_array;
break;
case CEED_MEM_DEVICE:
*need_sync = has_valid_array && !impl->d_array;
break;
}

return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Sync host to device
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,8 +112,16 @@ static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
//------------------------------------------------------------------------------
// Sync arrays
//------------------------------------------------------------------------------
static inline int CeedVectorSync_Hip(const CeedVector vec,
CeedMemType mem_type) {
static int CeedVectorSyncArray_Hip(const CeedVector vec,
CeedMemType mem_type) {
int ierr;
// Check whether device/host sync is needed
bool need_sync = false;
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync);
CeedChkBackend(ierr);
if (!need_sync)
return CEED_ERROR_SUCCESS;

switch (mem_type) {
case CEED_MEM_HOST: return CeedVectorSyncD2H_Hip(vec);
case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Hip(vec);
Expand Down Expand Up @@ -167,29 +199,6 @@ static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec,
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Sync array of given type
//------------------------------------------------------------------------------
static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
CeedMemType mem_type, bool *need_sync) {
int ierr;
CeedVector_Hip *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool has_valid_array = false;
ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
switch (mem_type) {
case CEED_MEM_HOST:
*need_sync = has_valid_array && !impl->h_array;
break;
case CEED_MEM_DEVICE:
*need_sync = has_valid_array && !impl->d_array;
break;
}

return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Set array from host
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -363,11 +372,7 @@ static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type,
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

// Sync array to requested mem_type
bool need_sync = false;
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr);
if (need_sync) {
ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr);
}
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);

// Update pointer
switch (mem_type) {
Expand Down Expand Up @@ -398,13 +403,8 @@ static int CeedVectorGetArrayCore_Hip(const CeedVector vec,
CeedVector_Hip *impl;
ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);

bool need_sync = false;
ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr);
CeedChkBackend(ierr);
if (need_sync) {
// Sync array to requested mem_type
ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr);
}
// Sync array to requested mem_type
ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);

// Update pointer
switch (mem_type) {
Expand Down Expand Up @@ -758,6 +758,8 @@ int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
CeedVectorTakeArray_Hip); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
(int (*)())(CeedVectorSetValue_Hip)); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray",
CeedVectorSyncArray_Hip); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
CeedVectorGetArray_Hip); CeedChkBackend(ierr);
ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
Expand Down
1 change: 1 addition & 0 deletions interface/ceed.c
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ int CeedInit(const char *resource, Ceed *ceed) {
CEED_FTABLE_ENTRY(CeedVector, SetArray),
CEED_FTABLE_ENTRY(CeedVector, TakeArray),
CEED_FTABLE_ENTRY(CeedVector, SetValue),
CEED_FTABLE_ENTRY(CeedVector, SyncArray),
jeremylt marked this conversation as resolved.
Show resolved Hide resolved
CEED_FTABLE_ENTRY(CeedVector, GetArray),
CEED_FTABLE_ENTRY(CeedVector, GetArrayRead),
CEED_FTABLE_ENTRY(CeedVector, GetArrayWrite),
Expand Down