diff --git a/ext/JACCAMDGPU/JACCAMDGPU.jl b/ext/JACCAMDGPU/JACCAMDGPU.jl index 728814d..c3bb884 100644 --- a/ext/JACCAMDGPU/JACCAMDGPU.jl +++ b/ext/JACCAMDGPU/JACCAMDGPU.jl @@ -4,7 +4,9 @@ using JACC, AMDGPU include("array.jl") -function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} +struct AMDGPUTag end + +function JACC.parallel_for(::AMDGPUTag, N::I, f::F, x...) where {I <: Integer, F <: Function} numThreads = 512 threads = min(N, numThreads) blocks = ceil(Int, N / threads) @@ -12,7 +14,7 @@ function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} AMDGPU.synchronize() end -function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_for(::AMDGPUTag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} numThreads = 16 Mthreads = min(M, numThreads) Nthreads = min(N, numThreads) @@ -22,7 +24,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, AMDGPU.synchronize() end -function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_reduce(::AMDGPUTag, N::I, f::F, x...) where {I <: Integer, F <: Function} numThreads = 512 threads = min(N, numThreads) blocks = ceil(Int, N / threads) @@ -36,7 +38,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Functi end -function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_reduce(::AMDGPUTag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} numThreads = 16 Mthreads = min(M, numThreads) Nthreads = min(N, numThreads) @@ -303,6 +305,7 @@ end function __init__() const JACC.Array = AMDGPU.ROCArray{T, N} where {T, N} + const JACC.Tag = AMDGPUTag end end # module JACCAMDGPU diff --git a/ext/JACCCUDA/JACCCUDA.jl b/ext/JACCCUDA/JACCCUDA.jl index 26b5882..bc9a331 100644 --- a/ext/JACCCUDA/JACCCUDA.jl +++ b/ext/JACCCUDA/JACCCUDA.jl @@ -5,7 +5,9 @@ using JACC, CUDA # overloaded array functions include("array.jl") -function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} +struct CUDATag end + +function JACC.parallel_for(::CUDATag, N::I, f::F, x...) where {I <: Integer, F <: Function} parallel_args = (N, f, x...) parallel_kargs = cudaconvert.(parallel_args) parallel_tt = Tuple{Core.Typeof.(parallel_kargs)...} @@ -14,9 +16,14 @@ function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} threads = min(N, maxPossibleThreads) blocks = ceil(Int, N / threads) parallel_kernel(parallel_kargs...; threads=threads, blocks=blocks) + +# maxPossibleThreads = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X) +# threads = min(N, maxPossibleThreads) +# blocks = ceil(Int, N / threads) +# CUDA.@sync @cuda threads = threads blocks = blocks _parallel_for_cuda(f, x...) end -function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_for(::CUDATag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} numThreads = 16 Mthreads = min(M, numThreads) Nthreads = min(N, numThreads) @@ -25,7 +32,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, CUDA.@sync @cuda threads = (Mthreads, Nthreads) blocks = (Mblocks, Nblocks) _parallel_for_cuda_MN(f, x...) end -function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_reduce(::CUDATag, N::I, f::F, x...) where {I <: Integer, F <: Function} numThreads = 512 threads = min(N, numThreads) blocks = ceil(Int, N / threads) @@ -37,7 +44,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Functi end -function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_reduce(::CUDATag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} numThreads = 16 Mthreads = min(M, numThreads) Nthreads = min(N, numThreads) @@ -304,6 +311,7 @@ end function __init__() const JACC.Array = CUDA.CuArray{T, N} where {T, N} + const JACC.Tag = CUDATag end end # module JACCCUDA diff --git a/ext/JACCONEAPI/JACCONEAPI.jl b/ext/JACCONEAPI/JACCONEAPI.jl index 26c332e..39ddc35 100644 --- a/ext/JACCONEAPI/JACCONEAPI.jl +++ b/ext/JACCONEAPI/JACCONEAPI.jl @@ -3,7 +3,9 @@ module JACCONEAPI using JACC, oneAPI -function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} +struct oneAPITag end + +function JACC.parallel_for(::oneAPITag, N::I, f::F, x...) where {I <: Integer, F <: Function} #maxPossibleItems = oneAPI.oneL0.compute_properties(device().maxTotalGroupSize) maxPossibleItems = 256 items = min(N, maxPossibleItems) @@ -11,7 +13,7 @@ function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} oneAPI.@sync @oneapi items = items groups = groups _parallel_for_oneapi(f, x...) end -function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_for(::oneAPITag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} maxPossibleItems = 16 Mitems = min(M, maxPossibleItems) Nitems = min(N, maxPossibleItems) @@ -20,7 +22,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, oneAPI.@sync @oneapi items = (Mitems, Nitems) groups = (Mgroups, Ngroups) _parallel_for_oneapi_MN(f, x...) end -function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_reduce(::oneAPITag, N::I, f::F, x...) where {I <: Integer, F <: Function} numItems = 256 items = min(N, numItems) groups = ceil(Int, N / items) @@ -31,7 +33,7 @@ function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Functi return rret end -function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} +function JACC.parallel_reduce(::oneAPITag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} numItems = 16 Mitems = min(M, numItems) Nitems = min(N, numItems) @@ -295,6 +297,7 @@ end function __init__() const JACC.Array = oneAPI.oneArray{T, N} where {T, N} + const JACC.Tag = oneAPITag end end # module JACCONEAPI diff --git a/src/JACC.jl b/src/JACC.jl index ec12153..cfa11cc 100644 --- a/src/JACC.jl +++ b/src/JACC.jl @@ -1,25 +1,29 @@ -__precompile__(false) +# __precompile__(false) module JACC -import Atomix: @atomic # module to set back end preferences include("JACCPreferences.jl") include("helper.jl") -# overloaded array functions -include("array.jl") -export Array, @atomic +export Array export parallel_for global Array +global Tag -function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} +struct ThreadsTag end + +function parallel_for(::ThreadsTag, N::I, f::F, x...) where {I <: Integer, F <: Function} @maybe_threaded for i in 1:N f(i, x...) end end -function parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} +@inline function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function} + parallel_for(Tag(), N, f, x...) +end + +function parallel_for(::ThreadsTag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} @maybe_threaded for j in 1:N for i in 1:M f(i, j, x...) @@ -27,7 +31,11 @@ function parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: end end -function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function} +@inline function parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} + parallel_for(Tag(), (M, N), f, x...) +end + +function parallel_reduce(::ThreadsTag, N::I, f::F, x...) where {I <: Integer, F <: Function} tmp = zeros(Threads.nthreads()) ret = zeros(1) @maybe_threaded for i in 1:N @@ -39,7 +47,11 @@ function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function} return ret end -function parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} +@inline function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function} + parallel_reduce(Tag(), N, f, x...) +end + +function parallel_reduce(::ThreadsTag, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} tmp = zeros(Threads.nthreads()) ret = zeros(1) @maybe_threaded for j in 1:N @@ -53,9 +65,13 @@ function parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F return ret end +@inline function parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} + parallel_reduce(Tag(), (M, N), f, x...) +end function __init__() const JACC.Array = Base.Array{T, N} where {T, N} + const JACC.Tag = ThreadsTag end