Skip to content

Commit

Permalink
Merge pull request #65 from pedrovalerolara/main
Browse files Browse the repository at this point in the history
Fixed some bugs on AMDGPU backend regarding synchronization and use of shared memory.
  • Loading branch information
pedrovalerolara authored Apr 11, 2024
2 parents e3b04d9 + bb9b81d commit ad771a6
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number,<:ROCArray}}) wh
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
@roc groupsize = threads gridsize = threads * blocks _parallel_for_amdgpu(f, x...)
@roc groupsize = threads gridsize = threads*blocks _parallel_for_amdgpu(f, x...)
# AMDGPU.synchronize()
end

Expand All @@ -16,7 +16,7 @@ function JACC.parallel_for((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number,<:
Nthreads = min(N, numThreads)
Mblocks = ceil(Int, M / Mthreads)
Nblocks = ceil(Int, N / Nthreads)
@roc groupsize = (Mthreads, Nthreads) gridsize = (Mblocks * Mthreads, Nblocks * Nthreads) _parallel_for_amdgpu_MN(f, x...)
@roc groupsize = (Mthreads, Nthreads) gridsize = (Mblocks*Mthreads, Nblocks*Nthreads) _parallel_for_amdgpu_MN(f, x...)
# AMDGPU.synchronize()
end

Expand All @@ -26,10 +26,10 @@ function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number,<:ROCArray}})
blocks = ceil(Int, N / threads)
ret = AMDGPU.zeros(Float64, blocks)
rret = AMDGPU.zeros(Float64, 1)
@roc groupsize = threads gridsize = threads * blocks _parallel_reduce_amdgpu(N, ret, f, x...)
AMDGPU.synchronize()
@roc groupsize = threads gridsize = threads*blocks _parallel_reduce_amdgpu(N, ret, f, x...)
#AMDGPU.synchronize()
@roc groupsize = threads gridsize = threads reduce_kernel_amdgpu(blocks, ret, rret)
AMDGPU.synchronize()
#AMDGPU.synchronize()
return rret

end
Expand All @@ -42,10 +42,10 @@ function JACC.parallel_reduce((M, N)::Tuple{I,I}, f::F, x::Vararg{Union{<:Number
Nblocks = ceil(Int, N / Nthreads)
ret = AMDGPU.zeros(Float64, (Mblocks, Nblocks))
rret = AMDGPU.zeros(Float64, 1)
@roc groupsize = (Mthreads, Nthreads) gridsize = (Mblocks * Mthreads, Nblocks * Nthreads) _parallel_reduce_amdgpu_MN((M, N), ret, f, x...)
AMDGPU.synchronize()
@roc groupsize = (Mblocks, Nblocks) gridsize = (Mblocks, Nblocks) reduce_kernel_amdgpu_MN((Mblocks, Nblocks), ret, rret)
AMDGPU.synchronize()
@roc groupsize = (Mthreads, Nthreads) gridsize = (Mblocks*Mthreads, Nblocks*Nthreads) _parallel_reduce_amdgpu_MN((M, N), ret, f, x...)
#AMDGPU.synchronize()
@roc groupsize = (Mthreads, Nthreads) gridsize = (Mthreads, Nthreads) reduce_kernel_amdgpu_MN((Mblocks, Nblocks), ret, rret)
#AMDGPU.synchronize()
return rret
end

Expand All @@ -63,7 +63,7 @@ function _parallel_for_amdgpu_MN(f, x...)
end

function _parallel_reduce_amdgpu(N, ret, f, x...)
shared_mem = @ROCDynamicLocalArray(Float64, 512)
shared_mem = @ROCStaticLocalArray(Float64, 512)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
ti = workitemIdx().x
tmp::Float64 = 0.0
Expand Down Expand Up @@ -115,7 +115,7 @@ function _parallel_reduce_amdgpu(N, ret, f, x...)
end

function reduce_kernel_amdgpu(N, red, ret)
shared_mem = @ROCDynamicLocalArray(Float64, 512)
shared_mem = @ROCStaticLocalArray(Float64, 512)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
ii = i
tmp::Float64 = 0.0
Expand Down Expand Up @@ -169,7 +169,7 @@ function reduce_kernel_amdgpu(N, red, ret)
end

function _parallel_reduce_amdgpu_MN((M, N), ret, f, x...)
shared_mem = @ROCDynamicLocalArray(Float64, 16 * 16)
shared_mem = @ROCStaticLocalArray(Float64, 256)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
j = (workgroupIdx().y - 1) * workgroupDim().y + workitemIdx().y
ti = workitemIdx().x
Expand Down Expand Up @@ -213,7 +213,7 @@ function _parallel_reduce_amdgpu_MN((M, N), ret, f, x...)
end

function reduce_kernel_amdgpu_MN((M, N), red, ret)
shared_mem = @ROCDynamicLocalArray(Float64, 16 * 16)
shared_mem = @ROCStaticLocalArray(Float64, 256)
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
j = (workgroupIdx().y - 1) * workgroupDim().y + workitemIdx().y
ii = i
Expand Down

0 comments on commit ad771a6

Please sign in to comment.