Skip to content

Commit

Permalink
Generalize 1D parallel_reduce and refactor cufunction bits
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipFackler committed Dec 20, 2024
1 parent bdc8439 commit 45433e0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 94 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = [
"williamfgc <[email protected]>",
"PhilipFackler <[email protected]>",
]
version = "0.1.0"
version = "0.1.1"

[deps]
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
Expand All @@ -22,8 +22,8 @@ JACCCUDA = ["CUDA"]
JACCONEAPI = ["oneAPI"]

[compat]
AMDGPU = "0.8"
Atomix = "0.1"
AMDGPU = "1.1.6"
Atomix = "1.0.1"
CUDA = "5"
Preferences = "1.4.0"
julia = "1.9, 1.10, 1.11"
158 changes: 67 additions & 91 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@ using .Experimental

JACC.get_backend(::Val{:cuda}) = CUDABackend()

@inline kernel_args(args...) = cudaconvert.((args))

@inline function kernel_maxthreads(kernel_function, kargs)
p_tt = Tuple{Core.Typeof.(kargs)...}
p_kernel = cufunction(kernel_function, p_tt)
maxThreads = CUDA.maxthreads(p_kernel)
return (p_kernel, CUDA.maxthreads(p_kernel))
end

function JACC.parallel_for(
::CUDABackend, N::I, f::F, x...) where {I <: Integer, F <: Function}
parallel_args = (N, f, x...)
parallel_kargs = cudaconvert.(parallel_args)
parallel_tt = Tuple{Core.Typeof.(parallel_kargs)...}
parallel_kernel = cufunction(_parallel_for_cuda, parallel_tt)
maxThreads = CUDA.maxthreads(parallel_kernel)
kargs = kernel_args(N, f, x...)
kernel, maxThreads = kernel_maxthreads(_parallel_for_cuda, kargs)
threads = min(N, maxThreads)
blocks = ceil(Int, N / threads)
shmem_size = attribute(
device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)
parallel_kernel(parallel_kargs...; threads = threads, blocks = blocks, shmem = shmem_size)
kernel(kargs...; threads = threads, blocks = blocks, shmem = shmem_size)
end

abstract type BlockIndexer2D end
Expand Down Expand Up @@ -64,11 +70,8 @@ function JACC.parallel_for(
m, n = (N, M)
end

parallel_args = (indexer, (M, N), f, x...)
parallel_kargs = cudaconvert.(parallel_args)
parallel_tt = Tuple{Core.Typeof.(parallel_kargs)...}
parallel_kernel = cufunction(_parallel_for_cuda_MN, parallel_tt)
maxThreads = CUDA.maxthreads(parallel_kernel)
kargs = kernel_args(indexer, (M, N), f, x...)
kernel, maxThreads = kernel_maxthreads(_parallel_for_cuda_MN, kargs)
blockAttrs = (
max_x = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X),
max_y = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y),
Expand All @@ -92,7 +95,7 @@ function JACC.parallel_for(

shmem_size = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)

parallel_kernel(parallel_kargs...; threads = threads, blocks = blocks, shmem = shmem_size)
kernel(kargs...; threads = threads, blocks = blocks, shmem = shmem_size)
end

function JACC.parallel_for(
Expand All @@ -117,15 +120,27 @@ end

function JACC.parallel_reduce(
::CUDABackend, N::Integer, op, f::Function, x...; init)
numThreads = 512
threads = min(N, numThreads)
ret_inst = CUDA.CuArray{typeof(init)}(undef, 0)

kargs_1 = kernel_args(N, op, ret_inst, f, x...)
kernel_1, maxThreads_1 = kernel_maxthreads(_parallel_reduce_cuda, kargs_1)

rret = CUDA.CuArray([init])
kargs_2 = kernel_args(1, op, ret_inst, rret)
kernel_2, maxThreads_2 = kernel_maxthreads(reduce_kernel_cuda, kargs_2)

threads = min(N, maxThreads_1, maxThreads_2)
threads = 2^(floor(Int, log2(threads))) # Limit to nearest power of 2
blocks = ceil(Int, N / threads)
shmem_size = threads * sizeof(init)

ret = fill!(CUDA.CuArray{typeof(init)}(undef, blocks), init)
rret = CUDA.CuArray([init])
CUDA.@sync @cuda threads=threads blocks=blocks shmem=512 * sizeof(typeof(init)) _parallel_reduce_cuda(
N, op, ret, f, x...)
CUDA.@sync @cuda threads=threads blocks=1 shmem=512 * sizeof(typeof(init)) reduce_kernel_cuda(
blocks, op, ret, rret)
kargs = kernel_args(N, op, ret, f, x...)
kernel_1(kargs...; threads = threads, blocks = blocks, shmem = shmem_size)

kargs = kernel_args(blocks, op, ret, rret)
kernel_2(kargs...; threads = threads, blocks = 1, shmem = shmem_size)

return Core.Array(rret)[]
end

Expand Down Expand Up @@ -174,113 +189,74 @@ function _parallel_for_cuda_LMN((L, M, N), f, x...)
return nothing
end

struct Pow2RevSeq
v::Int
end
Base.iterate(s::Pow2RevSeq) = s.v == 1 ? nothing : (s.v, Pow2RevSeq(s.v / 2))
Base.iterate(::Pow2RevSeq, s::Pow2RevSeq) = Base.iterate(s)

function _parallel_reduce_cuda(N, op, ret, f, x...)
shared_mem = @cuDynamicSharedMem(eltype(ret), 512)
shmem_length = blockDim().x
shared_mem = CuDynamicSharedArray(eltype(ret), shmem_length)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
ti = threadIdx().x
tmp::eltype(ret) = 0.0
shared_mem[ti] = 0.0
shared_mem[ti] = ret[blockIdx().x]

if i <= N
tmp = @inbounds f(i, x...)
shared_mem[threadIdx().x] = tmp
end
sync_threads()
if (ti <= 256)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 256])
end
sync_threads()
if (ti <= 128)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 128])
end
sync_threads()
if (ti <= 64)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 64])
end
sync_threads()
if (ti <= 32)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 32])
end
sync_threads()
if (ti <= 16)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 16])
end
sync_threads()
if (ti <= 8)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 8])
shared_mem[ti] = tmp
end
sync_threads()
if (ti <= 4)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 4])
end
sync_threads()
if (ti <= 2)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 2])

for tn in Pow2RevSeq(shmem_length / 2)
if (ti <= tn)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + tn])
end
sync_threads()
end
sync_threads()

if (ti == 1)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 1])
ret[blockIdx().x] = shared_mem[ti]
end

return nothing
end

function reduce_kernel_cuda(N, op, red, ret)
shared_mem = @cuDynamicSharedMem(eltype(ret), 512)
shmem_length = blockDim().x
shared_mem = CuDynamicSharedArray(eltype(ret), shmem_length)
i = threadIdx().x
ii = i
tmp::eltype(ret) = 0.0
if N > 512
if N > shmem_length
while ii <= N
tmp = op(tmp, @inbounds red[ii])
ii += 512
ii += shmem_length
end
elseif (i <= N)
tmp = @inbounds red[i]
end
shared_mem[i] = tmp
sync_threads()
if (i <= 256)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 256])
end
sync_threads()
if (i <= 128)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 128])
end
sync_threads()
if (i <= 64)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 64])
end
sync_threads()
if (i <= 32)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 32])
end
sync_threads()
if (i <= 16)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 16])
end
sync_threads()
if (i <= 8)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 8])
end
sync_threads()
if (i <= 4)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 4])
end
sync_threads()
if (i <= 2)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 2])

for tn in Pow2RevSeq(shmem_length / 2)
if (i <= tn)
shared_mem[i] = op(shared_mem[i], shared_mem[i + tn])
end
sync_threads()
end
sync_threads()

if (i == 1)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 1])
ret[1] = shared_mem[1]
end

return nothing
end

function _parallel_reduce_cuda_MN((M, N), op, ret, f, x...)
shared_mem = @cuDynamicSharedMem(eltype(ret), 16*16)
shared_mem = CuDynamicSharedArray(eltype(ret), 16*16)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
ti = threadIdx().x
Expand Down Expand Up @@ -332,7 +308,7 @@ function _parallel_reduce_cuda_MN((M, N), op, ret, f, x...)
end

function reduce_kernel_cuda_MN((M, N), op, red, ret)
shared_mem = @cuDynamicSharedMem(eltype(ret), 16*16)
shared_mem = CuDynamicSharedArray(eltype(ret), 16*16)
i = threadIdx().x
j = threadIdx().y
ii = i
Expand Down Expand Up @@ -432,7 +408,7 @@ end

function JACC.shared(x::CuDeviceArray{T, N}) where {T, N}
size = length(x)
shmem = @cuDynamicSharedMem(T, size)
shmem = CuDynamicSharedArray(T, size)
num_threads = blockDim().x * blockDim().y
if (size <= num_threads)
if blockDim().y == 1
Expand Down
4 changes: 4 additions & 0 deletions test/unittests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,15 @@ end
ad = JACC.Array(ah)
mxd = JACC.parallel_reduce(SIZE, max, (i, a) -> a[i], ad; init = -Inf)
@test mxd == maximum(ah)
mnd = JACC.parallel_reduce(SIZE, min, (i, a) -> a[i], ad; init = Inf)
@test mnd == minimum(ah)

ah2 = randn(FloatType, (SIZE, SIZE))
ad2 = JACC.Array(ah2)
mxd = JACC.parallel_reduce((SIZE, SIZE), max, (i, j, a) -> a[i, j], ad2; init = -Inf)
@test mxd == maximum(ah2)
mnd = JACC.parallel_reduce((SIZE, SIZE), min, (i, j, a) -> a[i, j], ad2; init = Inf)
@test mnd == minimum(ah2)
end

@testset "shared" begin
Expand Down

0 comments on commit 45433e0

Please sign in to comment.