-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMPIComm.jl
91 lines (69 loc) · 2.41 KB
/
MPIComm.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import MPI
export MPIComm
"""
MPIComm()
MPIComm(comm::MPI.Comm)
An implementation of Comm using MPI
The no argument constructor uses MPI.COMM_WORLD
"""
struct MPIComm{GID <: Integer, PID <:Integer, LID <: Integer} <: Comm{GID, PID, LID}
mpiComm::MPI.Comm
end
function MPIComm(GID::Type, PID::Type, LID::Type)
MPIInit()
comm = MPIComm{GID, PID, LID}(MPI.COMM_WORLD)
comm
end
MPINeedsInitialization = true
"""
MPIInit()
On the first call, initializes MPI and adds an exit hook to finalize MPI
Does nothing on subsequent calls
"""
function MPIInit()
global MPINeedsInitialization
if MPINeedsInitialization
MPI.Init()
atexit(() -> MPI.Finalize())
MPINeedsInitialization = false
end
end
function barrier(comm::MPIComm)
MPI.Barrier(comm.mpiComm)
end
function broadcastAll(comm::MPIComm, myvals::AbstractArray{T}, root::Integer)::Array{T} where T
vals = copy(myvals)
result = MPI.Bcast!(vals, root-1, comm.mpiComm)
result
end
function gatherAll(comm::MPIComm, myVals::AbstractArray{T})::Array{T} where T
lengths = MPI.Allgather([convert(Cint, length(myVals))], comm.mpiComm)
MPI.Allgatherv(myVals, lengths, comm.mpiComm)
end
function sumAll(comm::MPIComm, partialsums::AbstractArray{T})::Array{T} where T
MPI.allreduce(partialsums, +, comm.mpiComm)
end
function maxAll(comm::MPIComm, partialmaxes::AbstractArray{T})::Array{T} where T
MPI.allreduce(partialmaxes, max, comm.mpiComm)
end
function maxAll(comm::MPIComm, partialmaxes::AbstractArray{Bool})::Array{Bool}
Array{Bool}(maxAll(comm, Array{UInt8}(partialmaxes)))
end
function minAll(comm::MPIComm, partialmins::AbstractArray{T})::Array{T} where T
MPI.allreduce(partialmins, min, comm.mpiComm)
end
function minAll(comm::MPIComm, partialmins::AbstractArray{Bool})::Array{Bool}
Array{Bool}(minAll(comm, Array{UInt8}(partialmins)))
end
function scanSum(comm::MPIComm, myvals::AbstractArray{T})::Array{T} where T
MPI.Scan(myvals, length(myvals), MPI.SUM, comm.mpiComm)
end
function myPid(comm::MPIComm{GID, PID})::PID where {GID <: Integer, PID <: Integer}
MPI.Comm_rank(comm.mpiComm) + 1
end
function numProc(comm::MPIComm{GID, PID})::PID where {GID <: Integer, PID <:Integer}
MPI.Comm_size(comm.mpiComm)
end
function createDistributor(comm::MPIComm{GID, PID, LID})::MPIDistributor{GID, PID, LID} where {GID <: Integer, PID <: Integer, LID <: Integer}
MPIDistributor(comm)
end