Skip to content

Commit

Permalink
Update numpy set functions to be stride aware
Browse files Browse the repository at this point in the history
  • Loading branch information
bengineerd committed Sep 30, 2024
1 parent b47e9cd commit 340c295
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/rogue/interfaces/memory/Block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,7 @@ void rim::Block::setUIntPy(bp::object& value, rim::Variable* var, int32_t index)
PyArrayObject* arr = reinterpret_cast<decltype(arr)>(value.ptr());
npy_intp ndims = PyArray_NDIM(arr);
npy_intp* dims = PyArray_SHAPE(arr);
npy_intp* strides = PyArray_STRIDES(arr);

if (ndims != 1)
throw(rogue::GeneralError::create("Block::setUIntPy",
Expand All @@ -921,10 +922,18 @@ void rim::Block::setUIntPy(bp::object& value, rim::Variable* var, int32_t index)

if (PyArray_TYPE(arr) == NPY_UINT64) {
uint64_t* src = reinterpret_cast<uint64_t*>(PyArray_DATA(arr));
for (x = 0; x < dims[0]; x++) setUInt(src[x], var, index + x);
uint64_t value = 0;
for (x = 0; x < dims[0]; x++) {
value = &(src + x * strides[0]);
setUInt(value, var, index + x);
}
} else if (PyArray_TYPE(arr) == NPY_UINT32) {
uint32_t* src = reinterpret_cast<uint32_t*>(PyArray_DATA(arr));
for (x = 0; x < dims[0]; x++) setUInt(src[x], var, index + x);
uint32_t vlaue = 0;
for (x = 0; x < dims[0]; x++) {
value = &(src + x * strides[0]);
setUInt(value, var, index + x);
}
} else {
throw(rogue::GeneralError::create("Block::setUIntPy",
"Passed nparray is not of type (uint64 or uint32) for %s",
Expand Down

0 comments on commit 340c295

Please sign in to comment.