Skip to content

Commit

Permalink
Update the way dtype is passed
Browse files Browse the repository at this point in the history
  • Loading branch information
bengineerd committed Oct 15, 2024
1 parent 5aa304b commit 25c5199
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
4 changes: 2 additions & 2 deletions include/rogue/interfaces/stream/Frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,10 @@ class Frame : public rogue::EnableSharedFromThis<rogue::interfaces::stream::Fram
* @return The read data as a 1-D numpy byte array
*
* @param[in] offset The byte offset into the frame to write to
* @param[in] size The number of bytes to write
* @param[in] count The number of bytes to write
*
*/
boost::python::object getNumpy(uint32_t offset = 0, uint32_t count = 0, int dtype = NPY_UINT8);
boost::python::object getNumpy(uint32_t offset, uint32_t count, boost::python::object dtype);

//! Python Frame data write using a numpy array as the source
/*
Expand Down
25 changes: 22 additions & 3 deletions src/rogue/interfaces/stream/Frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ void ris::Frame::writePy(boost::python::object p, uint32_t offset) {
}

//! Read the specified number of bytes at the specified offset of frame data into a numpy array
boost::python::object ris::Frame::getNumpy(uint32_t offset, uint32_t count, int dtype) {
boost::python::object ris::Frame::getNumpy(uint32_t offset, uint32_t count, bp::object dtype) {
// Retrieve the size, in bytes of the data
npy_intp size = getPayload();

Expand All @@ -414,9 +414,20 @@ boost::python::object ris::Frame::getNumpy(uint32_t offset, uint32_t count, int
size));
}


// Convert Python dtype object to NumPy type
int numpy_type;
PyObject* dtype_pyobj = dtype.ptr(); // Get the raw PyObject from the Boost.Python object
if (PyArray_DescrCheck(dtype_pyobj)) {
numpy_type = ((PyArray_Descr*)dtype_pyobj)->type_num;
} else {
throw(rogue::GeneralError::create("Frame::getNumpy",
"Invalid dtype argument. Must be a NumPy dtype object."));
}

// Create a numpy array to receive it and locate the destination data buffer
npy_intp dims[1] = {count};
PyObject* obj = PyArray_SimpleNew(1, dims, dtype);
PyObject* obj = PyArray_SimpleNew(1, dims, numpy_type);
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(obj);
uint8_t* dst = reinterpret_cast<uint8_t*>(PyArray_DATA(arr));

Expand Down Expand Up @@ -488,6 +499,14 @@ void ris::Frame::setup_python() {

_import_array();

// Create a NumPy dtype object from the NPY_UINT8 constant
PyObject* dtype_uint8 = reinterpret_cast<PyObject*>(PyArray_DescrFromType(NPY_UINT8));
if (!dtype_uint8) {
throw(rogue::GeneralError::create("Frame::setup_python",
"Failed to create default dtype object for NPY_UINT8."));
}


bp::class_<ris::Frame, ris::FramePtr, boost::noncopyable>("Frame", bp::no_init)
.def("lock", &ris::Frame::lock)
.def("getSize", &ris::Frame::getSize)
Expand All @@ -508,7 +527,7 @@ void ris::Frame::setup_python() {
.def("getNumpy", &ris::Frame::getNumpy, (
bp::arg("offset")=0,
bp::arg("count")=0,
bp::arg("dtype")=NPY_UINT8))
bp::arg("dtype")=bp::object(bp::handle<>(bp::borrowed(dtype_uint8)))))
.def("putNumpy", &ris::Frame::putNumpy)
.def("_debug", &ris::Frame::debug);
#endif
Expand Down

0 comments on commit 25c5199

Please sign in to comment.