From 8effd6f46b8154390eaedebab95b732170fd3a2f Mon Sep 17 00:00:00 2001 From: William F Godoy Date: Tue, 6 Feb 2024 17:41:36 -0500 Subject: [PATCH] Fix AMDGPU Remove localmem for kernel launch --- ext/JACCAMDGPU/JACCAMDGPU.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/JACCAMDGPU/JACCAMDGPU.jl b/ext/JACCAMDGPU/JACCAMDGPU.jl index f327f4d..3383d93 100644 --- a/ext/JACCAMDGPU/JACCAMDGPU.jl +++ b/ext/JACCAMDGPU/JACCAMDGPU.jl @@ -26,8 +26,8 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I<:Integer,F<:Function} blocks = ceil(Int, N / threads) ret = AMDGPU.zeros(Float64, blocks) rret = AMDGPU.zeros(Float64, 1) - @roc groupsize = threads gridsize = threads * blocks localmem = 512 * sizeof(Float64) _parallel_reduce_amdgpu(N, ret, f, x...) - @roc groupsize = threads gridsize = threads localmem = 512 * sizeof(Float64) reduce_kernel_amdgpu(blocks, ret, rret) + @roc groupsize = threads gridsize = threads * blocks _parallel_reduce_amdgpu(N, ret, f, x...) + @roc groupsize = threads gridsize = threads reduce_kernel_amdgpu(blocks, ret, rret) return rret end @@ -40,8 +40,8 @@ function JACC.parallel_reduce((M, N), f::F, x...) where {F<:Function} Nblocks = ceil(Int, N / Nthreads) ret = AMDGPU.zeros(Float64, (Mblocks, Nblocks)) rret = AMDGPU.zeros(Float64, 1) - @roc groupsize = (Mthreads, Nthreads) gridsize = (Mblocks * Mthreads, Nblocks * Nthreads) localmem = 16 * 16 * sizeof(Float64) _parallel_reduce_amdgpu_MN((M, N), ret, f, x...) - @roc groupsize = (Mblocks, Nblocks) gridsize = (Mblocks, Nblocks) localmem = 16 * 16 * sizeof(Float64) reduce_kernel_amdgpu_MN((Mblocks, Nblocks), ret, rret) + @roc groupsize = (Mthreads, Nthreads) gridsize = (Mblocks * Mthreads, Nblocks * Nthreads) _parallel_reduce_amdgpu_MN((M, N), ret, f, x...) + @roc groupsize = (Mblocks, Nblocks) gridsize = (Mblocks, Nblocks) reduce_kernel_amdgpu_MN((Mblocks, Nblocks), ret, rret) return rret end