-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added JACCASYNC API for threads.jl backend. Note that JACC.Async work…
…s as regular JACC on threads.jl backend. The idea of JACC.Async is to enable JACC for multi-device (GPUs) concurrent executions.
- Loading branch information
pedrovalerolara
committed
Dec 18, 2024
1 parent
0aaff70
commit bdad0ff
Showing
2 changed files
with
64 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
module Async | ||
|
||
using JACC | ||
import JACC: ThreadsBackend | ||
|
||
function Array(::ThreadsBackend, queue_id::I, x::Base.Array{T,N}) where {I <: Integer, T, N} | ||
return JACC.Array(x) | ||
end | ||
|
||
function copy(::ThreadsBackend, queue_id_dest::I, x::Base.Array{T,N}, queue_id_orig::I, y::Base.Array{T,N}) where {I <: Integer, T, N} | ||
copyto!(x, y) | ||
end | ||
|
||
function parallel_for(::ThreadsBackend, queue_id::I, N::I, f::F, x...) where {I <: Integer, F <: Function} | ||
JACC.parallel_for(N, f, x...) | ||
end | ||
|
||
function parallel_for(::ThreadsBackend, queue_id::I, (M, N)::Tuple{I,I}, f::F, x...) where {I <: Integer, F <: Function} | ||
JACC.parallel_for((M, N), f, x...) | ||
end | ||
|
||
function parallel_reduce(::ThreadsBackend, queue_id::I, N::I, f::F, x...) where {I <: Integer, F <: Function} | ||
return JACC.parallel_reduce(N, f, x...) | ||
end | ||
|
||
function parallel_reduce(::ThreadsBackend, queue_id::I, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} | ||
return JACC.parallel_reduce((M, N), f, x...) | ||
end | ||
|
||
function synchronize(::ThreadsBackend) | ||
end | ||
|
||
function Array(queue_id::I, x::Base.Array{T,N}) where {I <: Integer, T, N} | ||
return Array(JACC.default_backend(), queue_id, x) | ||
end | ||
|
||
function copy(queue_id_dest::I, x::Base.Array{T,N}, queue_id_orig::I, y::Base.Array{T,N}) where {I <: Integer, T, N} | ||
return copy(JACC.default_backend(), queue_id_dest, x, queue_id_orig, y) | ||
end | ||
|
||
function parallel_for(queue_id::I, N::I, f::F, x...) where {I <: Integer, F <: Function} | ||
return parallel_for(JACC.default_backend(), queue_id, N, f, x...) | ||
end | ||
|
||
function parallel_for(queue_id::I, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} | ||
return parallel_for(JACC.default_backend(), queue_id, (M, N), f, x...) | ||
end | ||
|
||
function parallel_reduce(queue_id::I, N::I, f::F, x...) where {I <: Integer, F <: Function} | ||
return parallel_reduce(JACC.default_backend(), queue_id, N, f, x...) | ||
end | ||
|
||
function parallel_reduce(queue_id::I, (M, N)::Tuple{I, I}, f::F, x...) where {I <: Integer, F <: Function} | ||
return parallel_reduce(JACC.default_backend(), queue_id, (M, N), f, x...) | ||
end | ||
|
||
function synchronize() | ||
return synchronize(JACC.default_backend()) | ||
end | ||
|
||
end # module Async |