Skip to content

Commit

Permalink
Add tag dispatch to separate implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipFackler committed Apr 18, 2024
1 parent abe26c3 commit 025175b
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 21 deletions.
11 changes: 7 additions & 4 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ using JACC, AMDGPU

include("array.jl")

function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
struct AMDGPUTag end

function JACC.parallel_for(::AMDGPUTag, N::I, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
@roc groupsize = threads gridsize = blocks _parallel_for_amdgpu(f, x...)
AMDGPU.synchronize()
end

function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::AMDGPUTag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -22,7 +24,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer,
AMDGPU.synchronize()
end

function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::AMDGPUTag, N::I, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -36,7 +38,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Functi

end

function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::AMDGPUTag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down Expand Up @@ -303,6 +305,7 @@ end

function __init__()
const JACC.Array = AMDGPU.ROCArray{T, N} where {T, N}
const JACC.Tag = AMDGPUTag
end

end # module JACCAMDGPU
16 changes: 12 additions & 4 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using JACC, CUDA
# overloaded array functions
include("array.jl")

function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
struct CUDATag end

function JACC.parallel_for(::CUDATag, N::I, f::F, x...) where {I <: Integer, F <: Function}
parallel_args = (N, f, x...)
parallel_kargs = cudaconvert.(parallel_args)
parallel_tt = Tuple{Core.Typeof.(parallel_kargs)...}
Expand All @@ -14,9 +16,14 @@ function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
threads = min(N, maxPossibleThreads)
blocks = ceil(Int, N / threads)
parallel_kernel(parallel_kargs...; threads=threads, blocks=blocks)

# maxPossibleThreads = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X)
# threads = min(N, maxPossibleThreads)
# blocks = ceil(Int, N / threads)
# CUDA.@sync @cuda threads = threads blocks = blocks _parallel_for_cuda(f, x...)
end

function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::CUDATag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -25,7 +32,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer,
CUDA.@sync @cuda threads = (Mthreads, Nthreads) blocks = (Mblocks, Nblocks) _parallel_for_cuda_MN(f, x...)
end

function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::CUDATag, N::I, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -37,7 +44,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Functi
end


function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::CUDATag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down Expand Up @@ -304,6 +311,7 @@ end

function __init__()
const JACC.Array = CUDA.CuArray{T, N} where {T, N}
const JACC.Tag = CUDATag
end

end # module JACCCUDA
11 changes: 7 additions & 4 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ module JACCONEAPI

using JACC, oneAPI

function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
struct oneAPITag end

function JACC.parallel_for(::oneAPITag, N::I, f::F, x...) where {I <: Integer, F <: Function}
#maxPossibleItems = oneAPI.oneL0.compute_properties(device().maxTotalGroupSize)
maxPossibleItems = 256
items = min(N, maxPossibleItems)
groups = ceil(Int, N / items)
oneAPI.@sync @oneapi items = items groups = groups _parallel_for_oneapi(f, x...)
end

function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_for(::oneAPITag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
maxPossibleItems = 16
Mitems = min(M, maxPossibleItems)
Nitems = min(N, maxPossibleItems)
Expand All @@ -20,7 +22,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer,
oneAPI.@sync @oneapi items = (Mitems, Nitems) groups = (Mgroups, Ngroups) _parallel_for_oneapi_MN(f, x...)
end

function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::oneAPITag, N::I, f::F, x...) where {I <: Integer, F <: Function}
numItems = 256
items = min(N, numItems)
groups = ceil(Int, N / items)
Expand All @@ -31,7 +33,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Functi
return rret
end

function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(::oneAPITag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numItems = 16
Mitems = min(M, numItems)
Nitems = min(N, numItems)
Expand Down Expand Up @@ -295,6 +297,7 @@ end

function __init__()
const JACC.Array = oneAPI.oneArray{T, N} where {T, N}
const JACC.Tag = oneAPITag
end

end # module JACCONEAPI
34 changes: 25 additions & 9 deletions src/JACC.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
__precompile__(false)
# __precompile__(false)
module JACC

import Atomix: @atomic
# module to set back end preferences
include("JACCPreferences.jl")
include("helper.jl")
# overloaded array functions
include("array.jl")

export Array, @atomic
export Array
export parallel_for

global Array
global Tag

function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
struct ThreadsTag end

function parallel_for(::ThreadsTag, N::I, f::F, x...) where {I <: Integer, F <: Function}
@maybe_threaded for i in 1:N
f(i, x...)
end
end

function parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
@inline function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
parallel_for(Tag(), N, f, x...)
end

function parallel_for(::ThreadsTag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
@maybe_threaded for j in 1:N
for i in 1:M
f(i, j, x...)
end
end
end

function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
@inline function parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
parallel_for(Tag(), (M, N), f, x...)
end

function parallel_reduce(::ThreadsTag, N::I, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
@maybe_threaded for i in 1:N
Expand All @@ -39,7 +47,11 @@ function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
return ret
end

function parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
@inline function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
parallel_reduce(Tag(), N, f, x...)
end

function parallel_reduce(::ThreadsTag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
@maybe_threaded for j in 1:N
Expand All @@ -53,9 +65,13 @@ function parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F
return ret
end

@inline function parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
parallel_reduce(Tag(), (M, N), f, x...)
end

function __init__()
const JACC.Array = Base.Array{T, N} where {T, N}
const JACC.Tag = ThreadsTag
end


Expand Down

0 comments on commit 025175b

Please sign in to comment.