Skip to content

Commit

Permalink
Allow offsets in pyopencl arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
isuruf committed Mar 15, 2022
1 parent 1cb234c commit 7cc60bf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
17 changes: 10 additions & 7 deletions pyvkfft/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def fft(self, src: cla.Array, dest: cla.Array = None):
if dest is not None:
if src.data.int_ptr != dest.data.int_ptr:
raise RuntimeError("VkFFTApp.fft: dest is not None but this is an inplace transform")
res = _vkfft_opencl.fft(self.app, int(src.data.int_ptr), int(src.data.int_ptr), int(self.queue.int_ptr))
res = _vkfft_opencl.fft(self.app, int(src.base_data.int_ptr), int(src.base_data.int_ptr), int(self.queue.int_ptr),
int(src.offset), int(src.offset))
check_vkfft_result(res, src.shape, src.dtype, self.ndim, self.inplace, self.norm, self.r2c,
self.dct, backend="opencl")
if self.norm == "ortho":
Expand All @@ -186,7 +187,8 @@ def fft(self, src: cla.Array, dest: cla.Array = None):
raise RuntimeError("VkFFTApp.fft: dest and src are identical but this is an out-of-place transform")
if self.r2c:
assert (dest.size == src.size // src.shape[-1] * (src.shape[-1] // 2 + 1))
res = _vkfft_opencl.fft(self.app, int(src.data.int_ptr), int(dest.data.int_ptr), int(self.queue.int_ptr))
res = _vkfft_opencl.fft(self.app, int(src.base_data.int_ptr), int(dest.base_data.int_ptr),
int(self.queue.int_ptr), int(src.offset), int(dest.offset))
check_vkfft_result(res, src.shape, src.dtype, self.ndim, self.inplace, self.norm, self.r2c,
self.dct, backend="opencl")
if self.norm == "ortho":
Expand All @@ -206,7 +208,8 @@ def ifft(self, src: cla.Array, dest: cla.Array = None):
if dest is not None:
if src.data.int_ptr != dest.data.int_ptr:
raise RuntimeError("VkFFTApp.fft: dest!=src but this is an inplace transform")
res = _vkfft_opencl.ifft(self.app, int(src.data.int_ptr), int(src.data.int_ptr), int(self.queue.int_ptr))
res = _vkfft_opencl.ifft(self.app, int(src.base_data.int_ptr), int(src.base_data.int_ptr),
int(self.queue.int_ptr), int(src.offset), int(src.offset))
check_vkfft_result(res, src.shape, src.dtype, self.ndim, self.inplace, self.norm, self.r2c,
self.dct, backend="opencl")
if self.norm == "ortho":
Expand All @@ -226,11 +229,11 @@ def ifft(self, src: cla.Array, dest: cla.Array = None):
assert (src.size == dest.size // dest.shape[-1] * (dest.shape[-1] // 2 + 1))
# Special case, src and dest buffer sizes are different,
# VkFFT is configured to go back to the source buffer
res = _vkfft_opencl.ifft(self.app, int(dest.data.int_ptr), int(src.data.int_ptr),
int(self.queue.int_ptr))
res = _vkfft_opencl.ifft(self.app, int(dest.base_data.int_ptr), int(src.base_data.int_ptr),
int(self.queue.int_ptr), int(dest.offset), int(src.offset))
else:
res = _vkfft_opencl.ifft(self.app, int(src.data.int_ptr), int(dest.data.int_ptr),
int(self.queue.int_ptr))
res = _vkfft_opencl.ifft(self.app, int(src.base_data.int_ptr), int(dest.base_data.int_ptr),
int(self.queue.int_ptr), int(src.offset), int(dest.offset))
check_vkfft_result(res, src.shape, src.dtype, self.ndim, self.inplace, self.norm, self.r2c,
self.dct, backend="opencl")
if self.norm == "ortho":
Expand Down
18 changes: 14 additions & 4 deletions src/vkfft_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ LIBRARY_API VkFFTConfiguration* make_config(const size_t, const size_t, const si

LIBRARY_API VkFFTApplication* init_app(const VkFFTConfiguration*, void*, int*);

LIBRARY_API int fft(VkFFTApplication* app, void*, void*, void*);
LIBRARY_API int fft(VkFFTApplication* app, void*, void*, void*, uint64_t, uint64_t);

LIBRARY_API int ifft(VkFFTApplication* app, void*, void*, void*);
LIBRARY_API int ifft(VkFFTApplication* app, void*, void*, void*, uint64_t, uint64_t);

LIBRARY_API void free_app(VkFFTApplication* app);

Expand Down Expand Up @@ -176,7 +176,8 @@ VkFFTApplication* init_app(const VkFFTConfiguration* config, void *queue, int *r
return app;
}

int fft(VkFFTApplication* app, void *in, void *out, void* queue)
int fft(VkFFTApplication* app, void *in, void *out, void* queue,
uint64_t in_offset, uint64_t out_offset)
{
cl_command_queue q = (cl_command_queue) queue;

Expand All @@ -185,18 +186,23 @@ int fft(VkFFTApplication* app, void *in, void *out, void* queue)
*(app->configuration.buffer) = (cl_mem)out;
*(app->configuration.inputBuffer) = (cl_mem)in;
*(app->configuration.outputBuffer) = (cl_mem)out;
app->configuration.inputBufferOffset = in_offset;
app->configuration.outputBufferOffset = out_offset;
app->configuration.commandQueue = &q;

VkFFTLaunchParams par = {};
par.commandQueue = &q;
par.buffer = app->configuration.buffer;
par.inputBuffer = app->configuration.inputBuffer;
par.outputBuffer = app->configuration.outputBuffer;
par.inputBufferOffset = app->configuration.inputBufferOffset;
par.outputBufferOffset = app->configuration.outputBufferOffset;

return VkFFTAppend(app, -1, &par);
}

int ifft(VkFFTApplication* app, void *in, void *out, void* queue)
int ifft(VkFFTApplication* app, void *in, void *out, void* queue,
uint64_t in_offset, uint64_t out_offset)
{
cl_command_queue q = (cl_command_queue) queue;

Expand All @@ -205,13 +211,17 @@ int ifft(VkFFTApplication* app, void *in, void *out, void* queue)
*(app->configuration.buffer) = (cl_mem)out;
*(app->configuration.inputBuffer) = (cl_mem)in;
*(app->configuration.outputBuffer) = (cl_mem)out;
app->configuration.inputBufferOffset = in_offset;
app->configuration.outputBufferOffset = out_offset;
app->configuration.commandQueue = &q;

VkFFTLaunchParams par = {};
par.commandQueue = &q;
par.buffer = app->configuration.buffer;
par.inputBuffer = app->configuration.inputBuffer;
par.outputBuffer = app->configuration.outputBuffer;
par.inputBufferOffset = app->configuration.inputBufferOffset;
par.outputBufferOffset = app->configuration.outputBufferOffset;

return VkFFTAppend(app, 1, &par);
}
Expand Down

0 comments on commit 7cc60bf

Please sign in to comment.