diff --git a/ext/JACCAMDGPU/JACCAMDGPU.jl b/ext/JACCAMDGPU/JACCAMDGPU.jl index f327f4d..3383d93 100644 --- a/ext/JACCAMDGPU/JACCAMDGPU.jl +++ b/ext/JACCAMDGPU/JACCAMDGPU.jl @@ -26,8 +26,8 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function} blocks = ceil(Int, N / threads) ret = AMDGPU.zeros(Float64, blocks) rret = AMDGPU.zeros(Float64, 1) - @roc groupsize = threads gridsize = threads * blocks localmem = 512 * sizeof(Float64) _parallel_reduce_amdgpu(N, ret, f, x...) - @roc groupsize = threads gridsize = threads localmem = 512 * sizeof(Float64) reduce_kernel_amdgpu(blocks, ret, rret) + @roc groupsize = threads gridsize = threads * blocks _parallel_reduce_amdgpu(N, ret, f, x...) + @roc groupsize = threads gridsize = threads reduce_kernel_amdgpu(blocks, ret, rret) return rret end @@ -40,8 +40,8 @@ function JACC.parallel_reduce((M, N), f::F, x...) where {F<:Function} Nblocks = ceil(Int, N / Nthreads) ret = AMDGPU.zeros(Float64, (Mblocks, Nblocks)) rret = AMDGPU.zeros(Float64, 1) - @roc groupsize = (Mthreads, Nthreads) gridsize = (Mblocks * Mthreads, Nblocks * Nthreads) localmem = 16 * 16 * sizeof(Float64) _parallel_reduce_amdgpu_MN((M, N), ret, f, x...) - @roc groupsize = (Mblocks, Nblocks) gridsize = (Mblocks, Nblocks) localmem = 16 * 16 * sizeof(Float64) reduce_kernel_amdgpu_MN((Mblocks, Nblocks), ret, rret) + @roc groupsize = (Mthreads, Nthreads) gridsize = (Mblocks * Mthreads, Nblocks * Nthreads) _parallel_reduce_amdgpu_MN((M, N), ret, f, x...) + @roc groupsize = (Mblocks, Nblocks) gridsize = (Mblocks, Nblocks) reduce_kernel_amdgpu_MN((Mblocks, Nblocks), ret, rret) return rret end