Skip to content

Commit

Permalink
Add __precompile__(false)
Browse files Browse the repository at this point in the history
Prevent from triggering Julia v1.10 bug
Enable passing any type (for now)
  • Loading branch information
williamfgc committed Apr 17, 2024
1 parent 9c7e2f9 commit 555bd77
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 21 deletions.
8 changes: 4 additions & 4 deletions ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ module JACCAMDGPU

using JACC, AMDGPU

function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number, <:ROCArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
@roc groupsize = threads gridsize = blocks _parallel_for_amdgpu(f, x...)
AMDGPU.synchronize()
end

function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:ROCArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -20,7 +20,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number,
AMDGPU.synchronize()
end

function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number, <:ROCArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -34,7 +34,7 @@ function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number, <:ROCArray}}

end

function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:ROCArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down
8 changes: 4 additions & 4 deletions ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ module JACCCUDA

using JACC, CUDA

function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number, <:CuArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
maxPossibleThreads = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X)
threads = min(N, maxPossibleThreads)
blocks = ceil(Int, N / threads)
CUDA.@sync @cuda threads = threads blocks = blocks _parallel_for_cuda(f, x...)
end

function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:CuArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand All @@ -18,7 +18,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number,
CUDA.@sync @cuda threads = (Mthreads, Nthreads) blocks = (Mblocks, Nblocks) _parallel_for_cuda_MN(f, x...)
end

function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number, <:CuArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 512
threads = min(N, numThreads)
blocks = ceil(Int, N / threads)
Expand All @@ -30,7 +30,7 @@ function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number, <:CuArray}})
end


function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:CuArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numThreads = 16
Mthreads = min(M, numThreads)
Nthreads = min(N, numThreads)
Expand Down
8 changes: 4 additions & 4 deletions ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ module JACCONEAPI

using JACC, oneAPI

function JACC.parallel_for(N::I, f::F, x::Vararg{Union{<:Number, <:oneArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
#maxPossibleItems = oneAPI.oneL0.compute_properties(device().maxTotalGroupSize)
maxPossibleItems = 256
items = min(N, maxPossibleItems)
groups = ceil(Int, N / items)
oneAPI.@sync @oneapi items = items groups = groups _parallel_for_oneapi(f, x...)
end

function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:oneArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
maxPossibleItems = 16
Mitems = min(M, maxPossibleItems)
Nitems = min(N, maxPossibleItems)
Expand All @@ -20,7 +20,7 @@ function JACC.parallel_for((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number,
oneAPI.@sync @oneapi items = (Mitems, Nitems) groups = (Mgroups, Ngroups) _parallel_for_oneapi_MN(f, x...)
end

function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number, <:oneArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
numItems = 256
items = min(N, numItems)
groups = ceil(Int, N / items)
Expand All @@ -31,7 +31,7 @@ function JACC.parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number, <:oneArray}}
return rret
end

function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:oneArray}}) where {I <: Integer, F <: Function}
function JACC.parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
numItems = 16
Mitems = min(M, numItems)
Nitems = min(N, numItems)
Expand Down
15 changes: 6 additions & 9 deletions src/JACC.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__precompile__(false)
module JACC

# module to set back end preferences
Expand All @@ -9,21 +10,21 @@ export parallel_for

global Array

function parallel_for(N::I, f::F, x::Vararg{Union{<:Number, <:Base.Array}}) where {I <: Integer, F <: Function}
function parallel_for(N::I, f::F, x...) where {I <: Integer, F <: Function}
@maybe_threaded for i in 1:N
f(i, x...)
end
end

function parallel_for((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:Base.Array}}) where {I <: Integer, F <: Function}
function parallel_for((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
@maybe_threaded for j in 1:N
for i in 1:M
f(i, j, x...)
end
end
end

function parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number, <:Base.Array}}) where {I <: Integer, F <: Function}
function parallel_reduce(N::I, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
@maybe_threaded for i in 1:N
Expand All @@ -35,7 +36,7 @@ function parallel_reduce(N::I, f::F, x::Vararg{Union{<:Number, <:Base.Array}}) w
return ret
end

function parallel_reduce((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:Base.Array}}) where {I <: Integer, F <: Function}
function parallel_reduce((M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function}
tmp = zeros(Threads.nthreads())
ret = zeros(1)
@maybe_threaded for j in 1:N
Expand All @@ -50,11 +51,7 @@ function parallel_reduce((M, N)::Tuple{I, I}, f::F, x::Vararg{Union{<:Number, <:
end

function __init__()
@info("Using JACC backend: $(JACCPreferences.backend)")

if JACCPreferences.backend == "threads"
const JACC.Array = Base.Array{T, N} where {T, N}
end
const JACC.Array = Base.Array{T, N} where {T, N}
end


Expand Down

0 comments on commit 555bd77

Please sign in to comment.