Skip to content

Commit

Permalink
Merge pull request #136 from PhilipFackler/CUDA-blocks-threads
Browse files Browse the repository at this point in the history
WIP: Better blocks/threads calculations for CUDA backend
  • Loading branch information
PhilipFackler authored Feb 7, 2025
2 parents 0998b86 + 2c802da commit 48383bb
Showing 1 changed file with 116 additions and 99 deletions.
215 changes: 116 additions & 99 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,85 @@ using .Experimental

JACC.get_backend(::Val{:cuda}) = CUDABackend()

@inline kernel_args(args...) = cudaconvert.((args))

@inline function kernel_maxthreads(kernel_function, kargs)
p_tt = Tuple{Core.Typeof.(kargs)...}
p_kernel = cufunction(kernel_function, p_tt)
maxThreads = CUDA.maxthreads(p_kernel)
return (p_kernel, CUDA.maxthreads(p_kernel))
end

function JACC.parallel_for(
::CUDABackend, N::I, f::F, x...) where {I <: Integer, F <: Function}
#parallel_args = (N, f, x...)
#parallel_kargs = cudaconvert.(parallel_args)
#parallel_tt = Tuple{Core.Typeof.(parallel_kargs)...}
#parallel_kernel = cufunction(_parallel_for_cuda, parallel_tt)
#maxPossibleThreads = CUDA.maxthreads(parallel_kernel)
maxPossibleThreads = 512
threads = min(N, maxPossibleThreads)
kargs = kernel_args(N, f, x...)
kernel, maxThreads = kernel_maxthreads(_parallel_for_cuda, kargs)
threads = min(N, maxThreads)
blocks = ceil(Int, N / threads)
shmem_size = attribute(
device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)
#parallel_kernel(parallel_kargs...; threads = threads, blocks = blocks)
CUDA.@sync @cuda threads=threads blocks=blocks shmem=shmem_size _parallel_for_cuda(
N, f, x...)
kernel(kargs...; threads = threads, blocks = blocks, shmem = shmem_size)
end

abstract type BlockIndexer2D end

struct BlockIndexerBasic <: BlockIndexer2D end

function (blkIter::BlockIndexerBasic)(blockIdx, blockDim, threadIdx)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
return (i, j)
end

struct BlockIndexerSwapped <: BlockIndexer2D end

function (blkIter::BlockIndexerSwapped)(blockIdx, blockDim, threadIdx)
j = (blockIdx().x - 1) * blockDim().x + threadIdx().x
i = (blockIdx().y - 1) * blockDim().y + threadIdx().y
return (i, j)
end

function JACC.parallel_for(
::CUDABackend, (M, N)::Tuple{I, I}, f::F, x...) where {
I <: Integer, F <: Function}
#To use JACC.shared, it is recommended to use a high number of threads per block to maximize the
# potential benefit from using shared memory.
#numThreads = 32
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Mblocks = ceil(Int, M / Mthreads)
Nblocks = ceil(Int, N / Nthreads)
shmem_size = attribute(
device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)
CUDA.@sync @cuda threads=(Mthreads, Nthreads) blocks=(Mblocks, Nblocks) shmem=shmem_size _parallel_for_cuda_MN(
(M, N), f, x...)

dev = CUDA.device()
maxBlocks = (
x = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_GRID_DIM_X),
y = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y),
)
indexer = BlockIndexerBasic()
m, n = (M, N)
if M < N && maxBlocks.x > maxBlocks.y
indexer = BlockIndexerSwapped()
m, n = (N, M)
end

kargs = kernel_args(indexer, (M, N), f, x...)
kernel, maxThreads = kernel_maxthreads(_parallel_for_cuda_MN, kargs)
blockAttrs = (
max_x = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X),
max_y = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y),
total = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK))
x_thr = min(
blockAttrs.max_x,
nextpow(2, m / blockAttrs.total + 1),
blockAttrs.total,
maxThreads
)
y_thr = min(
blockAttrs.max_y,
ceil(Int, blockAttrs.total / x_thr),
ceil(Int, maxThreads / x_thr),
)
threads = (x_thr, y_thr)
blocks = (cld(m, x_thr), cld(n, y_thr))

shmem_size = attribute(dev, CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK)

kernel(kargs...; threads = threads, blocks = blocks, shmem = shmem_size)
end

function JACC.parallel_for(
Expand All @@ -70,16 +117,27 @@ end

function JACC.parallel_reduce(
::CUDABackend, N::Integer, op, f::Function, x...; init)
numThreads = 512
threads = numThreads
ret_inst = CUDA.CuArray{typeof(init)}(undef, 0)

kargs_1 = kernel_args(N, op, ret_inst, f, x...)
kernel_1, maxThreads_1 = kernel_maxthreads(_parallel_reduce_cuda, kargs_1)

rret = CUDA.CuArray([init])
kargs_2 = kernel_args(1, op, ret_inst, rret)
kernel_2, maxThreads_2 = kernel_maxthreads(reduce_kernel_cuda, kargs_2)

threads = min(maxThreads_1, maxThreads_2, 512)
blocks = ceil(Int, N / threads)

shmem_size = threads * sizeof(init)

ret = fill!(CUDA.CuArray{typeof(init)}(undef, blocks), init)
rret = CUDA.CuArray([init])
shmem_size = 512 * sizeof(init)
CUDA.@sync @cuda threads=threads blocks=blocks shmem=shmem_size _parallel_reduce_cuda(
N, op, ret, f, x...)
CUDA.@sync @cuda threads=threads blocks=1 shmem=shmem_size reduce_kernel_cuda(
blocks, op, ret, rret)
kargs = kernel_args(N, op, ret, f, x...)
kernel_1(kargs...; threads = threads, blocks = blocks, shmem = shmem_size)

kargs = kernel_args(blocks, op, ret, rret)
kernel_2(kargs...; threads = threads, blocks = 1, shmem = shmem_size)

return Base.Array(rret)[]
end

Expand Down Expand Up @@ -107,9 +165,8 @@ function _parallel_for_cuda(N, f, x...)
return nothing
end

function _parallel_for_cuda_MN((M, N), f, x...)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
j = (blockIdx().y - 1) * blockDim().y + threadIdx().y
function _parallel_for_cuda_MN(indexer::BlockIndexer2D, (M, N), f, x...)
i, j = indexer(blockIdx, blockDim, threadIdx)
i > M && return nothing
j > N && return nothing
f(i, j, x...)
Expand All @@ -128,7 +185,8 @@ function _parallel_for_cuda_LMN((L, M, N), f, x...)
end

function _parallel_reduce_cuda(N, op, ret, f, x...)
shared_mem = CuDynamicSharedArray(eltype(ret), 512)
shmem_length = blockDim().x
shared_mem = CuDynamicSharedArray(eltype(ret), shmem_length)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
ti = threadIdx().x
shared_mem[ti] = ret[blockIdx().x]
Expand All @@ -138,96 +196,55 @@ function _parallel_reduce_cuda(N, op, ret, f, x...)
shared_mem[ti] = tmp
end
sync_threads()
if (ti <= 256)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 256])
end
sync_threads()
if (ti <= 128)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 128])
end
sync_threads()
if (ti <= 64)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 64])
end
sync_threads()
if (ti <= 32)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 32])
end
sync_threads()
if (ti <= 16)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 16])
end
sync_threads()
if (ti <= 8)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 8])
end
sync_threads()
if (ti <= 4)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 4])
end
sync_threads()
if (ti <= 2)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 2])

tn = div(shmem_length, 2)
while tn > 1
if (ti <= tn)
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + tn])
end
sync_threads()
tn = div(tn, 2)
end
sync_threads()
if (ti == 1)

if ti == 1
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 1])
ret[blockIdx().x] = shared_mem[ti]
end

return nothing
end

function reduce_kernel_cuda(N, op, red, ret)
shared_mem = CuDynamicSharedArray(eltype(ret), 512)
shmem_length = blockDim().x
shared_mem = CuDynamicSharedArray(eltype(ret), shmem_length)
i = threadIdx().x
ii = i
tmp = ret[1]
if N > 512
if N > shmem_length
while ii <= N
tmp = op(tmp, @inbounds red[ii])
ii += 512
ii += shmem_length
end
elseif (i <= N)
tmp = @inbounds red[i]
end
shared_mem[i] = tmp
sync_threads()
if (i <= 256)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 256])
end
sync_threads()
if (i <= 128)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 128])
end
sync_threads()
if (i <= 64)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 64])
end
sync_threads()
if (i <= 32)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 32])
end
sync_threads()
if (i <= 16)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 16])
end
sync_threads()
if (i <= 8)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 8])
end
sync_threads()
if (i <= 4)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 4])
end
sync_threads()
if (i <= 2)
shared_mem[i] = op(shared_mem[i], shared_mem[i + 2])

tn = div(shmem_length, 2)
while tn > 1
if i <= tn
shared_mem[i] = op(shared_mem[i], shared_mem[i + tn])
end
sync_threads()
tn = div(tn, 2)
end
sync_threads()
if (i == 1)

if i == 1
shared_mem[i] = op(shared_mem[i], shared_mem[i + 1])
ret[1] = shared_mem[1]
end

return nothing
end

Expand Down

0 comments on commit 48383bb

Please sign in to comment.