diff --git a/REQUIRE b/REQUIRE new file mode 100644 index 0000000..1a0c5d6 --- /dev/null +++ b/REQUIRE @@ -0,0 +1,2 @@ +julia 0.6 +MPI diff --git a/src/BasicDirectory.jl b/src/BasicDirectory.jl new file mode 100644 index 0000000..05de9d2 --- /dev/null +++ b/src/BasicDirectory.jl @@ -0,0 +1,251 @@ +export BasicDirectory + +""" + BasicDirectory(map::BlockMap) +Creates a BasicDirectory, which implements the methods of Directory with +basic implmentations +""" +type BasicDirectory{GID <: Integer, PID <:Integer, LID <: Integer} <: Directory{GID, PID, LID} + map::BlockMap{GID, PID, LID} + + directoryMap::Nullable{BlockMap} + + procList::Vector{PID} + procListLists::Vector{Vector{PID}} + + entryOnMultipleProcs::Bool + + localIndexList::Vector{LID} + allMinGIDs::Vector{GID} + + function BasicDirectory{GID, PID, LID}(map::BlockMap{GID, PID, LID}) where {GID, PID, LID} + + if !(distributedGlobal(map)) + new(map, Nullable{BlockMap}(), [], [], numProc(comm(map))!=1, [], []) + elseif linearMap(map) + commObj = comm(map) + + allMinGIDs = gatherAll(commObj, minMyGID(map)) + allMinGIDs = vcat(allMinGIDs, [1+maxAllGID(map)]) + + entryOnMultipleProcs = length(Set(allMinGIDs)) != length(allMinGIDs) + + new(map, Nullable{BlockMap}(), [], [], entryOnMultipleProcs, [], allMinGIDs) + else + generateContent( + new(map, Nullable{BlockMap}(), [], [], false, [], []), + map) + end + end +end + +""" +internal method to assist constructor +""" +function generateContent(dir::BasicDirectory{GID, PID, LID}, map::BlockMap{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + minAllGID = minAllGID(map) + maxAllGID = maxAllGID(map) + + dirNumGlobalElements = maxAllGID - minAllGID + 1 + + directoryMap = BlockMap(dirNumGlobalElements, minAllGID, commObj) + dir.directoryMap = Nullable(directoryMap) + + dirNumMyElements = numMyElements(dir.directoryMap) + + if dirNumMyElements > 0 + dir.procList = Array{PID}(dirNumMyElements) + dir.localIndexList = Array{LID}(dirNumMyElements) + + fill!(dir.procList, -1) + fill!(dir.localIndexList, -1) + else + dir.procList = [] + dir.localIndexList = [] + end + + map_numMyElements = numMyElements(map) + map_myGlobalElements = myGlobalElements(map) + + sendProcs = remoteIDList(directoryMap, map_numMyElements, myGlobalElements(map)) + + distributor = createDistributor(commObj) + numRecvs = createFromSends(distributor, map_numMyElements, sendProcs) + + exportElements = Array{Tuple{GID, PID, LID}}(numMyElements) + + myPIDVal = myPID(commObj) + for i = 1:numMyElements + exportElements[i] = (map_myGlobalElements[i], myPIDVal, i) + end + + importElements = resolve(distributor, exportElements) + + + for i = 1:numRecvs + currLID = lid(directoryMap, importElements[i][1]) + @assert currLID > 0 //internal error + + proc = importElements[i][2] + if dir.procList[currLID] >= 0 + if dir.procList[currLID] != proc + if dir.procListLists == [] + numProcLists = numMyElements(directoryMap) + procListLists = Array{Array{PID}}(numProcLists) + fill!(procListLists, []) + end + + l = procListLists[currLID] + + index = searchsortedfirst(l, procList[currLID]) + insert(l, index, procList[currLID]) + + index = searchsortedfirst(l, proc) + insert(l, index, proc) + + dir.procList[currLID] = dir.procListLists[curr_LID][1] + end + else + dir.procList[currLID] = proc + end + + dir.localIndexList[currLID] = importElements[i][3] + end + + globalVal = maxAll(commObj, numProcLists) + dir.entryOnMultipleProcs = globalval > 0 ? true : false; + + dir +end + + +function getDirectoryEntries(directory::BasicDirectory{GID, PID, LID}, map::BlockMap{GID, PID, LID}, globalEntries::AbstractVector{GID}, + high_rank_sharing_procs::Bool)::Tuple{Vector{PID}, Vector{LID}} where GID <: Integer where PID <: Integer where LID <: Integer + numEntries = length(globalEntries) + procs = Vector{PID}(numEntries) + localEntries = Vector{LID}(numEntries) + + if !distributedGlobal(map) + myPIDVal = myPid(comm(map)) + + for i = 1:numEntries + lidVal = lid(map, globalEntries[i]) + + if lidVal == 0 + procs[i] = 0 + warn("GID $(globalEntries[i]) is not part of this map") + else + procs[i] = myPIDVal + end + localEntries[i] = lidVal + end + elseif linearMap(map) + minAllGIDVal = minAllGID(map) + maxAllGIDVal = maxAllGID(map) + + numProcVal = numProc(comm(map)) + + n_over_p = numGlobalElements(map)/numProcVal + + allMinGIDs_list = copy(directory.allMinGIDs) + order = sortperm(allMinGIDs_list) + permute!(allMinGIDs_list, order) + + for i = 1:numEntries + lid = 0 + proc = 0 + + gid = globalEntries[i] + if gid < minAllGIDVal || gid > maxAllGIDVal + throw(InvalidArgumentError("GID=$gid out of valid range [$minAllGIDVal, $maxAllGIDVal]")) + end + #guess uniform distribution and start a little above it + proc1 = min(GID(fld(gid, max(n_over_p, 1)) + 2), numProcVal) + proc1 = 1 + found = false + + while proc1 >= 1 && proc1 <= numProcVal + if allMinGIDs_list[proc1] <= gid + if (gid < allMinGIDs_list[proc1+1]) + found = true + break + else + proc1 += 1 + end + else + proc1 -= 1 + end + end + if found + proc = order[proc1] + lid = gid - allMinGIDs_list[proc1] + 1 + end + + procs[i] = proc + localEntries[i] = lid + end + else #general case + distributor = createDistributor(comm(map)) + + dirProcs = remoteIDList(map, numEntries, globalEntries) + + numMissing = 0 + for i = 1:numEntries + if dirProcs[i] == 0 + procs[i] = 0 + localEntries[i] = 0 + numMissing += 1 + end + end + + (sendGIDs, sendPIDs) = createFromRecvs(distrbutor, globalEntries, dirProcs) + numSends = length(sendGIDs) + + if numSends > 0 + exports = Array{Tuple{GID, PID, LID}}(numSends) + for i = 1:numSends + currGID = sendGIDs[i] + exports[i][1] = currGID + + currLID = lid(map, currGID) + @assert currLID > 0 #internal error + if !high_rank_sharing_procs + exports[i][2] = procList[currLID] + else + numProcLists = numMyElements(directory.directoryMap) + if numProcLists > 0 + num = length(directory.procListLists[currLID]) + if num > 1 + exports[i][2] = directory.procListLists[currLID][num] + else + exports[i][2] = directory.procList[currLID] + end + else + exports[i][2] = directory.procList[currLID] + end + end + exports[i][3] = directory.localIndexList[currLID] + end + end + + numRecv = numEntries - numMissing + imports = resolve(distributor, exports) + + offsets = sortperm(globalEntries) + sortedGE = globalEntries[offsets] + + for i = 1:numRecv + currLID = imports[i][1] + j = searchsortedfirst(sortedGE, currLID) + if j > 0 + procs[offsets[j]] = imports[i][2] + localEntries[offsets[j]] = imports[i][3] + end + end + end + (procs, localEntries) +end + +function gidsAllUniquelyOwned(directory::BasicDirectory) + !directory.entryOnMultipleProcs +end diff --git a/src/BlockMap.jl b/src/BlockMap.jl new file mode 100644 index 0000000..66efce2 --- /dev/null +++ b/src/BlockMap.jl @@ -0,0 +1,721 @@ +export BlockMap +export remoteIDList, lid, gid, findLocalElementID +export minAllGID, maxAllGID, minMyGID, maxMyGID, minLID, maxLID +export numGlobalElements, myGlobalElements +export uniqueGIDs, globalIndicesType, sameBlockMapDataAs, sameAs +export linearMap, myGlobalElementIDs, comm +export myGID, myLID, distributedGlobal, numMyElements + + +# methods and docs based straight off Epetra_BlockMap to match Comm + +# ignoring indexBase methods and sticking with 1-based indexing +# ignoring elementSize methods since, type information is carried anyways +# ignoring point-related code, since elementSize is ignored + + +# TODO figure out expert users and developers only functions + +""" +A type for partitioning block element vectors and matrices +""" +struct BlockMap{GID <: Integer, PID <:Integer, LID <: Integer} + data::BlockMapData{GID, PID, LID} + + function BlockMap{GID, PID, LID}(data::BlockMapData) where {GID <: Integer, PID <:Integer, LID <: Integer} + new(data) + end +end + + +""" + BlockMap(numGlobalElements, comm) + +Constructor for petra-defined uniform linear distribution of elements +""" +function BlockMap(numGlobalElements::Integer, comm::Comm{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + BlockMap(GID(numGlobalElements), comm) +end + +function BlockMap(numGlobalElements::GID, comm::Comm{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + if numGlobalElements < 0 + throw(InvalidArgumentError("NumGlobalElements = $(numGlobalElements). Should be >= 0")) + end + + const data = BlockMapData(numGlobalElements, comm) + const map = BlockMap{GID, PID, LID}(data) + + const numProcVal = numProc(comm) + data.linearMap = true + + const myPIDVal = myPid(comm) - 1 + + data.numMyElements = floor(typeof(data.numGlobalElements), + data.numGlobalElements/numProcVal) + remainder = data.numGlobalElements % numProcVal + startIndex = myPIDVal * (data.numMyElements+1) + + if myPIDVal < remainder + data.numMyElements += 1 + else + startIndex -= (myPIDVal - remainder) + end + + data.minAllGID = 1 + data.maxAllGID = data.minAllGID + data.numGlobalElements - 1 + data.minMyGID = startIndex + 1 + data.maxMyGID = data.minMyGID + data.numMyElements - 1 + data.distributedGlobal = isDistributedGlobal(map, data.numGlobalElements, + data.numMyElements) + + EndOfConstructorOps(map) + map +end + +""" + BlockMap(numGlobalElements, numMyElements, comm) + +Constructor for user-defined linear distribution of elements +""" +function BlockMap(numGlobalElements::Integer, numMyElements::Integer, comm::Comm{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + BlockMap(GID(numGlobalElements), LID(numMyElements), comm) +end + +function BlockMap(numGlobalElements::GID, numMyElements::LID, comm::Comm{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + if numGlobalElements < -1 + throw(InvalidArgumentError("NumGlobalElements = $(numGlobalElements). Should be >= -1")) + end + if numMyElements < 0 + throw(InvalidArgumentError("NumMyElements = $(numMyElements). Should be >= 0")) + end + + const data = BlockMapData(numGlobalElements, comm) + const map = BlockMap{GID, PID, LID}(data) + + data.numMyElements = numMyElements + data.linearMap = true + + data.distributedGlobal = isDistributedGlobal(map, numGlobalElements, numMyElements) + + #Local Map and uniprocessor case: Each processor gets a complete copy of all elements + if !data.distributedGlobal || numProc(comm) == 1 + data.numGlobalElements = data.numMyElements + + data.minAllGID = 1 + data.maxAllGID = data.minAllGID + data.numGlobalElements - 1 + data.minMyGID = 1 + data.maxMyGID = data.minMyGID + data.numMyElements - 1 + else + tmp_numMyElements = data.numMyElements + data.numGlobalElements = sumAll(data.comm, tmp_numMyElements) + + data.minAllGID = 1 + data.maxAllGID = data.minAllGID + data.numGlobalElements - 1 + + tmp_numMyElements = data.numMyElements + data.maxMyGID = scanSum(data.comm, tmp_numMyElements) + + startIndex = data.maxMyGID - data.numMyElements + data.minMyGID = startIndex + 1 + data.maxMyGID = data.minMyGID + data.numMyElements - 1 + end + checkValidNGE(map, numGlobalElements) + + EndOfConstructorOps(map) + map +end + + +""" + BlockMap(myGlobalElements, comm) + +Constructor for user-defined arbitrary distribution of elements +""" +function BlockMap(myGlobalElements::AbstractArray{<:Integer}, comm::Comm{GID, PID,LID} + ) where GID <: Integer where PID <: Integer where LID <: Integer + BlockMap(Array{GID}(myGlobalElements), comm) +end + +function BlockMap(myGlobalElements::AbstractArray{GID}, comm::Comm{GID, PID,LID} + ) where GID <: Integer where PID <: Integer where LID <: Integer + numMyElements = LID(length(myGlobalElements)) + + const data = BlockMapData(GID(0), comm) + const map = BlockMap{GID, PID, LID}(data) + + data.numMyElements = numMyElements + + linear = 1 + if numMyElements > 0 + data.myGlobalElements = Array{GID, 1}(numMyElements) + + data.myGlobalElements[1] = myGlobalElements[1] + data.minMyGID = myGlobalElements[1] + data.maxMyGID = myGlobalElements[1] + + for i = 2:numMyElements + data.myGlobalElements[i] = myGlobalElements[i] + data.minMyGID = min(data.minMyGID, myGlobalElements[i]) + data.maxMyGID = max(data.maxMyGID, myGlobalElements[i]) + + if myGlobalElements[i] != myGlobalElements[i-1] + 1 + linear = 0 + end + end + else + data.minMyGID = 1 + data.maxMyGID = 0 + end + + data.linearMap = Bool(minAll(data.comm, linear)) + + if numProc(comm) == 1 + data.numGlobalElements = data.numMyElements + data.minAllGID = data.minMyGID + data.maxAllGID = data.maxMyGID + else + tmp_send = [ + -((data.numMyElements > 0)? + data.minMyGID:Inf) + , data.maxMyGID] + + tmp_recv = maxAll(data.comm, tmp_send) + + @assert typeof(tmp_recv[1]) <: Integer "Result type is $(typeof(tmp_recv[1])), should be subtype of Integer" + + data.minAllGID = -tmp_recv[1] + data.maxAllGID = tmp_recv[2] + + if data.linearMap + data.numGlobalElements = sumAll(data.comm, data.numMyElements) + else + #if 1+ GIDs shared between processors, need to total that correctly + allIDs = gatherAll(data.comm, myGlobalElements) + + indexModifier = 1 - data.minAllGID + maxGID = data.maxAllGID + + count = 0 + arr = falses(maxGID + indexModifier) + for id in allIDs + if !arr[GID(id + indexModifier)] + arr[GID(id + indexModifier)] = true + count += 1 + end + end + data.numGlobalElements = count + end + end + + data.distributedGlobal = isDistributedGlobal(map, data.numGlobalElements, numMyElements) + + EndOfConstructorOps(map) + map +end + +""" + BlockMap(numGlobalElements, numMyElements, myGlobalElements, isDistributedGlobal, minAllGID, maxAllGID, comm) + +Constructor for user-defined arbitrary distribution of elements with all information on globals provided by the user +""" +function BlockMap(numGlobalElements::Integer, numMyElements::Integer, + myGlobalElements::AbstractArray{GID}, userIsDistributedGlobal::Bool, + userMinAllGID::Integer, userMaxAllGID::Integer, comm::Comm{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + BlockMap(GID(numGlobalElements), LID(numMyElements), Array{GID}(myGlobalElements), userIsDistributedGlobal, + GID(userMinAllGID), GID(userMaxAllGID), comm) +end + +function BlockMap(numGlobalElements::GID, numMyElements::LID, + myGlobalElements::AbstractArray{GID}, userIsDistributedGlobal::Bool, + userMinAllGID::GID, userMaxAllGID::GID, comm::Comm{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + if numGlobalElements < -1 + throw(InvalidArgumentError("NumGlobalElements = $(numGlobalElements). Should be >= -1")) + end + if numMyElements < 0 + throw(InvalidArgumentError("NumMyElements = $(numMyElements). Should be >= 0")) + end + if userMinAllGID < 1 + throw(InvalidArgumentError("Minimum global element index = $(data.minAllGID). Should be >= 1")) + end + + const data = BlockMapData(numGlobalElements, comm) + const map = BlockMap{GID, PID, LID}(data) + + data.numMyElements = numMyElements + + linear = 1 + if numMyElements > 0 + data.myGlobalElements = Array{GID}(numMyElements) + + data.myGlobalElements[1] = myGlobalElements[1] + data.minMyGID = myGlobalElements[1] + data.maxMyGID = myGlobalElements[1] + + for i = 2:numMyElements + data.myGlobalElements[i] = myGlobalElements[i] + data.minMyGID = min(data.minMyGID, myGlobalElements[i]) + data.maxMyGID = max(data.maxMyGID, myGlobalElements[i]) + + if myGlobalElements[i] != myGlobalElements[i-1] + 1 + linear = 0 + end + end + + else + data.minMyGID = 1 + data.maxMyGID = 0 + end + + data.linearMap = Bool(minAll(comm, linear)) + + data.distributedGlobal = userIsDistributedGlobal + + if !data.distributedGlobal || numProc(comm) == 1 + data.numGlobalElements = data.numMyElements + checkValidNGE(map, numGlobalElements) + data.minAllGID = data.minMyGID + data.maxAllGID = data.maxMyGID + else + if numGlobalElements == -1 + data.numGlobalElements = sumAll(data.comm, data.numMyElements) + else + data.numGlobalElements = numGlobalElements + end + checkValidNGE(data.numGlobalELements) + + data.minAllGID = userMinAllGID + data.maxAllGID = userMaxAllGID + end + EndOfConstructorOps(map) + map +end + + + +##### internal construction methods ##### +function isDistributedGlobal(map::BlockMap{GID, PID, LID}, numGlobalElements::GID, + numMyElements::LID) where GID <: Integer where PID <: Integer where LID <: Integer + data = map.data + if numProc(data.comm) > 1 + localReplicated = numGlobalElements == numMyElements + !Bool(minAll(data.comm, localReplicated)) + else + false + end +end + +function EndOfConstructorOps(map::BlockMap) + map.data.minLID = 1 + map.data.maxLID = max(map.data.numMyElements, 1) + + GlobalToLocalSetup(map); +end + +function GlobalToLocalSetup(map::BlockMap) + data = map.data + numMyElements = data.numMyElements + myGlobalElements = data.myGlobalElements + + if data.linearMap || numMyElements == 0 + return map + end + if length(data.numGlobalElements) == 0 + return map + end + + + val = myGlobalElements[1] + i = 1 + for i = 1:numMyElements-1 + if val+1 != myGlobalElements[i+1] + break + end + val += 1 + end + + data.lastContiguousGIDLoc = i + if data.lastContiguousGIDLoc <= 1 + data.lastContiguousGID = myGlobalElements[1] + else + data.lastContiguousGID = myGlobalElements[data.lastContiguousGIDLoc] + end + + if i < numMyElements + data.lidHash = empty!(data.lidHash) + + sizehint!(data.lidHash, numMyElements - i + 2) + + for i = i:numMyElements + data.lidHash[myGlobalElements[i]] = i + end + end + map +end + +function checkValidNGE(map::BlockMap{GID, PID, LID}, numGlobalElements::GID) where GID <: Integer where PID <: Integer where LID <: Integer + if (numGlobalElements != -1) && (numGlobalElements != map.data.numGlobalElements) + throw(InvalidArgumentError("Invalid NumGlobalElements. " + * "NumGlobalElements = $(numGlobalElements)" + * ". Should equal $(map.data.numGlobalElements)" + * ", or be set to -1 to compute automatically")) + end +end + +##### external methods ##### + +""" + myGID(map::BlockMap, gidVal::Integer) + +Return true if the GID passed in belongs to the calling processor in this +map, otherwise returns false. +""" +function myGID(map::BlockMap, gidVal::Integer) + lid(map, gidVal) != 0 +end + +""" + myLID(map::BlockMap, lidVal::Integer) + +Return true if the LID passed in belongs to the calling processor in this +map, otherwise returns false. +""" +@inline function myLID(map::BlockMap, lidVal::Integer) + gid(map, lidVal) != 0 +end + +""" + distributedGlobal(map::BlockMap) + +Return true if map is defined across more than one processor +""" +function distributedGlobal(map::BlockMap) + map.data.distributedGlobal +end + +""" + numMyElements(map::BlockMap{GID, PID, LID})::LID + +Return the number of elements across the calling processor +""" +function numMyElements(map::BlockMap{GID, PID, LID})::LID where GID <: Integer where PID <: Integer where LID <: Integer + map.data.numMyElements +end + +""" + minMyGID(map::BlockMap{GID, PID, LID})::GID + +Return the minimum global ID owned by this processor +""" +function minMyGID(map::BlockMap{GID, PID, LID})::GID where GID <: Integer where PID <: Integer where LID <: Integer + map.data.minMyGID +end + +""" + maxMyGID(map::BlockMap{GID, PID, LID})::GID + +Return the maximum global ID owned by this processor +""" +function maxMyGID(map::BlockMap{GID, PID, LID})::GID where GID <: Integer where PID <: Integer where LID <: Integer + map.data.maxMyGID +end + +""" + getLocalMap(::BlockMap{GID, PID, LID})::BlockMap{GID, PID, LID} + +Creates a copy of the given map that doesn't support any inter-process actions +""" +function getLocalMap(map::BlockMap{GID, PID, LID})::BlockMap{GID, PID, LID} where {GID, PID, LID} + oldData = map.data + data = BlockMapData(oldData.numGlobalElements, LocalComm(oldData.comm)) + + data.directory = Nullable{Directory}() + data.lid = copy(oldData.lid) + #maps shouldn't be modified anyways, may as well share array + #data.myGlobalElements = copy(oldData.myGlobalElements) + data.myGlobalElements = oldData.myGlobalElements + data.numMyElements = oldData.numMyElements + data.minAllGID = oldData.minAllGID + data.maxAllGID = oldData.maxAllGID + data.minMyGID = oldData.minMyGID + data.maxMyGID = oldData.maxMyGID + data.minLID = oldData.minLID + data.maxLID = oldData.maxLID + data.linearMap = oldData.linearMap + data.distributedGlobal = oldData.distributedGlobal + data.oneToOneIsDetermined = oldData.oneToOneIsDetermined + data.oneToOne = oldData.oneToOne + data.lastContiguousGID = oldData.lastContiguousGID + data.lastContiguousGIDLoc = oldData.lastContiguousGIDLoc + data.lidHash = copy(oldData.lidHash) + + BlockMap{GID, PID, LID}(data) +end + +##local/global ID accessor methods## + +""" + remoteIDList(map::BlockMap{GID, PID, LID}, gidList::AbstractArray{<: Integer}::Tuple{AbstractArray{PID}, AbstractArray{LID}} + +Return the processor ID and local index value for a given list of global indices. +The returned value is a tuple containing +1. an Array of processors owning the global ID's in question +2. an Array of local IDs of the global on the owning processor +""" +function remoteIDList(map::BlockMap{GID, PID, LID}, gidList::AbstractArray{<:Integer} + )::Tuple{AbstractArray{PID}, AbstractArray{LID}} where GID <: Integer where PID <: Integer where LID <: Integer + remoteIDList(map, Array{GID}(gidList)) +end + +function remoteIDList(map::BlockMap{GID, PID, LID}, gidList::AbstractArray{GID} + )::Tuple{AbstractArray{PID}, AbstractArray{LID}} where GID <: Integer where PID <: Integer where LID <: Integer + data = map.data + if isnull(data.directory) + data.directory = createDirectory(data.comm, map) + end + + getDirectoryEntries(get(data.directory), map, gidList) +end + + +""" + lid(map::BlockMap{GID, PID, LID}, gid::Integer)::LID + +Return local ID of global ID, or 0 if not found on this processor +""" +@inline function lid(map::BlockMap{GID, PID, LID}, gid::Integer) where GID <: Integer where PID <: Integer where LID <: Integer + data = map.data + if (gid < data.minMyGID) || (gid > data.maxMyGID) + LID(0) + elseif data.linearMap + LID(gid - data.minMyGID + 1) + elseif gid >= data.myGlobalElements[1] && gid <= data.lastContiguousGID + LID(gid - data.myGlobalElements[1] + 1) + elseif haskey(data.lidHash, GID(gid)) + data.lidHash[gid] + else + LID(0) + end +end + +""" + gid(map::BlockMap{GID, PID, LID}, lid::Integer)::GID + +Return global ID of local ID, or 0 if not found on this processor +""" +@inline function gid(map::BlockMap{GID, PID, LID}, lid::Integer) where GID <: Integer where PID <: Integer where LID <: Integer + data = map.data + if (data.numMyElements == LID(0)) || (lid < data.minLID) || (lid > data.maxLID) + GID(0) + elseif data.linearMap + GID(lid + data.minMyGID - 1) + else + GID(data.myGlobalElements[lid]) + end +end + + +""" + minAllGID(map::BlockMap{GID, PID, LID})::GID + +Return the minimum global ID across the entire map +""" +function minAllGID(map::BlockMap{GID})::GID where GID <: Integer + map.data.minAllGID +end + +""" + maxAllGID(map::BlockMap{GID, PID, LID})::GID + +Return the maximum global ID across the entire map +""" +function maxAllGID(map::BlockMap{GID})::GID where GID <: Integer + map.data.maxAllGID +end + +""" + minLID(map::BlockMap{GID, PID, LID})::LID + +Return the mimimum local index value on the calling processor +""" +function minLID(map::BlockMap{GID, PID, LID})::LID where GID <: Integer where PID <: Integer where LID <: Integer + map.data.minLID +end + +""" + maxLID(map::BlockMap{GID, PID, LID})::LID + +Return the maximum local index value on the calling processor +""" +function maxLID(map::BlockMap{GID, PID, LID})::LID where GID <: Integer where PID <: Integer where LID <: Integer + map.data.maxLID +end + +##size/dimension accessor functions## + +""" + numGlobalElements(map::BlockMap{GID, PID, LID})::GID + +Return the number of elements across all processors +""" +function numGlobalElements(map::BlockMap{GID})::GID where GID <: Integer + map.data.numGlobalElements +end + +""" + myGlobalElements(map::BlockMap{GID, PID, LID})::AbstractArray{GID} + +Return a list of global elements on this processor +""" +function myGlobalElements(map::BlockMap{GID})::AbstractArray{GID} where GID <: Integer + data = map.data + + if length(data.myGlobalElements) == 0 + myGlobalElements = Vector{GID}(data.numMyElements) + @inbounds for i = GID(1):GID(data.numMyElements) + myGlobalElements[i] = data.minMyGID + i - 1 + end + data.myGlobalElements = myGlobalElements + else + data.myGlobalElements + end +end + + +##Miscellaneous boolean tests## + +""" + uniqueGIDs(map::BlockMap)::Bool + +Return true if each map GID exists on at most 1 processor +""" +function uniqueGIDs(map::BlockMap)::Bool + isOneToOne(map) +end + + +""" + globalIndicesType(map::BlockMap{GID, PID, LID})::Type{GID} + +Return the type used for global indices in the map +""" +function globalIndicesType(map::BlockMap{GID})::Type{GID} where GID <: Integer + GID +end + +""" + sameBlockMapDataAs(this::BlockMap, other::BlockMap)::Bool + +Return true if the maps have the same data +""" +function sameBlockMapDataAs(this::BlockMap, other::BlockMap)::Bool + this.data == other.data +end + +""" + sameAs(this::BlockMap, other::BlockMap)::Bool + +Return true if this and other are identical maps +""" +function sameAs(this::BlockMap, other::BlockMap) + # behavior by specification + false +end + +function sameAs(this::BlockMap{GID, PID, LID}, other::BlockMap{GID, PID, LID})::Bool where GID <: Integer where PID <: Integer where LID <: Integer + tData = this.data + oData = other.data + if tData == oData + return true + end + + if ((tData.minAllGID != oData.minAllGID) + || (tData.maxAllGID != oData.maxAllGID) + || (tData.numGlobalElements != oData.numGlobalElements)) + return false + end + + mySameMap = 1 + + if tData.numMyElements != oData.numMyElements + mySameMap = 0 + end + + if tData.linearMap && oData.linearMap + # For linear maps, just need to check whether lower bound is the same + if tData.minMyGID != oData.minMyGID + mySameMap = 0 + end + else + for i = 1:tData.numMyElements + if gid(this, i) != gid(other, i) + mySameMap = 0 + break + end + end + end + + Bool(minAll(tData.comm, mySameMap)) +end + + +""" + linearMap(map::BlockMap)::Bool + +Return true if the global ID space is contiguously divided (but +not necessarily uniformly) across all processors +""" +function linearMap(map::BlockMap)::Bool + map.data.linearMap +end + + +##Array accessor functions## + +""" + myGlobalElementsIDs map::BlockMap{GID, PID, LID})::AbstractArray{GID} + +Return list of global IDs assigned to the calling processor +""" +function myGlobalElementIDs(map::BlockMap{GID})::AbstractArray{GID} where GID <: Integer + data = map.data + if length(data.myGlobalElements) == 0 + base = 0:data.numMyElements-1 + rng = data.minMyGID + base + myGlobalElements = collect(rng) + else + myGlobalElements = copy(data.myGlobalElements) + end + + myGlobalElements +end + + +function isOneToOne(map::BlockMap)::Bool + data = map.data + if !(data.oneToOneIsDetermined) + data.oneToOne = determineIsOneToOne(map) + data.oneToOneIsDetermined = true + end + data.oneToOne +end + +function determineIsOneToOne(map::BlockMap)::Bool + data = map.data + if numProc(data.comm) < 2 + true + else + if isnull(data.directory) + data.directory = Nullable(createDirectory(data.comm, map)) + end + gidsAllUniquelyOwned(get(data.directory)) + end +end + +""" + comm(map::BlockMap{GID, PID, LID})::Comm{GID, PID, LID} + +Return the Comm for the map +""" +function comm(map::BlockMap{GID, PID, LID})::Comm{GID, PID, LID} where GID <: Integer where PID <: Integer where LID <: Integer + map.data.comm +end diff --git a/src/BlockMapData.jl b/src/BlockMapData.jl new file mode 100644 index 0000000..85c559b --- /dev/null +++ b/src/BlockMapData.jl @@ -0,0 +1,63 @@ +""" +Contains the data for a BlockMap +""" +type BlockMapData{GID <: Integer, PID <:Integer, LID <: Integer} + comm::Comm{GID, PID, LID} + directory::Nullable{Directory} + lid::Vector{LID} + myGlobalElements::Vector{GID} +# firstPointInElementList::Array{Integer} +# elementSizeList::Array{Integer} +# pointToElementList::Array{Integer} + + numGlobalElements::GID + numMyElements::LID +# elementSize::Integer +# minMyElementSize::Integer +# maxMyElementSize::Integer +# minElementSize::Integer +# maxElementSize::Integer + minAllGID::GID + maxAllGID::GID + minMyGID::GID + maxMyGID::GID + minLID::LID + maxLID::LID +# numGlobalPoints::Integer +# numMyPoints::Integer + +# constantElementSize::Bool + linearMap::Bool + distributedGlobal::Bool + oneToOneIsDetermined::Bool + oneToOne::Bool + lastContiguousGID::GID + lastContiguousGIDLoc::GID + lidHash::Dict{GID, LID} +end + +function BlockMapData(numGlobalElements::GID, comm::Comm{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + BlockMapData( + comm, + Nullable{Directory}(), + LID[], + GID[], + + numGlobalElements, + LID(0), + GID(0), + GID(0), + GID(0), + GID(0), + LID(0), + LID(0), + + false, + false, + false, + false, + GID(0), + GID(0), + Dict{GID, LID}() + ) +end diff --git a/src/CRSGraphConstructors.jl b/src/CRSGraphConstructors.jl new file mode 100644 index 0000000..12b8e67 --- /dev/null +++ b/src/CRSGraphConstructors.jl @@ -0,0 +1,421 @@ +export CRSGraph + +#TODO document the type and constructors + +mutable struct CRSGraph{GID <: Integer, PID <: Integer, LID <: Integer} <: DistRowGraph{GID, PID, LID} + rowMap::BlockMap{GID, PID, LID} + colMap::Nullable{BlockMap{GID, PID, LID}} + rangeMap::Nullable{BlockMap{GID, PID, LID}} + domainMap::Nullable{BlockMap{GID, PID, LID}} + + #may be null if domainMap and colMap are the same + importer::Nullable{Import{GID, PID, LID}} + #may be null if rangeMap and rowMap are the same + exporter::Nullable{Export{GID, PID, LID}} + + localGraph::LocalCRSGraph{LID, LID} + + #Local number of (populated) entries; must always be consistent + nodeNumEntries::LID + + #Local number of (populated) diagonal entries. + nodeNumDiags::LID + + #Local maximum of the number of entries in each row. + nodeMaxNumRowEntries::LID + + #Global number of entries in the graph. + globalNumEntries::GID + + #Global number of (populated) diagonal entries. + globalNumDiags::GID + + #Global maximum of the number of entries in each row. + globalMaxNumRowEntries::GID + + #Whether the graph was allocated with static or dynamic profile. + pftype::ProfileType + + ## 1-D storage (Static profile) data structures ## + localIndices1D::Array{LID, 1} + globalIndices1D::Array{GID, 1} + rowOffsets::Array{LID, 1} #Tpetra: k_rowPts_ + + ## 2-D storage (Dynamic profile) data structures ## + localIndices2D::Array{Array{LID, 1}, 1} + globalIndices2D::Array{Array{GID, 1}, 1} + #may exist in 1-D storage if not packed + numRowEntries::Array{LID, 1} + + storageStatus::StorageStatus + + indicesAllowed::Bool + indicesType::IndexType + fillComplete::Bool + + lowerTriangle::Bool + upperTriangle::Bool + indicesAreSorted::Bool + noRedundancies::Bool + haveLocalConstants::Bool + haveGlobalConstants::Bool + sortGhostsAssociatedWithEachProcessor::Bool + + plist::Dict{Symbol} + + nonLocals::Dict{GID, Array{GID, 1}} + + #Large ammounts of duplication between the constructors, so group it in an inner constructor + function CRSGraph( + rowMap::BlockMap{GID, PID, LID}, + colMap::Nullable{BlockMap{GID, PID, LID}}, + rangeMap::Nullable{BlockMap{GID, PID, LID}}, + domainMap::Nullable{BlockMap{GID, PID, LID}}, + + localGraph::LocalCRSGraph, + + nodeNumEntries::LID, + + pftype::ProfileType, + storageStatus::StorageStatus, + + indicesType::IndexType, + plist::Dict{Symbol} + ) where {GID <: Integer, PID <: Integer, LID <: Integer} + + graph = new{GID, PID, LID}( + rowMap, + colMap, + rangeMap, + domainMap, + + Nullable{Import{GID, PID, LID}}(), + Nullable{Export{GID, PID, LID}}(), + + localGraph, + + #Local number of (populated) entries; must always be consistent + nodeNumEntries, + + #using 0 to indicate uninitiallized, since -1 isn't gareenteed to work + 0, #nodeNumDiags + 0, #nodeMaxNumRowEntries + 0, #globalNumEntries + 0, #globalNumDiags + 0, #globalMaxNumRowEntries + + #Whether the graph was allocated with static or dynamic profile. + pftype, + + + ## 1-D storage (Static profile) data structures ## + LID[], + GID[], + LID[], + + ## 2-D storage (Dynamic profile) data structures ## + Array{Array{LID, 1}, 1}(0), + Array{Array{GID, 1}, 1}(0), + LID[], + + storageStatus, + + false, + indicesType, + false, + + false, + false, + true, + true, + false, + false, + true, + + plist, + + Dict{GID, Array{GID, 1}}() + ) + + ## staticAssertions() + #skipping sizeof checks + #skipping max value checks related to size_t + + @assert(indicesType != LOCAL_INDICES || !isnull(colMap), + "Cannot have local indices with a null column Map") + + #ensure LID is a subset of GID (for positive numbers) + if !(LID <: GID) && (GID != BigInt) && (GID != Integer) + # all ints are assumed to be able to handle 1, up to their max + if LID == BigInt || LID == Integer || typemax(LID) > typemax(GID) + throw(InvalidArgumentError("The positive values of GID must " + * "be a superset of the positive values of LID")) + end + end + + graph + end +end + + +#### Constructors ##### + +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, maxNumEntriesPerRow::Integer, + pftype::ProfileType; plist...) where { + GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, LID(maxNumEntriesPerRow), pftype, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, maxNumEntriesPerRow::Integer, + pftype::ProfileType, plist::Dict{Symbol}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, Nullable{BlockMap{GID, PID, LID}}(), LID(maxNumEntriesPerRow), pftype, plist) +end + +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + maxNumEntriesPerRow::Integer, pftype::ProfileType; plist...) where { + GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, colMap, LID(maxNumEntriesPerRow), pftype, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + maxNumEntriesPerRow::Integer, pftype::ProfileType, plist::Dict{Symbol}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, Nullable(colMap), LID(maxNumEntriesPerRow), pftype, plist) +end + +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::Nullable{BlockMap{GID, PID, LID}}, + maxNumEntriesPerRow::LID, pftype::ProfileType; plist...) where { + GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, colMap, maxNumEntriesPerRow, pftype, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::Nullable{BlockMap{GID, PID, LID}}, + maxNumEntriesPerRow::LID, pftype::ProfileType, plist::Dict{Symbol}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + graph = CRSGraph( + rowMap, + colMap, + Nullable{BlockMap{GID, PID, LID}}(), + Nullable{BlockMap{GID, PID, LID}}(), + + LocalCRSGraph{LID, LID}(), #localGraph + + LID(0), #nodeNumEntries + + pftype, + + (pftype == STATIC_PROFILE ? + STORAGE_1D_UNPACKED + : STORAGE_2D), + + isnull(colMap)?GLOBAL_INDICES:LOCAL_INDICES, + plist + ) + + allocateIndices(graph, graph.indicesType, maxNumEntriesPerRow) + + resumeFill(graph, plist) + checkInternalState(graph) + + graph +end + +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, numEntPerRow::AbstractArray{<:Integer, 1}, + pftype::ProfileType; plist...) where { + GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, numEntPerRow, pftype, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, numEntPerRow::AbstractArray{<:Integer, 1}, + pftype::ProfileType, plist::Dict{Symbol}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, Nullable{BlockMap{GID, PID, LID}}(), numEntPerRow, pftype, plist) +end + +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + numEntPerRow::AbstractArray{<:Integer, 1}, pftype::ProfileType; + plist...) where {GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, colMap, numEntPerRow, pftype, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + numEntPerRow::AbstractArray{<:Integer, 1}, pftype::ProfileType, + plist::Dict{Symbol}) where {GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, Nullable(colMap), numEntPerRow, pftype, plist) +end + +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::Nullable{BlockMap{GID, PID, LID}}, + numEntPerRow::AbstractArray{<:Integer, 1}, pftype::ProfileType; + plist...) where {GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, colMap, numEntPerRow, pftype, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::Nullable{BlockMap{GID, PID, LID}}, + numEntPerRow::AbstractArray{<:Integer, 1}, pftype::ProfileType, + plist::Dict{Symbol}) where {GID <: Integer, PID <: Integer, LID <: Integer} + graph = CRSGraph( + rowMap, + colMap, + Nullable{BlockMap{GID, PID, LID}}(), + Nullable{BlockMap{GID, PID, LID}}(), + + LocalCRSGraph{LID, LID}(), #localGraph + + LID(0), #nodeNumEntries + + #Whether the graph was allocated with static or dynamic profile. + pftype, + + (pftype == STATIC_PROFILE ? + STORAGE_1D_UNPACKED + : STORAGE_2D), + + isnull(colMap)?GLOBAL_INDICES:LOCAL_INDICES, + plist + ) + + localNumRows = numMyElements(rowMap) + if length(numEntPerRow) != localNumRows + throw(InvalidArgumentError("numEntPerRows has length $(length(numEntPerRow)) " * + "!= the local number of rows $lclNumRows as spcified by the input row Map")) + end + + if @debug + for curRowCount in numEntPerRow + if curRowCount <= 0 + throw(InvalidArgumentError("numEntPerRow[$r] = $curRowCount is not valid")) + end + end + end + + allocateIndices(graph, graph.indicesType, Array{LID, 1}(numEntPerRow)) + + resumeFill(graph, plist) + checkInternalState(graph) + + graph +end + + +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + rowPointers::AbstractArray{LID, 1}, columnIndices::Array{LID, 1}; + plist...) where {GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, colMap, rowPointers, columnIndices, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + rowPointers::AbstractArray{LID, 1}, columnIndices::Array{LID, 1}, + plist::Dict{Symbol}) where {GID <: Integer, PID <: Integer, LID <: Integer} + graph = CRSGraph( + rowMap, + Nullable(colMap), + Nullable{BlockMap{GID, PID, LID}}(), + Nullable{BlockMap{GID, PID, LID}}(), + + LocalCRSGraph{LID, LID}(), #localGraph + + LID(0), #nodeNumEntries + + STATIC_PROFILE, + + STORAGE_1D_PACKED, + + LOCAL_INDICES, + plist + ) + #seems to be already taken care of + #allocateIndices(graph, LOCAL_INDICES, numEntPerRow) + + setAllIndicies(graph, rowPointers, columnIndicies) + checkInternalState(graph) + + graph +end + +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + localGraph::LocalCRSGraph{LID, LID}; plist...) where { + GID <: Integer, PID <: Integer, LID <: Integer} + CRSGraph(rowMap, colMap, localGraph, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end +function CRSGraph(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + localGraph::LocalCRSGraph{LID, LID}, plist::Dict{Symbol}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + mapRowCount = numMyElements(rowMap) + graph = CRSGraph( + rowMap, + Nullable(colMap), + Nullable{}(rowMap), + Nullable{}(colMap), + + localGraph, + + localGraph.rowMap[mapRowCount+1], #nodeNumEntries + + STATIC_PROFILE, + + STORAGE_1D_PACKED, + + LOCAL_INDICES, + plist + ) + + if numRows(localGraph) != numMyElements(rowMap) + throw(InvalidArgumentError("input row map and input local " + * "graph need to have the same number of rows. The " + * "row map claims $(numMyElements(rowMap)) row(s), " + * "but the local graph claims $(numRows(localGraph)) " + * "row(s).")) + end + + #seems to be already taken care of + #allocateIndices(graph, LOCAL_INDICES, numEntPerRow) + + makeImportExport(graph) + + d_inds = localGraph.entries + graph.localIndices1D = d_inds + + d_ptrs = localGraph.rowMap + graph.rowOffsets = d_ptrs + + #reset local properties + graph.upperTriangle = true + graph.lowerTriangle = true + graph.nodeMaxNumRowEntries = 0 + graph.nodeNumDiags + + for localRow = 1:mapRowCount + globalRow = gid(rowMap, localRow) + rowLID = lid(colMap, globalRow) + + #possible that the local matrix has no entries in the column + #corrisponding to the current row, in that case, the column map + #might not contain that GID. Hence, the index validity check + if rowLID != 0 + if rowLID +1 > length(d_ptrs) + throw(InvalidArgumentError("The given row Map and/or column Map " + * "is/are not compatible with the provided local graphs.")) + end + if d_ptrs[rowLID] != d_ptr[rowLID+1] + const smallestCol = d_inds[d_ptrs[rowLID]] + const largestCol = d_inds[d_ptrs[rowLID+1]-1] + + if smallestCol < localRow + graph.upperTriangle = false + end + if localRow < largestCol + graph.lowerTriangle = false + end + + if rowLID in d_inds[d_ptrs[rowLID]:d_ptrs[rowLID]-1] + graph.nodeNumDiags += 1 + end + end + + graph.nodeMaxNumRowEntries = max((d_ptrs[rowLID + 1] - d_ptrs[rowLID]), + graph.nodeMaxNumRowEntries) + end + end + + graph.hasLocalConstants = true + computeGlobalConstants(graph) + + graph.fillComplete = true + checkInternalState(graph) + + graph +end diff --git a/src/CRSGraphExternalMethods.jl b/src/CRSGraphExternalMethods.jl new file mode 100644 index 0000000..9adfe76 --- /dev/null +++ b/src/CRSGraphExternalMethods.jl @@ -0,0 +1,643 @@ +export getProfileType, getColMap +export resumeFill, fillComplete +export insertLocalIndices, insertGlobalIndices + +#### RowGraph methods #### + +getRowMap(graph::CRSGraph) = graph.rowMap +getColMap(graph::CRSGraph) = get(graph.colMap) +getDomainMap(graph::CRSGraph) = get(graph.domainMap) +getRangeMap(graph::CRSGraph) = get(graph.rangeMap) +getImporter(graph::CRSGraph) = get(graph.importer) +getExporter(graph::CRSGraph) = get(graph.exporter) + +getGlobalNumRows(graph::CRSGraph) = numGlobalElements(getRowMap(graph)) +getGlobalNumCols(graph::CRSGraph) = numGlobalElements(getColMap(graph)) +getLocalNumRows(graph::CRSGraph) = numMyElements(getRowMap(graph)) +getLocalNumCols(graph::CRSGraph) = numMyElements(getColMap(graph)) + +getGlobalNumEntries(graph::CRSGraph) = graph.globalNumEntries +getLocalNumEntries(graph::CRSGraph) = graph.nodeNumEntries + +function getNumEntriesInGlobalRow(graph::CRSGraph{GID}, globalRow::Integer)::Integer where {GID <: Integer} + localRow = lid(graph.rowMap, GID(globalRow)) + getNumEntriesInLocalRow(graph, localRow) +end + +function getNumEntriesInLocalRow(graph::CRSGraph{GID, PID, LID}, localRow::Integer)::Integer where {GID, PID, LID <: Integer} + if hasRowInfo(graph) && myLID(graph.rowMap, LID(localRow)) + info = getRowInfo(graph, LID(localRow)) + retVal = info.numEntries + recycleRowInfo(info) + retVal + else + -1 + end +end + +getGlobalNumDiags(graph::CRSGraph) = graph.globalNumDiags +getLocalNumDiags(graph::CRSGraph) = graph.nodeNumDiags + +getGlobalMaxNumRowEntries(graph::CRSGraph) = graph.globalMaxNumRowEntries +getLocalMaxNumRowEntries(graph::CRSGraph) = graph.nodeMaxNumRowEntries + +hasColMap(graph::CRSGraph) = !isnull(graph.colMap) + +isLowerTriangular(graph::CRSGraph) = graph.lowerTriangle +isUpperTriangular(graph::CRSGraph) = graph.upperTriangle + +isGloballyIndexed(graph::CRSGraph) = graph.indicesType == GLOBAL_INDICES +isLocallyIndexed(graph::CRSGraph) = graph.indicesType == LOCAL_INDICES + +isFillComplete(g::CRSGraph) = g.fillComplete + +function getGlobalRowCopy(graph::CRSGraph{GID}, globalRow::GID)::Array{GID, 1} where {GID <: Integer} + Array{GID, 1}(getGlobalRowView(graph, globalRow)) +end + +function getLocalRowCopy(graph::CRSGraph{GID, PID, LID}, localRow::LID)::Array{LID, 1} where {GID, PID, LID <: Integer} + Array{LID, 1}(getLocalRowView(graph, localRow)) +end + +function pack(source::CRSGraph{GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, distor::Distributor{GID, PID, LID})::Array{Array{LID, 1}, 1} where {GID, PID, LID} + srcMap = map(source) + [getGlobalRowCopy(source, gid(srcMap, lid)) for lid in exportLIDs] +end + + +#### DistObject methods #### +function checkSizes(source::RowGraph{GID, PID, LID}, + target::CRSGraph{GID, PID, LID}) where {GID, PID, LID} + #T and E petra's don't do any checks + true +end + +function copyAndPermute(source::RowGraph{GID, PID, LID}, + target::CRSGraph{GID, PID, LID}, numSameIDs::LID, + permuteToLIDs::AbstractArray{LID, 1}, permuteFromLIDs::AbstractArray{LID, 1}) where { + GID, PID, LID} + copyAndPermuteNoViewMode(source, target, + numSameIDs, permuteToLIDs, permuteFromLIDs) +end + +function copyAndPermute(source::CRSGraph{GID, PID, LID}, + target::CRSGraph{GID, PID, LID}, numSameIDs::LID, + permuteToLIDs::AbstractArray{LID, 1}, permuteFromLIDs::AbstractArray{LID, 1}) where { + GID, PID, LID} + if isFillComplete(target) + throw(InvalidStateError("Target cannot be fill complete")) + end + if isFillComplete(source) + copyAndPermuteNoViewMode(source, target, + numSameIDs, permuteToLIDs, permuteFromLIDs) + else + if length(permuteToLIDs) != length(permuteFromLIDs) + throw(InvalidArgumentError("permuteToLIDs and " + * "permuteFromLIDs must have the same size")) + end + + srcRowMap = getRowMap(source) + tgtRowMap = getRowMap(target) + + #copy part + for myid = 1:numSameIDs + myGID = gid(srcRowMap, myid) + row = getGlobalRowView(source, myGID) + insertGlobalIndices(target, myGID, row) + end + + #permute part + for i = 1:length(permuteToLIDs) + srcGID = gid(srcRowMap, permuteFromLIDs[i]) + tgtGID = gid(tgtRowMap, permuteToLIDs[i]) + row = getGlobalRowView(source, srcGID) + insertGlobalIndices(target, tgtGID, row) + end + end +end + + +function copyAndPermuteNoViewMode(source::RowGraph{GID, PID, LID}, + target::CRSGraph{GID, PID, LID}, numSameIDs::LID, + permuteToLIDs::AbstractArray{LID, 1}, permuteFromLIDs::AbstractArray{LID, 1}) where { + GID, PID, LID} + if length(permuteToLIDs) != length(premuteFromLIDs) + throw(InvalidArgumentError("permuteToLIDs and " + * "permuteFromLIDs must have the same size")) + end + + srcRowMap = getRowMap(source) + tgtRowMap = getRowMap(target) + + #copy part + for myid = 1:numSameIDs + myGID = gid(srcRowMap, myid) + rowCopy = getGlobalRowCopy(sourceRowGraph, myGID) + insertGlobalIndices(target, myGID, rowCopy) + end + + #permute part + for i = 1:length(permuteToLIDs) + tgtGID = gid(tgtRowMap, permuteToLIDs[i]) + srcGID = gid(srcRowMap, permuteFromLIDs[i]) + rowCopy = getGlobalRowCopy(source, srcGID) + insertGlobalIndices(target, tgtGID, rowCopy) + end +end + + + +function packAndPrepare(source::RowGraph{GID, PID, LID}, + target::CRSGraph{GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, + distor::Distributor{GID, PID, LID})::Array{GID, 1} where { + GID, PID, LID} + + pack(source, exportLIDs, distor) +end + +function unpackAndCombine(target::CRSGraph{GID, PID, LID}, + importLIDs::AbstractArray{LID, 1}, imports::AbstractArray, + distor::Distributor{GID, PID, LID}, cm::CombineMode) where { + GID, PID, LID} + #should be caught else where + @assert(isFillActive(target), + "Import and Export operations require a fill active graph") + + tgtMap = map(target) + + for i = 1:length(importLIDs) + row = imports[i] + insertGlobalIndicesFiltered(target, gid(tgtMap, importLIDs[i]), row) + end +end + + + +#### CRSGraph methods #### +#TODO document the CRSGraph methods + +""" + insertLocalIndices(::CRSGraph{GID, PID, LID}, localRow::Integer, [numEntries::Integer,] inds::AbstractArray{<: Integer, 1}) + +Inserts the given local indices into the graph. If `numEntries` is given, +only the first `numEntries` elements are inserted +""" +function insertLocalIndices(graph::CRSGraph{GID, PID, LID}, localRow::Integer, + numEntries::Integer, inds::AbstractArray{<:Integer, 1}) where {GID, PID, LID <: Integer} + insertLocalIndices(graph, LID(localRow), LID(numEntries), Array{LID, 1}(inds)) +end +function insertLocalIndices(graph::CRSGraph{GID, PID, LID}, localRow::LID, + numEntries::LID, inds::AbstractArray{LID, 1}) where { + GID, PID, LID <: Integer} + indicesView = view(inds, 1:numEntries) + insertLocalIndices(graph, localRow, indsT) +end + +function insertLocalIndices(graph::CRSGraph{GID, PID, LID}, localRow::Integer, + inds::AbstractArray{<:Integer, 1}) where {GID, PID, LID <: Integer} + insertLocalIndices(graph, LID(localRow), Array{LID, 1}(inds)) +end +function insertLocalIndices(graph::CRSGraph{GID, PID, LID}, + localRow::LID, indices::AbstractArray{LID, 1}) where{ + GID, PID, LID <: Integer} + if !isFillActive(graph) + throw(InvalidStateError("insertLocalIndices requires that fill is active")) + end + if isGloballyIndexed(graph) + throw(InvalidStateError("graph indices are global, use insertGlobalIndices(...) instead")) + end + if !hasColMap(graph) + throw(InvalidStateError("Cannot insert local indices without a column map")) + end + if !myLID(map(graph), localRow) + throw(InvalidArgumentError("Row does not belong to this process")) + end + if !hasRowInfo(graph) + throw(InvalidStateError("Row information was deleted")) + end + + if @debug + colMap = getColMap(graph) + badColIndices = [ind for ind in indices if myLID(colMap, ind)] + + if length(badColIndices) != 0 + throw(InvalidArgumentError( + "Attempting to insert entries in owned row $localRow, " + * "at the following column indices: $indices.\n" + + * "Of those, the following indices are not in " + * "the column map on this process: $badColIndices.\n" + + * "Since the graph has a column map already, it is " + * "invalid to insert entries at those locations")) + end + end + + insertLocalIndicesImpl(graph, localRow, indices) + + if @debug + @assert isLocallyIndexed(graph) "Post condtion violated" + end +end + + +""" + insertGlobalIndices(::CRSGraph{GID, PID, LID}, localRow::Integer, [numEntries::Integer,] inds::AbstractArray{<: Integer, 1}) + +Inserts the given global indices into the graph. If `numEntries` is given, +only the first `numEntries` elements are inserted +""" +function insertGlobalIndices(graph::CRSGraph{GID, PID, LID}, globalRow::Integer, + numEntries::Integer, inds::AbstractArray{<: Integer, 1}) where { + GID <: Integer, PID, LID <: Integer} + insertGlobalIndices(graph, GID(globalRow), LID(numEntries), Array{GID, 1}(inds)) +end +function insertGlobalIndices(graph::CRSGraph{GID, PID, LID}, globalRow::GID, + numEntries::LID, inds::AbstractArray{GID, 1}) where { + GID <: Integer, PID, LID <: Integer} + indicesView = view(inds, 1:numEntries) + insertGlobalIndices(graph, globalRow, indsT) +end +function insertGlobalIndices(graph::CRSGraph{GID, PID, LID}, globalRow::Integer, + inds::AbstractArray{<: Integer, 1}) where {GID <: Integer, PID, LID <: Integer} + insertGlobalIndices(graph, GID(globalRow), Array{GID, 1}(inds)) +end +function insertGlobalIndices(graph::CRSGraph{GID, PID, LID}, globalRow::GID, + indices::AbstractArray{GID, 1}) where {GID <: Integer, PID, LID <: Integer} + if isLocallyIndexed(graph) + throw(InvalidStateError("Graph indices are local, use insertLocalIndices()")) + end + if !hasRowInfo(graph) + throw(InvalidStateError("Graph row information was deleted")) + end + if isFillComplete(graph) + throw(InvalidStateError("Cannot call this method if the fill is not active")) + end + + myRow = lid(graph.rowMap, globalRow) + if myRow != 0 + if @debug + if hasColMap(graph) + colMap = get(graph.colMap) + + #appearently jupyter can't render the generator if properly + badColInds = [index for index in indices + if myGid(colMap, index)==0] + if length(badColInds) != 0 + throw(InvalidArgumentError("$(myPid(comm(graph))): " + * "Attempted to insert entries in owned row $globalRow, " + * "at the following column indices: $indices.\n" + + * "Of those, the following indices are not in the " + * "column Map on this process: $badColInds.\n" + + * "Since the matrix has a column map already, it " + * "is invalid to insert entries at those locations")) + end + end + end + insertGlobalIndicesImpl(graph, myRow, indices) + else + append!(graph.nonlocalRow, indices) + end +end + +""" + insertGlobalIndicesFiltered(::CRSGraph{GID, PID, LID}, localRow::Integer, inds::AbstractArray{<: Integer, 1}) + +As `insertGlobalIndices(...)` but filters by those present in +the column map, if present +""" +function insertGlobalIndicesFiltered(graph::CRSGraph{GID, PID, LID}, globalRow::GID, + indices::AbstractArray{GID, 1}) where{GID, PID, LID} + if isLocallyIndexed(graph) + throw(InvalidStateError( + "graph indices are local, use insertLocalIndices(...)")) + end + if !hasRowInfo(graph) + throw(InvalidStateError("Graph row information was deleted")) + end + if isFillComplete(graph) + throw(InvalidStateError("Cannot insert into fill complete graph")) + end + + myRow = lid(graph.rowMap, globalRow) + if myRow != 0 + #if column map present, use it to filter the entries + if hasColMap(graph) + colMap = getColMap(graph) + indices = [index for index in indices if myLid(colMap, index)] + end + insertGlobalIndicesImpl(myRow, indices) + else + #nonlocal row + append!(graph.nonlocals[globalRow], indices) + end +end + +function getGlobalView(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID}) where {GID <: Integer, PID, LID <: Integer} + if rowInfo.allocSize > 0 + if length(graph.globalIndices1D) != 0 + range = rowInfo.offset1D : rowInfo.offset1D + rowInfo.allocSize + view(graph.globalIndices1D, range) + elseif length(graph.globalIndices2D[rowInfo.localRow]) == 0 + globalIndices2D[rowInfo.localRow] + else + GID[] + end + else + GID[] + end +end + +@inline function getGlobalViewPtr(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID})::Tuple{Ptr{GID}, LID} where {GID <: Integer, PID, LID <: Integer} + if rowInfo.allocSize > 0 + if length(graph.globalIndices1D) != 0 + return (pointer(graph.globalIndices1D, rowInfo.offset1D), rowInfo.allocSize) + elseif length(graph.globalIndices2D[rowInfo.localRow]) == 0 + baseArray = graph.globalIndices2D[rowInfo.localRow]::Vector{GID} + return (pointer(baseArray), GID(length(baseArray))) + end + end + return (Ptr{GID}(C_NULL), 0) +end + +function getLocalView(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID}) where {GID <: Integer, PID, LID <: Integer} + if rowInfo.allocSize > 0 + if length(graph.localIndices1D) != 0 + range = rowInfo.offset1D : rowInfo.offset1D + rowInfo.allocSize-LID(1) + return view(graph.localIndices1D, range) + elseif length(graph.localIndices2D[rowInfo.localRow]) == 0 + baseArray = graph.localIndices2D[rowInfo.localRow] + return view(baseArray, LID(1):LID(length(baseArray))) + end + end + return LID +end + + +Base.@propagate_inbounds @inline function getLocalViewPtr(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID})::Tuple{Ptr{LID}, LID} where {GID <: Integer, PID, LID <: Integer} + if rowInfo.allocSize > 0 + if length(graph.localIndices1D) != 0 + return (pointer(graph.localIndices1D, rowInfo.offset1D), rowInfo.allocSize) + elseif length(graph.localIndices2D[rowInfo.localRow]) == 0 + baseArray::Array{LID, 1} = graph.localIndices2D[rowInfo.localRow] + return (pointer(baseArray), LID(length(baseArray))) + end + end + return (C_NULL, 0) +end + +function getGlobalRowView(graph::CRSGraph{GID}, globalRow::GID)::AbstractArray{GID, 1} where {GID <: Integer} + if isLocallyIndexed(graph) + throw(InvalidArgumentError("The graph's indices are currently stored as local indices, so a view with global column indices cannot be returned. Use getGlobalRowCopy(::CRSGraph) instead")) + end + + if @debug + @assert hasRowInfo(graph) "Graph row information was deleted" + end + rowInfo = getRowInfoFromGlobalRow(graph, globalRow) + + if rowInfo.localRow != 0 && rowInfo.numEntries > 0 + retVal = view(getGlobalView(graph, rowInfo), 1:rowInfo.numEntries) + else + retVal = GID[] + end + recycleRowInfo(rowInfo) + retVal +end + +Base.@propagate_inbounds function getGlobalRowViewPtr(graph::CRSGraph{GID, PID, LID}, globalRow::GID)::Tuple{Ptr{GID}, LID} where {GID <: Integer, PID <: Integer, LID <: Integer} + if isLocallyIndexed(graph) + throw(InvalidArgumentError("The graph's indices are currently stored as local indices, so a view with global column indices cannot be returned. Use getGlobalRowCopy(::CRSGraph) instead")) + end + + if @debug + @assert hasRowInfo(graph) "Graph row information was deleted" + end + rowInfo = getRowInfoFromGlobalRow(graph, globalRow) + + isLocalRow = true + @boundscheck isLocalRow = rowInfo.localRow != 0 + + if isLocalRow && rowInfo.numEntries > 0 + retVal = (getGlobalViewPtr(graph, rowInfo)[1], rowInfo.numEntries) + else + retVal = (Ptr{GID}(C_NULL), LID(0)) + end + recycleRowInfo(rowInfo) + retVal +end + +function getLocalRowView(graph::CRSGraph{GID}, localRow::GID)::AbstractArray{GID, 1} where {GID} + if @debug + @assert hasRowInfo() "Graph row information was deleted" + end + rowInfo = getRowInfoFromLocalRowIndex(graph, localRow) + + retVal = getLocalRowView(graph, rowInfo) + recycleRowInfo(rowInfo) + retVal +end + +function getLocalRowView(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID} + )::AbstractArray{GID, 1} where {GID, PID, LID} + + if isGloballyIndexed(graph) + throw(InvalidArgumentError("The graph's indices are currently stored as global indices, so a view with local column indices cannot be returned. Use getLocalRowCopy(::CRSGraph) instead")) + end + + if rowInfo.localRow != 0 && rowInfo.numEntries > 0 + indices = view(getLocalView(graph, rowInfo), 1:rowInfo.numEntries) + + if @debug + @assert(length(indices) == getNumEntriesInLocalRow(graph, localRow), + "length(indices) = $(length(indices)) " + * "!= getNumEntriesInLocalRow(graph, $localRow) " + * "= $(getNumEntriesInLocalRow(graph, localRow))") + end + indices + else + LID[] + end +end + +resumeFill(graph::CRSGraph; plist...) = resumeFill(graph, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) + +function resumeFill(graph::CRSGraph, plist::Dict{Symbol}) + if !hasRowInfo(graph) + throw(InvalidStateError("Cannot resume fill of the CRSGraph, " + * "since the graph's row information was deleted.")) + end + + clearGlobalConstants(graph) + graph.plist = plist + graph.lowerTriangle = false + graph.upperTriangle = false + graph.indicesAreSorted = true + graph.noRedundancies = true + graph.fillComplete = false + + if @debug + @assert isFillActive(graph) && !isFillComplete(graph) "Post condition violated" + end +end + + +fillComplete(graph::CRSGraph; plist...) = fillComplete(graph, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) + +function fillComplete(graph::CRSGraph, plist::Dict{Symbol}) + if isnull(graph.domainMap) + domMap = graph.rowMap + else + domMap = get(graph.domainMap) + end + + if isnull(graph.rangeMap) + ranMap = graph.colMap + else + ranMap = get(graph.rangeMap) + end + + fillComplete(graph, ranMap, domMap, plist) +end + +function fillComplete(graph::CRSGraph{GID, PID, LID}, + domainMap::BlockMap{GID, PID, LID}, rangeMap::BlockMap{GID, PID, LID}; + plist...) where {GID, PID, LID} + fillComplete(graph, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end + +function fillComplete(graph::CRSGraph{GID, PID, LID}, domainMap::BlockMap{GID, PID, LID}, rangeMap::BlockMap{GID, PID, LID}, plist::Dict{Symbol}) where {GID, PID, LID} + if !isFillActive(graph) || isFillComplete(graph) + throw(InvalidStateError("Graph fill state must be active to call fillComplete(...)")) + end + + const numProcs = numProc(comm(graph)) + + const assertNoNonlocalInserts = get(plist, :noNonlocalChanges, false) + + const mayNeedGlobalAssemble = !assertNoNonlocalInserts && numProcs > 1 + if mayNeedGlobalAssemble + globalAssemble(graph) + else + if numProcs == 1 && length(graph.nonLocals) > 0 + throw(InvalidStateError("Only one process, but nonlocal entries are present")) + end + end + + setDomainRangeMaps(graph, domainMap, rangeMap) + + if !hasColMap(graph) + makeColMap(graph) + end + + makeIndicesLocal(graph) + + sortAndMergeAllIndices(graph, isSorted(graph), isMerged(graph)) + + makeImportExport(graph) + computeGlobalConstants(graph) + fillLocalGraph(graph, plist) + graph.fillComplete(true) + + if @debug + @assert !isFillActive(graph) && isFillComplete(graph) "post conditions violated" + end + + checkInternalState(graph) +end + +function makeColMap(graph::CRSGraph{GID, PID, LID}) where {GID, PID, LID} + const localNumRows = getLocalNumEntries(graph) + + #TODO look at FIXME on line 4898 + + error, colMap = __makeColMap(graph, graph.domainMap) + + if @debug + comm = JuliaPetra.comm(graph) + globalError = maxAll(comm, error) + + if globalError + Base.error("makeColMap reports an error on at least one process") + end + end + + graph.colMap = colMap + + checkInternalState(graph) +end + +""" + isSorted(::CRSGraph) + +Whether the indices are sorted +""" +isSorted(graph::CRSGraph) = graph.indicesAreSorted + +""" + isMerged(::CRSGraph) + +Whether duplicate column indices in each row have been merged +""" +isMerged(graph::CRSGraph) = graph.noRedundancies + +""" + setAllIndices(graph::CRSGraph{GID, PID, LID}, rowPointers::Array{LID, 1}, columnIndices::AbstractArray{LID, 1}) + +Sets the graph's data directly, using 1D storage +""" +function setAllIndices(graph::CRSGraph{GID, PID, LID}, + rowPointers::AbstractArray{LID, 1},columnIndices::Array{LID, 1}) where { + GID, PID, LID <: Integer} + + localNumRows = getLocalNumRows(graph) + + if isnull(graph.colMap) + throw(InvalidStateError("The graph must have a " + * "column map before calling setAllIndices")) + end + if length(rowPointers) != localNumRows + 1 + throw(InvalidArgumentError("length(rowPointers) = $(length(rowPointers)) " + * "!= localNumRows+1 = $(localNumRows+1)")) + end + + localNumEntries = rowPointers[localNumRows+1] + + graph.indicesType = LOCAL_INDICES + graph.pftype = STATIC_PROFILE + graph.localIndices1D = columnIndices + graph.rowOffsets = rowPointers + graph.nodeNumEntries = localNumEntries + graph.storageStatus = STORAGE_1D_UNPACKED + + graph.localGraph = LocalCRSGraph(columnIndices, rowPointers) + + checkInternalState(graph) +end + + +""" + isStorageOptimized(::CRSGraph) + +Whether the graph's storage is optimized +""" +function isStorageOptimized(graph::CRSGraph) + const isOpt = length(graph.numRowEntries) == 0 && getLocalNumRows(graph) > 0 + if (@debug) && isOpt + @assert(getProfileType(graph) == STATIC_PROFILE, + "Matrix claims optimized storage by profile type " + * "is dynamic. This shouldn't happend.") + end + isOpt +end + +""" + getProfileType(::CRSGraph) + +Gets the profile type of the graph +""" +function getProfileType(graph::CRSGraph) + graph.pftype +end diff --git a/src/CRSGraphInternalMethods.jl b/src/CRSGraphInternalMethods.jl new file mode 100644 index 0000000..8f46813 --- /dev/null +++ b/src/CRSGraphInternalMethods.jl @@ -0,0 +1,980 @@ + + +#DECISION put this somewhere else? Its only an internal grouping +mutable struct RowInfo{LID <: Integer} + graph::CRSGraph{<:Integer, <:Integer, LID} + localRow::LID + allocSize::LID + numEntries::LID + offset1D::LID +end + +#RowInfo object for re-use +const rowInfoSpare = Union{Void, RowInfo}[nothing] + +""" + Gets a `RowInfo` object with the given values, reusing an intance if able +""" +@inline function createRowInfo(graph::CRSGraph{<:Integer, <:Integer, LID}, localRow::LID, + allocSize::LID, numEntries::LID, offset1D::LID)::RowInfo{LID} where {LID <: Integer} + global rowInfoSpare + + @inbounds begin#if length(rowInfoSpare) > 0 + #rowInfoSpare should always be size 1 + nextVal = rowInfoSpare[1] + #ensure object pool haves right type + if nextVal isa RowInfo{LID} + rowInfoSpare[1] = nothing + + rowInfo::RowInfo{LID} = nextVal + rowInfo.graph = graph + rowInfo.localRow = localRow + rowInfo.allocSize = allocSize + rowInfo.numEntries = numEntries + rowInfo.offset1D = offset1D + return rowInfo + end + end + + (@debug) && (myPid(comm(graph)) == 1) && println("can't reuse $(@inbounds rowInfoSpare[1])") + #couldn't reuse, create new instance + return RowInfo{LID}(graph, localRow, allocSize, numEntries, offset1D) +end + +""" + Puts the `RowInfo` object back in the object pool. + After calling this method remove all references to the object. +""" +@inline function recycleRowInfo(rowInfo::RowInfo{T}) where T + @inbounds rowInfoSpare[1] = rowInfo + nothing +end + + +#TODO implement getLocalDiagOffsets(::CRSGraph) +getLocalGraph(graph::CRSGraph) = graph.localGraph + + +function updateGlobalAllocAndValues(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID}, newAllocSize::Integer, rowValues::AbstractArray{Data, 1}) where {Data, GID, PID, LID} + + resize!(graph.globalIndices2D[rowInfo.localRow], newAllocSize) + resize!(rowVals, newAllocSize) + + nothing +end + +function insertIndicesAndValues(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID}, newInds::Union{AbstractArray{GID, 1}, AbstractArray{LID, 1}}, oldRowVals::AbstractArray{Data, 1}, newRowVals::AbstractArray{Data, 1}, lg::IndexType) where {Data, GID, PID, LID} + numNewInds = insertIndices(graph, rowInfo, newInds, lg) + oldInd = rowInfo.numEntries+1 + + oldRowVals[range(oldInd, 1, numNewInds)] = newRowVals[1:numNewInds] +end + +function insertIndices(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID}, newInds::Union{AbstractArray{GID, 1}, AbstractArray{LID, 1}}, lg::IndexType) where {GID, PID, LID} + numNewInds = LID(length(newInds)) + if lg == GLOBAL_INDICES + if isGloballyIndexed(graph) + numEntries = rowInfo.numEntries + (gIndPtr, gIndLen) = getGlobalViewPtr(graph, rowInfo) + @assert gIndLen >= numNewInds+numEntries + for i in 1:numNewInds + unsafe_store!(gIndPtr, GID(newInds[i]), numEntries+i) + end + else + lIndView = getLocalView(graph, rowInfo) + colMap = graph.colMap + + dest = range(rowInfo.numEntries, 1, numNewInds) + lIndView[dest] = [lid(colMap, GID(id)) for id in newInds] + end + elseif lg == LOCAL_INDICES + if isLocallyIndexed(graph) + numEntries = rowInfo.numEntries + (lIndPtr, lIndLen) = getLocalViewPtr(graph, rowInfo) + @assert gIndLen >= numNewInds+numEntries + for i in 1:numNewInds + unsafe_store!(lIndPtr, LID(newInds[i]), numEntries+i) + end + else + @assert(false,"lg=LOCAL_INDICES, isGloballyIndexed(g) not implemented, " + * "because it doesn't make sense") + end + end + + graph.numRowEntries[rowInfo.localRow] += numNewInds + graph.nodeNumEntries += numNewInds + setLocallyModified(graph) + + numNewInds +end + + +function computeGlobalConstants(graph::CRSGraph{GID, PID, LID}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + + #short circuit if already computed + graph.haveGlobalConstants && return + + if @debug + @assert !isnull(graph.colMap) "The graph must have a column map at this point" + end + + computeLocalConstants(graph) + + commObj = comm(map(graph)) + + #if graph.haveGlobalConstants == false #short circuited above + graph.globalNumEntries, graph.globalNumDiags = sumAll(commObj, + [GID(graph.nodeNumEntries), GID(graph.nodeNumDiags)]) + + graph.globalMaxNumRowEntries = maxAll(commObj, GID(graph.nodeMaxNumRowEntries)) + graph.haveGlobalConstants = true +end + +function clearGlobalConstants(graph::CRSGraph) + graph.globalNumEntries = 0 + graph.globalNumDiags = 0 + graph.globalMaxNumRowEntries = 0 + graph.haveGlobalConstants = false +end + + +function computeLocalConstants(graph::CRSGraph{GID, PID, LID}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + + #short circuit if already computed + graph.haveLocalConstants && return + + if @debug + @assert !isnull(graph.colMap) "The graph must have a column map at this point" + end + + #if graph.haveLocalConstants == false #short circuited above + + graph.upperTriangle = true + graph.lowerTriangle = true + graph.nodeMaxNumRowEntries = 0 + graph.nodeNumDiags = 0 + + rowMap = graph.rowMap + colMap = get(graph.colMap) + + #indicesAreAllocated => true + if hasRowInfo(graph) + const numLocalRows = numMyElements(rowMap) + for localRow = LID(1):numLocalRows + const globalRow = gid(rowMap, localRow) + const rowLID = lid(colMap, globalRow) + + const rowInfo = getRowInfo(graph, localRow) + (rowPtr, rowLen) = getLocalViewPtr(graph, rowInfo) + + + for i in 1:rowLen + if unsafe_load(rowPtr) == rowLID + graph.nodeNumDiags += 1 + break + end + end + + const smallestCol::LID = unsafe_load(rowPtr, 1) + const largestCol::LID = unsafe_load(rowPtr, rowLen) + if smallestCol < localRow + graph.upperTriangle = false + end + if localRow < largestCol + graph.lowerTriangle = false + end + graph.nodeMaxNumRowEntries = max(graph.nodeMaxNumRowEntries, rowInfo.numEntries) + + recycleRowInfo(rowInfo) + end + end + graph.haveLocalConstants = true +end + + +hasRowInfo(graph::CRSGraph) = (getProfileType(graph) != STATIC_PROFILE + || length(graph.rowOffsets) != 0) + +Base.@propagate_inbounds function getRowInfoFromGlobalRow(graph::CRSGraph{GID, PID, LID}, + row::Integer)::RowInfo{LID} where {GID, PID, LID <: Integer} + getRowInfo(graph, lid(graph.rowMap, row)) +end + +@inline function getRowInfo(graph::CRSGraph{GID, PID, LID}, row::LID)::RowInfo{LID} where {GID, PID, LID <: Integer} + if @debug + @assert hasRowInfo(graph) "Graph does not have row info anymore. Should have been caught earlier" + end + + emptyRowInfo = !hasRowInfo(graph) + @boundscheck emptyRowInfo = emptyRowInfo || !myLID(graph.rowMap, row) + if emptyRowInfo + return createRowInfo(graph, row, LID(0), LID(0), LID(1)) + end + + offset1D::LID = 1 + allocSize::LID = 0 + + @inbounds if getProfileType(graph) == STATIC_PROFILE + if length(graph.rowOffsets) != 0 + offset1D = LID(graph.rowOffsets[row]) + allocSize = LID(graph.rowOffsets[row+1] - graph.rowOffsets[row]) + end + numEntries = (length(graph.numRowEntries) == 0 ? + allocSize : LID(graph.numRowEntries[row])) + else #dynamic profile + if isLocallyIndexed(graph) && length(graph.localIndices2D) == 0 + allocSize = LID(length(graph.localIndices2D[row])) + + elseif isGloballyIndexed(graph) && length(graph.globalIndices2D) == 0 + allocSize = LID(length(graph.globalIndices2D[row])) + end + numEntries = (length(graph.numRowEntries) == 0 ? + LID(0) : LID(graph.numRowEntries[row])) + end + createRowInfo(graph, row, allocSize, numEntries, offset1D) +end + +function getLocalView(rowInfo::RowInfo{LID})::AbstractArray{LID, 1} where LID <: Integer + graph = rowInfo.graph + if rowInfo.allocSize == 0 + LID[] + elseif length(graph.localIndices1D) != 0 + start = rowInfo.offset1D + len = rowInfo.allocSize + + view(graph.localIndices1D, range(start, 1, len)) + elseif length(graph.localIndices2D[rowInfo.localRow]) != 0 + graph.localIndices2D[rowInfo.localRow] + else + LID[] + end +end + +function allocateIndices(graph::CRSGraph{GID, <:Integer, LID}, + lg::IndexType, numAllocPerRow::AbstractArray{<:Integer, 1}) where { + GID <: Integer, LID <: Integer} + numRows = getLocalNumRows(graph) + @assert(length(numAllocPerRow) == numRows, + "numAllocRows has length = $(length(numAllocPerRow)) " + * "!= numRows = $numRows") + allocateIndices(graph, lg, numAllocPerRow, i -> numAllocPerRow[i]) +end + +function allocateIndices(graph::CRSGraph{GID, <:Integer, LID}, + lg::IndexType, numAllocPerRow::Integer) where { + GID <: Integer, LID <: Integer} + allocateIndices(graph, lg, numAllocPerRow, i-> numAllocPerRow) +end + +function allocateIndices(graph::CRSGraph{GID, <:Integer, LID}, + lg::IndexType, numAlloc, numAllocPerRow::Function) where { + GID <: Integer, LID <: Integer} + + @assert(isLocallyIndexed(graph) == (lg == LOCAL_INDICES), + "Graph is $(isLocallyIndexed(graph)?"":"not ")locally indexed, but lg=$lg") + @assert(isGloballyIndexed(graph) == (lg == GLOBAL_INDICES), + "Graph is $(isGloballyIndexed(graph)?"":"not ")globally indexed but lg=$lg") + + numRows = getLocalNumRows(graph) + + if getProfileType(graph) == STATIC_PROFILE + rowPtrs = Array{LID, 1}(numRows + 1) + + computeOffsets(rowPtrs, numAlloc) + + graph.rowOffsets = rowPtrs + numInds = rowPtrs[numRows+1] + + if lg == LOCAL_INDICES + graph.localIndices1D = Array{LID, 1}(numInds) + else + graph.globalIndices1D = Array{GID, 1}(numInds) + end + graph.storageStatus = STORAGE_1D_UNPACKED + else + if lg == LOCAL_INDICES + graph.localIndices2D = Array{Array{LID, 1}, 1}(numRows) + for row = 1:numRows + graph.localIndices2D[row] = Array{LID, 1}(numAllocPerRow(row)) + end + else #lg == GLOBAL_INDICES + graph.globalIndices2D = Array{Array{GID, 1}, 1}(numRows) + for row = 1:numRows + graph.globalIndices2D[row] = Array{GID, 1}(numAllocPerRow(row)) + end + end + graph.storageStatus = STORAGE_2D + end + + graph.indicesType = lg + + if numRows > 0 + numRowEntries = zeros(LID, numRows) + graph.numRowEntries = numRowEntries + end + + #let the calling constructor take care of this + #checkInternalState(graph) +end + + +function makeImportExport(graph::CRSGraph{GID, PID, LID}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + @assert !isnull(graph.colMap) "Cannot make imports and exports without a column map" + + if isnull(graph.importer) + if get(graph.domainMap) !== get(graph.colMap) && !sameAs(get(graph.domainMap), get(graph.colMap)) + graph.importer = Import(get(graph.domainMap), get(graph.colMap), graph.plist) + end + end + + if isnull(graph.exporter) + if get(graph.rangeMap) !== graph.rowMap && !sameAs(get(graph.rangeMap), graph.rowMap) + graph.exporter = Export(graph.rowMap, graph.rangeMap, graph.plist) + end + end +end + +#TODO migrate this to testing +function checkInternalState(graph::CRSGraph) + if @debug + const localNumRows = getLocalNumRows(graph) + + @assert(isFillActive(graph) != isFillComplete(graph), + "Graph must be either fill active or fill " + * "complete$(isFillActive(graph)?"not both":"").") + @assert(!isFillComplete(graph) + || (!isnull(graph.colMap) + && !isnull(graph.domainMap) + && !isnull(graph.rangeMap)), + "Graph is fill complete, but at least one of {column, range, domain} map is null") + @assert((graph.storageStatus != STORAGE_1D_PACKED + && graph.storageStatus != STORAGE_1D_UNPACKED) + || graph.pftype != DYNAMIC_PROFILE, + "Graph claims 1D storage, but dynamic profile") + if graph.storageStatus == STORAGE_2D + @assert(graph.pftype != STATIC_PROFILE , + "Graph claims 2D storage, but static profile") + @assert(!isLocallyIndexed(graph) + || length(graph.localIndices2D) == localNumRows, + "Graph calims to be locally index and have 2D storage, " + * "but length(graph.localIndices2D) = $(length(graph.localIndices2D)) " + * "!= getLocalNumRows(graph) = $localNumRows") + @assert(!isGloballyIndexed(graph) + || length(graph.globalIndices2D) == localNumRows, + "Graph calims to be globally index and have 2D storage, " + * "but length(graph.globalIndices2D) = $(length(graph.globalIndices2D)) " + * "!= getLocalNumRows(graph) = $localNumRows") + end + + @assert(graph.haveGlobalConstants + || (graph.globalNumEntries == 0 + && graph.globalNumDiags == 0 + && graph.globalMaxNumRowEntries == 0), + "Graph claims to not have global constants, " + * "but some of the global constants are not 0") + + @assert(!graph.haveGlobalConstants + || (graph.globalNumEntries != 0 + && graph.globalMaxNumRowEntries != 0), + "Graph claims to have global constants, but also says 0 global entries") + + @assert(!graph.haveGlobalConstants + || (graph.globalNumEntries >= graph.nodeNumEntries + && graph.globalNumDiags >= graph.nodeNumDiags + && graph.globalMaxNumRowEntries >= graph.nodeMaxNumRowEntries), + "Graph claims to have global constants, but some of the local " + * "constants are greater than their corresponding global constants") + + @assert(!isStorageOptimized(graph) + || graph.pftype == STATIC_PROFILE, + "Storage is optimized, but graph is not STATIC_PROFILE") + + @assert(!isGloballyIndexed(graph) + || length(graph.rowOffsets) == 0 + || (length(graph.rowOffsets) == localNumRows +1 + && graph.rowOffsets[localNumRows+1] == length(graph.globalIndices1D)), + "If rowOffsets has nonzero size and the graph is globally " + * "indexed, then rowOffsets must have N+1 rows and rowOffsets[N+1] " + * "must equal the length of globalIndices1D") + + @assert(!isLocallyIndexed(graph) + || length(graph.rowOffsets) == 0 + || (length(graph.rowOffsets) == localNumRows +1 + && graph.rowOffsets[localNumRows+1] == length(graph.localIndices1D)), + "If rowOffsets has nonzero size and the graph is globally " + * "indexed, then rowOffsets must have N+1 rows and rowOffsets[N+1] " + * "must equal the length of localIndices1D") + + if graph.pftype == DYNAMIC_PROFILE + @assert(localNumRows == 0 + || length(graph.localIndices2D) > 0 + || length(graph.globalIndices2D) > 0, + "Graph has dynamic profile, the calling process has nonzero " + * "rows, but no 2-D column index storage is present.") + @assert(localNumRows == 0 + || length(graph.numRowEntries) != 0, + "Graph has dynamic profiles and the calling process has " + * "nonzero rows, but numRowEntries is not present") + + @assert(length(graph.localIndices1D) == 0 + && length(graph.globalIndices1D) == 0, + "Graph has dynamic profile, but 1D allocations are present") + + @assert(length(graph.rowOffsets) == 0, + "Graph has dynamic profile, but row offsets are present") + + elseif graph.pftype == STATIC_PROFILE + @assert(length(graph.localIndices1D) != 0 + || length(graph.globalIndices1D) != 0, + "Graph has static profile, but 1D allocations are not present") + + @assert(length(graph.localIndices2D) == 0 + && length(graph.globalIndices2D) == 0, + "Graph has static profile, but 2D allocations are present") + else + error("Unknown profile type: $(graph.pftype)") + end + + if graph.indicesType == LOCAL_INDICES + @assert(length(graph.globalIndices1D) == 0 + && length(graph.globalIndices2D) == 0, + "Indices are local, but global allocations are present") + @assert(graph.nodeNumEntries == 0 + || length(graph.localIndices1D) > 0 + || length(graph.localIndices2D) > 0, + "Indices are local and local entries exist, but there aren't local allocations present") + elseif graph.indicesType == GLOBAL_INDICES + @assert(length(graph.localIndices1D) == 0 + && length(graph.localIndices2D) == 0, + "Indices are global, but local allocations are present") + @assert(graph.nodeNumEntries == 0 + || length(graph.globalIndices1D) > 0 + || length(graph.globalIndices2D) > 0, + "Indices are global and local entries exist, but there aren't global allocations present") + else + warn("Unknown indices type: $(graph.indicesType)") + end + + #check actual allocations + const lenRowOffsets = length(graph.rowOffsets) + if graph.pftype == STATIC_PROFILE && lenRowOffsets != 0 + @assert(lenRowOffsets == localNumRows+1, + "Graph has static profile, rowOffsets has a nonzero length " + * "($lenRowOffsets), but is not equal to the " + * "local number of rows plus one ($(localNumRows+1))") + const actualNumAllocated = graph.rowOffsets[localNumRows+1] + @assert(!isLocallyIndexed(graph) + || length(graph.localIndices1D) == actualNumAllocated, + "Graph has static profile, rowOffsets has a nonzero length, " + * "but length(localIndices1D) = $(length(graph.localIndices1D)) " + * "!= actualNumAllocated = $actualNumAllocated") + @assert(!isGloballyIndexed(graph) + || length(graph.globalIndices1D) == actualNumAllocated, + "Graph has static profile, rowOffsets has a nonzero length, " + * "but length(globalIndices1D) = $(length(graph.globalIndices1D)) " + * "!= actualNumAllocated = $actualNumAllocated") + end + end +end + +function setLocallyModified(graph::CRSGraph) + graph.indicesAreSorted = false + graph.noRedundancies = false + graph.haveLocalConstants = false +end + +function sortAndMergeAllIndices(graph::CRSGraph, sorted::Bool, merged::Bool) + @assert(isLocallyIndexed(graph), + "This method may only be called after makeIndicesLocal(graph)") + @assert(merged || isStoragedOptimized(graph), + "The graph is already storage optimized, " + * "so we shouldn't be merging any indices.") + + if !sorted || !merged + localNumRows = getLocalNumRows(graph) + totalNumDups = 0 + for localRow = 1:localNumRows + rowInfo = getRowInfo(graph, localRow) + if !sorted + sortRowIndices(graph, rowInfo) + end + if !merged + numDups += mergeRowIndices(graph, rowInfo) + end + recycleRowInfo(rowInfo) + end + graph.nodeNumEntries -= totalNumDups + graph.indiciesAreSorted = true + graph.noRedunancies = true + end +end + +function sortRowIndices(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID}) where {GID, PID, LID <: Integer} + if rowInfo.numEntries > 0 + localColumnIndices = getLocalView(graph, rowInfo) + sort!(localColumnIndices) + end +end + +function mergeRowIndices(graph::CRSGraph{GID, PID, LID}, rowInfo::RowInfo{LID}) where {GID, PID, LID <: Integer} + localColIndices = getLocalView(graph, rowInfo) + localColIndices[:] = unique(localColIndices) + mergedEntries = length(localColIndices) + graph.numRowEntries[rowInfo.localRow] = mergedEntries + + rowInfo.numEntries - mergedEntries +end + + +function setDomainRangeMaps(graph::CRSGraph{GID, PID, LID}, domainMap::BlockMap{GID, PID, LID}, rangeMap::BlockMap{GID, PID, LID}) where {GID, PID, LID} + if graph.domainMap != domainMap + graph.domainMap = domainMap + graph.importer = Nullable{Import{GID, PID, LID}}() + end + if graph.rangeMap != rangeMap + graph.rangeMap = rangeMap + graph.exporter = Nullable{Export{GID, PID, LID}}() + end +end + + +function globalAssemble(graph::CRSGraph) + @assert isFillActive(graph) "Fill must be active before calling globalAssemble(graph)" + + comm = JuliaPetra.comm(graph) + myNumNonlocalRows = length(graph.nonlocals) + + maxNonlocalRows = maxAll(comm, myNumNonlocalRows) + if maxNonlocalRows != 0 + return + end + + #skipping: nonlocalRowMap = null + + numEntPerNonlocalRow = Array{LID, 1}(myNumNonlocalRows) + myNonlocalGlobalRows = Array{GID, 1}(myNumNonlocalRows) + + for (i, (key, val)) = zip(1:length(graph.nonlocals), graph.nonlocals) + myNonlocalGlobalRows[i] = key + const globalCols = val #const b/c changing in place + sort!(globalCols) + globalCols[:] = unique(globalCols) + numEntPerNonlocalRow[i] = length(globalCols) + end + + myMinNonLocalGlobalRow = minimum(myNonLocalGlobalRows) + + globalMinNonlocalRow = minAll(comm, myMinNonlocalGlobalRow) + + nonlocalRowMap = BlockMap(-1, myNonlocalGlobalRows, comm) + + nonlocalGraph = CRSGraph(nonlocalRowMap, numEntPerNonlocalRow, STATIC_PROFILE) + for (i, (key, val)) = zip(1:length(graph.nonlocals), graph.nonlocals) + globalRow = key + globalColumns = val + numEnt = length(numEntPerNonlocalRow[i]) + insertGlobalIndices(nonLocalGraph, globalRow, numEnt, globalColumns) + end + + const origRowMap = graph.rowMap + const origRowMapIsOneToOne = isOneToOne(origRowMap) + + if origRowMapIsOneToOne + exportToOrig = Export(nonlocalRowMap, origRowMap) + doExport(nonLocalGraph, graph, exportToOrig, INSERT) + else + oneToOneRowMap = createOneToOne(origRowMap) + exportToOneToOne = Export(nonlocalRowMap, oneToOneRowMap) + + oneToOneGraph = CRSGraph(oneToOneRowMap, 0) + doExport(nonlocalGraph, oneToOneGraph, exportToOneToOne, INSERT) + + #keep memory highwater mark down + #nonlocalGraph = null + + importToOrig(oneToOneRowMap, origRowMap) + doImport(oneToOneGraph, graph, importToOrig, INSERT) + end + clear!(graph.nonLocals) + + checkInternalState(graph) +end + +function makeIndicesLocal(graph::CRSGraph{GID, PID, LID}) where {GID, PID, LID} + @assert hasColMap(graph) "The graph does not have a column map yet. This method should never be called in that case" + + colMap = get(graph.colMap) + localNumRows = getLocalNumRows(graph) + + if isGloballyIndexed(graph) && localNumRows != 0 + numRowEntries = graph.numRowEntries + + if getProfileType(graph) == STATIC_PROFILE + if GID == LID + graph.localIndices1D = graph.globalIndices1D + + else + @assert(length(graph.rowOffsets) != 0, + "length(graph.rowOffsets) == 0. " + * "This should never happen at this point") + const numEnt = graph.rowOffsets[localNumRows+1] + graph.localIndices1D = Array{LID, 1}(numEnt) + end + + + localColumnMap = getLocalMap(colMap) + + numBad = convertColumnIndicesFromGlobalToLocal( + graph.localIndices1D, + graph.globalIndices1D, + graph.rowOffsets, + localColumnMap, + numRowEntries) + + if numBad != 0 + throw(InvalidArgumentError("When converting column indices from " + * "global to local, we encountered $numBad indices that " + * "do not live in the column map on this process")) + end + + graph.globalIndices1D = Array{LID, 1}(0) + else #graph has dynamic profile + graph.localIndices2D = Array{Array{LID, 1}, 1}(localNumRows) + for localRow = 1:localNumRows + if length(graph.globalIndices2D[localRow]) != 0 + globalIndices = graph.globalIndices2D[localRow] + + graph.localIndices2D[localRow] = [lid(colMap, gid) for gid in globalIndices] + if @debug + @assert(minimum(graph.localIndices2D[localRow]) > 0, + "Globalal indices were not found in the column Map") + end + end + end + graph.globalIndices2D = Array{GID, 1}[] + end + end + + graph.localGraph = LocalCRSGraph(graph.localIndices1D, graph.rowOffsets) + graph.indicesType = LOCAL_INDICES + checkInternalState(graph) +end + + +function convertColumnIndicesFromGlobalToLocal(localColumnIndices::AbstractArray{LID, 1}, + globalColumnIndices::AbstractArray{GID, 1}, ptr::AbstractArray{LID, 1}, + localColumnMap::BlockMap{GID, PID, LID}, numRowEntries::AbstractArray{LID, 1} + )::LID where {GID, PID, LID} + + + localNumRows = max(length(ptr)-1, 0) + numBad = 0 + for localRow = 1:localNumRows + offset = ptr[localRow] + + for j = 0:numRowEntries[localRow]-1 + gid = globalColumnIndices[offset+j] + localColumnIndices[offset+j] = lid(localColumnMap, gid) + if localColumnIndices[offset+j] == 0 + numBad += 1 + end + end + end + numBad +end + +#covers the overlap between insert methods +macro insertIndicesImpl(indicesType, innards) + indices1D = Symbol(indicesType*"Indices1D") + indices2D = Symbol(indicesType*"Indices2D") + + esc(quote + rowInfo = getRowInfo(graph, myRow) + numNewIndices = length(indices) + newNumEntries = rowInfo.numEntries + numNewIndices + + if newNumEntries > rowInfo.allocSize + if getProfileType(graph) == STATIC_PROFILE + $innards + else + newAllocSize = 2*rowInfo.allocSize + if newAllocSize < newNumEntries + newAllocSize = newNumEntries + end + resize!(graph.$indices2D[myRow], newAllocSize) + end + end + + if length(graph.$indices1D) != 0 + offset = rowInfo.offset1D + rowInfo.numEntries + destRange = offset+1:offset+numNewIndices + + graph.$indices1D[destRange] = indices[:] + else + graph.$indices2D[myRow][rowInfo.numEntries+1:newNumEntries] = indices[:] + end + + graph.numRowEntries[myRow] += numNewIndices + graph.nodeNumEntries += numNewIndices + setLocallyModified(graph) + + recycleRowInfo(rowInfo) + if @debug + chkNewNumEntries = getNumEntriesInLocalRow(graph, myRow) + @assert(chkNewNumEntries == newNumEntries, + "Internal Logic error: chkNewNumEntries = $chkNewNumEntries " + * "!= newNumEntries = $newNumEntries") + end + nothing + end) +end + +function insertLocalIndicesImpl(graph::CRSGraph{GID, PID, LID}, + myRow::LID, indices::AbstractArray{LID, 1}) where { + GID, PID, LID <: Integer} + @insertIndicesImpl "local" begin + throw(InvalidArgumentError("new indices exceed statically allocated graph structure")) + end +end + +#TODO figure out if this all can be moved to @insertIndicesImpl +function insertGlobalIndicesImpl(graph::CRSGraph{GID, PID, LID}, + myRow::LID, indices::AbstractArray{GID, 1}) where { + GID <: Integer, PID, LID <: Integer} + @insertIndicesImpl "global" begin + @assert(rowInfo.numEntries <= rowInfo.allocSize, + "For local row $myRow, rowInfo.numEntries = $(rowInfo.numEntries) " + * "> rowInfo.allocSize = $(rowInfo.allocSize).") + + dupCount = 0 + if length(graph.globalIndices1D) != 0 + curOffset = rowInfo.offset1D + @assert(length(graph.globalIndices1D) >= curOffset, + "length(graph.globalIndices1D) = $(length(graph.globalIndices1D)) " + * ">= curOffset = $curOffset") + @assert(length(graph.globalIndices1D) >= curOffset + rowInfo.offset1D, + "length(graph.globalIndices1D) = $(length(graph.globalIndices1D)) " + * ">= curOffset+rowInfo.offset1D = $(curOffset + rowInfo.offset1D)") + + range = curOffset:curOffset+rowInfo.numEntries + globalIndicesCur = view(graph.globalIndices1D, range) + else + #line 1959 + + globalIndices = graph.globalIndices2D[myRow] + @assert(rowInfo.allocSize == length(globalIndices), + "rowInfo.allocSize = $(rowInfo.allocSize) " + * "== length(globalIndices) = $(length(globalIndices))") + @assert(rowInfo.numEntries <= length(globalIndices), + "rowInfo.numEntries = $(rowInfo.numEntries) " + * "== length(globalIndices) = $(length(globalIndices))") + + globalIndicesCur = view(globalIndices, 0, rowInfo.numEntries) + end + for newIndex = indices + dupCount += count(old -> old==newIndex, globalIndicesCur) + end + + numNewToInsert = numNewInds - dupCount + @assert numNewToInsert >= 0 "More duplications than indices" + + if rowInfo.numEntries + numNewToInsert > rowInfo.allocSize + throw(InvalidArgumentError("$(myPid(comm(graph))): " + * "For local row $myRow, even after excluding " + * "$dupCount duplicate(s) in input, the new number " + * "of entries $(rowInfo.numEntries + numNewToInsert) " + * "still exceeds this row's static allocation size " + * "$(rowInfo.allocSize). You must either fix the upper " + * "bound on number of entries in this row, or switch " + * "to dynamic profile.")) + end + + if length(graph.globalIndices) != 0 + curOffset = rowInfo.offset1D + globalIndicesCur = view(graph.globalIndices1D, + range(curOffset, 1, rowInfo.numEntries)) + globalIndicesNew = view(graph.globalIndices1D, + curOffset+rowInfo.numEntries+1 : currOffset+rowInfo.allocSize) + else + #line 2036 + + globalIndices = graph.globalIndices2D[myRow] + globalIndicesCur = view(globalIndices, 1:rowInfo.numEntries) + globalIndicesNew = view(globalIndices, + rowInfo.numEntries+1 : rowInfo.allocSize-rowInfo.numEntries) + end + + curPos = 1 + for globalIndexToInsert = indices + + alreadyInOld = globalIndexToInsert in globalIndicesCur + if !alreadyInOld + @assert(curPos <= numNewToInsert, + "curPos = $curPos >= numNewToInsert = $newToInsert.") + globalIndicesNew[curPos] = globalIndexToInsert + curPos += 1 + end + end + + graph.numRowEntries[myRow] = rowInfo.numEntries+numNewToInsert + graph.nodeNumEntries += numNewToInsert + setLocallyModified(graph) + + if @debug + newNumEntries = rowInfo.numEntries + numNewToInsert + chkNewNumEntries = getNumEntiresInLocalRow(graph, myRow) + @assert(chkNewNumEntries == newNumEntries, + "chkNewNumEntries = $chkNewNumEntries " + * "!= newNumEntries = $newNumEntries") + end + return + end +end + + + +#internal implementation of makeColMap, needed to handle some return and debuging stuff +#returns Tuple(error, colMap) +function __makeColMap(graph::CRSGraph{GID, PID, LID}, wrappedDomMap::Nullable{BlockMap{GID, PID, LID}} + ) where {GID, PID, LID} + error = false + + if isnull(wrappedDomMap) + return Nullable{BlockMap{GID, PID, LID}}() + end + domMap = get(wrappedDomMap) + + if isLocallyIndexed(graph) + wrappedColMap = graph.colMap + + if isnull(wrappedColMap) + warn("$(myPid(comm(graph))): The graph is locally indexed, but does not have a column map") + + error = true + myColumns = GID[] + else + colMap = get(wrappedColMap) + if linearMap(colMap) #i think isContiguous(map) <=> linearMap(map)? + numCurGIDs = numMyElements(colMap) + myFirstGlobalIndex = minMyGIDs(colMap) + + myColumns = collect(range(myFirstGlobalIndex, 1, numCurGIDs)) + else + myColumns = copy(myGlobalElements(colMap)) + end + end + return (error, BlockMap(myColumns, comm(domMap))) + end + + #else if graph.isGloballyIndexed + const localNumRows = getLocalNumRows(graph) + + numLocalColGIDs = 0 + + gidIsLocal = falses(localNumRows) + remoteGIDSet = Set() + + #if rowMap != null + const rowMap = graph.rowMap + + for localRow = 1:localNumRows + globalRow = gid(rowMap, localRow) + (rowGIDPtr, numEnt) = getGlobalRowViewPtr(graph, globalRow) + + if numEnt != 0 + for k = 1:numEnt + gid::GID = unsafe_load(rowGIDPtr, k) + lid::LID = JuliaPetra.lid(domMap, gid) + if lid != 0 + @inbounds if !gidIsLocal[lid] + gidIsLocal[lid] = true + numLocalColGIDs += 1 + end + else + #don't need containment checks, set already takes care of that + push!(remoteGIDSet, gid) + end + end + end + end + + + + numRemoteColGIDs = length(remoteGIDSet) + + #line 214, abunch of explanation of serial short circuit + if numProc(comm(domMap)) == 1 + if numRemoteColGIDs != 0 + error = true + end + if numLocalColGIDs == localNumRows + return (error, domMap) + end + end + myColumns = Vector{GID}(numLocalColGIDs+numRemoteColGIDs) + localColGIDs = view(myColumns, 1:numLocalColGIDs) + remoteColGIDs = view(myColumns, numLocalColGIDs+1:numLocalColGIDs+numRemoteColGIDs) + + remoteColGIDs[:] = [el for el in remoteGIDSet] + + remotePIDs = Array{PID, 1}(numRemoteColGIDs) + + remotePIDs = remoteIDList(domMap, remoteColGIDs)[1] + if any(remotePIDs .== 0) + if @debug + warn("Some column indices are not in the domain Map") + end + error = true + end + + order = sortperm(remotePIDs) + permute!(remotePIDs, order) + permute!(remoteColGIDs, order) + + #line 333 + + numDomainElts = numMyElements(domMap) + if numLocalColGIDs == numDomainElts + if linearMap(domMap) #I think isContiguous() <=> linearMap() + localColGIDs[1:numLocalColGIDs] = range(minMyGID(domMap), 1, numLocalColGIDs) + else + domElts = myGlobalElements(domMap) + localColGIDs[1:length(domElts)] = domElts + end + else + numLocalCount = 0 + if linearMap(domMap) #I think isContiguous() <=> linearMap() + curColMapGID = minMyGID(domMap) + for i = 1:numDomainElts + if gidIsLocal[i] + numLocalCount += 1 + localColGIDs[numLocalCount] = curColMapGID + end + curColMapGID += 1 + end + else + domainElts = myGlobalElement(domMap) + for i = 1:numDomainElts + if gidIsLocal[i] + numLocalCount += 1 + localColGIDs[numLocalCount] = domainElts[i] + end + curColMapGID += 1 + end + end + + if numLocalCount != numLocalColGIDs + if @debug + warn("$(myPid(comm(graph))): numLocalCount = $numLocalCount " + * "!= numLocalColGIDs = $numLocalColGIDs. " + * "This should not happen.") + end + error = true + end + end + + return (error, BlockMap(myColumns, comm(domMap))) +end diff --git a/src/CSRMatrix.jl b/src/CSRMatrix.jl new file mode 100644 index 0000000..19799dc --- /dev/null +++ b/src/CSRMatrix.jl @@ -0,0 +1,1241 @@ +export CSRMatrix, insertGlobalValues +#TODO export other CSRMatrix-specific symbols + +using TypeStability + +mutable struct CSRMatrix{Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} <: DistRowMatrix{Data, GID, PID, LID} + rowMap::BlockMap{GID, PID, LID} + colMap::Nullable{BlockMap{GID, PID, LID}} + + importMV::Nullable{MultiVector{Data, GID, PID, LID}} + exportMV::Nullable{MultiVector{Data, GID, PID, LID}} + + myGraph::CRSGraph{GID, PID, LID} + + localMatrix::LocalCSRMatrix{Data, LID} + + values2D::Array{Array{Data, 1}, 1} + + #pull storageStatus and fillComplete from graph + + #Dict keys are row indices + #first element of each tuple is a column index + #second element of each tuple is the matching entry + nonlocals::Dict{GID, Tuple{Array{Data, 1}, Array{GID, 1}}} + + plist::Dict{Symbol} + + function CSRMatrix{Data, GID, PID, LID}(rowMap::BlockMap{GID, PID, LID}, colMap::Nullable{BlockMap{GID, PID, LID}}, myGraph::CRSGraph{GID, PID, LID}, localMatrix::LocalCSRMatrix{Data, LID}, plist::Dict{Symbol}) where {Data, GID, PID, LID} + + #allocate values + localNumRows = getLocalNumRows(myGraph) + if getProfileType(myGraph) == STATIC_PROFILE + ptrs = myGraph.rowOffsets + localTotalNumEntries = ptrs[localNumRows+1] + + resize!(localMatrix.values, localTotalNumEntries) + + values2D = Array{Array{Data, 1}, 1}(0) + else #DYNAMIC_PROFILE + if isLocallyIndexed(myGraph) + graphIndices = myGraph.localIndices + else + graphIndices = myGraph.globalIndices + end + values2D = Array{Array{Data, 1}, 1}(localNumRows) + for r = 1:length(graphIndices) + values2D[r] = Array{Array{Data, 1}, 1}(length(graphIndices[r])) + end + end + + new(rowMap, + colMap, + Nullable{MultiVector{Data, GID, PID, LID}}(), + Nullable{MultiVector{Data, GID, PID, LID}}(), + myGraph, + localMatrix, + values2D, + Dict{GID, Tuple{Array{Data, 1}, Array{GID, 1}}}(), + plist) + end + +end + +#### Constructors #### +#TODO document Constructors + +function CSRMatrix{Data}(rowMap::BlockMap{GID, PID, LID}, + maxNumEntriesPerRow::Union{Integer, Array{<:Integer, 1}}, + pftype::ProfileType; plist...) where {Data, GID, PID, LID} + CSRMatrix{Data}(rowMap, maxNumEntriesPerRow, pftype, Dict(Array{Tuple{Symbol, Any}}(plist))) +end +function CSRMatrix{Data}(rowMap::BlockMap{GID, PID, LID}, + maxNumEntriesPerRow::Union{Integer, Array{<:Integer, 1}}, + pftype::ProfileType, plist::Dict{Symbol}) where {Data, GID, PID, LID} + CSRMatrix{Data}(rowMap, Nullable{BlockMap{GID, PID, LID}}(), + maxNumEntriesPerRow, pftype, plist) +end + +function CSRMatrix{Data}(rowMap::BlockMap{GID, PID, LID}, + colMap::BlockMap{GID, PID, LID}, + maxNumEntriesPerRow::Union{Integer, Array{<:Integer, 1}}, + pftype::ProfileType; plist...) where {Data, GID, PID, LID} + CSRMatrix{Data}(rowMap, colMap, maxNumEntriesPerRow, pftype, Dict(Array{Tuple{Symbol, Any}}(plist))) +end +function CSRMatrix{Data}(rowMap::BlockMap{GID, PID, LID}, + colMap::BlockMap{GID, PID, LID}, + maxNumEntriesPerRow::Union{Integer, Array{<:Integer, 1}}, + pftype::ProfileType, plist::Dict{Symbol}) where {Data, GID, PID, LID} + CSRMatrix{Data}(rowMap, Nullable(colMap), maxNumEntriesPerRow, + pftype, plist) +end + +function CSRMatrix{Data}(rowMap::BlockMap{GID, PID, LID}, + colMap::Nullable{BlockMap{GID, PID, LID}}, + maxNumEntriesPerRow::Union{Integer, Array{<:Integer, 1}}, + pftype::ProfileType; plist...) where {Data, GID, PID, LID} + CSRMatrix{Data}(rowMap, colMap, maxNumEntriesPerRow, pftype, Dict(Array{Tuple{Symbol, Any}}(plist))) +end +function CSRMatrix{Data}(rowMap::BlockMap{GID, PID, LID}, + colMap::Nullable{BlockMap{GID, PID, LID}}, + maxNumEntriesPerRow::Union{Integer, Array{<:Integer, 1}}, + pftype::ProfileType, plist::Dict{Symbol}) where {Data, GID, PID, LID} + graph = CRSGraph(rowMap, maxNumEntriesPerRow, pftype, plist) + + matrix = CSRMatrix{Data, GID, PID, LID}(rowMap, colMap, + graph, LocalCSRMatrix{Data, LID}(), plist) + + resumeFill(matrix, plist) + + matrix +end + +function CSRMatrix{Data}(graph::CRSGraph{GID, PID, LID}; plist... + ) where {Data, GID, PID, LID} + CSRMatrix{Data}(graph, Dict(Array{Tuple{Symbol, Any}}(plist))) +end +function CSRMatrix{Data}(graph::CRSGraph{GID, PID, LID},plist::Dict{Symbol} + ) where {Data, GID, PID, LID} + numCols = numMyElements(getColMap(graph)) + localGraph = getLocalGraph(graph) + val = Array{Data, 1}(length(localGraph.entries)) + localMatrix = LocalCSRMatrix(numCols, val, localGraph) + + CSRMatrix(graph.rowMap, graph.colMap, graph, localMatrix, plist) +end + +function CSRMatrix(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + rowOffsets::AbstractArray{LID, 1}, colIndices::AbstractArray{LID, 1}, values::AbstractArray{Data, 1}; + plist...) where {Data, GID, PID, LID} + CSRMatrix(rowMap, colMap, rowOffsets, colIndices, values, Dict(Array{Tuple{Symbol, Any}}(plist))) +end +function CSRMatrix(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + rowOffsets::AbstractArray{LID, 1}, colIndices::AbstractArray{LID, 1}, values::AbstractArray{Data, 1}, + plist::Dict{Symbol}) where {Data, GID, PID, LID} + + #check user's input. Might throw on only some processes, causing deadlock + if length(values) != length(colIndices) + throw(InvalidArgumentError("values and columnIndices must " + * "have the same length")) + end + + graph = CRSGraph(rowMap, colMap, rowOffsets, columnIndices, plist) + localGraph = getLocalGraph(graph) + + numCols = numMyElements(getColMap(graph)) + localMatrix = LocalCSRMatrix(numCols, values, localGraph) + + CSRMatrix(rowMap, Nullable(colMap), graph, localMatrix, plist) +end + +function CSRMatrix(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + localMatrix::LocalCSRMatrix{Data, LID}; plist... + ) where {Data, GID, PID, LID} + CSRMatrix(rowMap, colMap, localMatrix, Dict(Array{Tuple{Symbol, Any}}(plist))) +end +function CSRMatrix(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + localMatrix::LocalCSRMatrix{Data, LID}, plist::Dict{Symbol} + ) where {Data, GID, PID, LID} + + graph = CRSGraph(rowMap, colMap, localMatrix.graph, plist) + + matrix = CSRMatrix(rowMap, colMap, graph, localMatrix, plist) + + computeGlobalConstants(matrix) + matrix +end + +function CSRMatrix(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + localMatrix::AbstractArray{Data, 2}; plist... + ) where {Data, GID, PID, LID} + CSRMatrix(rowmap, colMap, localMatrix, Dict(Array{Tuple{Symbol, Any}}(plist))) +end +function CSRMatrix(rowMap::BlockMap{GID, PID, LID}, colMap::BlockMap{GID, PID, LID}, + localMatrix::AbstractArray{Data, 2}, plist::Dict{Symbol} + ) where {Data, GID, PID, LID} + linearIndices = find(x -> x!=0, localMatrix) + rowIndicesIter, colIndicesIter, valuesIter = zip( + sort!(collect(zip(ind2sub(size(localMatrix), linearIndices)..., + localMatrix[linearIndices])))...) + rowIndices = collect(rowIndicesIter) + rowOffsets = Array{LID, 1}(size(localMatrix, 1)+1) + row = 1 + j = 1 + for i in 1:length(rowIndices) + if rowIndices[i] > row + row += 1 + rowOffsets[row] = i + end + end + rowOffsets[length(rowOffsets)] = length(rowIndices)+1 + + CSRMatrix(rowMap, colMap, rowOffsets, + collect(colIndicesIter), collect(valuesIter), plist) +end + + +#### Internal methods #### +function combineGlobalValues(matrix::CSRMatrix{Data, GID, PID, LID}, + globalRow::GID, indices::AbstractArray{GID, 1}, + values::AbstractArray{Data, 1}, cm::CombineMode + ) where {Data, GID, PID, LID} + + if cm == ADD || cm == INSERT + insertGlobalValuesFiltered(globalRow, indices, values) + else + #TODO implement ABSMAX and REPLACE + #not implmenented in TPetra, because its not a common use case and difficult (see FIXME on line 6225) + throw(InvalidArgumentError("Not yet implemented for combine mode $cm")) + end +end + +""" +Returns a nullable object of the column map multivector +""" +function getColumnMapMultiVector(mat::CSRMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID}, force = false) where {Data, GID, PID, LID} + if !hasColMap(mat) + throw(InvalidStateError("Can only call getColumnMapMultiVector with a matrix that has a column map")) + end + if !isFillComplete(mat) + throw(InvalidStateError("Can only call getColumnMapMultiVector if the matrix is fill active")) + end + + numVecs = numVectors(X) + importer = getGraph(mat).importer + colMap = getColMap(mat) + + #if import object is trivial, don't need a seperate column map multivector + if !isnull(importer) || force + if isnull(mat.importMV) || numVectors(get(mat.importMV)) != numVecs + mat.importMV = Nullable(MultiVector{Data, GID, PID, LID}(colMap, numVecs)) + else + mat.importMV + end + else + Nullable{MultiVector{Data, GID, PID, LID}}() + end +end + +""" +Returns a nullable object of the row map multivector +""" +function getRowMapMultiVector(mat::CSRMatrix{Data, GID, PID, LID}, Y::MultiVector{Data, GID, PID, LID}, force = false) where {Data, GID, PID, LID} + if !isFillComplete(mat) + throw(InvalidStateError("Cannot call getRowMapMultiVector wif the matrix is fill active")) + end + + numVecs = numVectors(Y) + exporter = getGraph(mat).exporter + rowMap = getRowMap(mat) + + if !isnull(exporter) || force + if isnull(mat.exportMV) || getNumVectors(get(map.exportMV)) != numVecs + mat.exportMV = Nullable(MultiVector(rowMap, numVecs)) + else + mat.exportMV + end + else + Nullable{MultiVector{Data, GID, PID, LID}}() + end +end + + +#does nothing, exists only to be a parallel to CRSGraph +computeGlobalConstants(matrix::CSRMatrix) = nothing + +#Tpetra's only clears forbNorm, exists only to be a parallel to CRSGraph +clearGlobalConstants(matrix::CSRMatrix) = nothing + +function globalAssemble(matrix::CSRMatrix) + comm = JuliaPetra.comm(matrix) + if !isFillActive(matrix) + throw(InvalidStateError("Fill must be active to call this method")) + end + + myNumNonLocalRows = length(matrix.nonlocals) + nooneHasNonLocalRows = maxAll(comm, myNumNonLocalRows == 0) + if nooneHasNonLocalRows + #no process has nonlocal rows, so nothing to do + return + end + + #nonlocalRowMap = BlockMap{GID, PID, LID}() + + numEntPerNonlocalRow = Array{Integer, 1}(numNonLocalRows) + myNonlocalGlobalRows = Array{GID, 1}(numNonLocalRows) + + curPos = 1 + for (key, val) in matrix.nonlocal + myNonlocalGlobalRows[curPos] = key + + values = val[1] + globalColumns = val[2] + + order = sortperm(globalColumns) + permute!(globalColumns, order) + permute!(values, order) + + + curPos += 1 + end + + #merge2 + if length(globalColumns) > 0 + setIndex = 1 + valuesSum = 0 + currentIndex = globalColumns[1] + for i = 1:length(globalColumns) + if currentIndex != globalColumns[i] + values[setIndex] = valuesSum + globalColumns[setIndex] = currentIndex + setIndex += 1 + valuesSum = 0 + end + valuesSum += values[i] + i += 1 + end + values[setIndex] = valuesSum + globalColumns[setIndex] = currentIndex + + resize!(values, setIndex) + resize!(globalColumns, setIndex) + end + + numEntPerNonloalRow[curPose] = length(globalColumns) + + + #don't need to worry about finding the indexBase + nonlocalRowMap = BlockMap(0, myNonlocalGlobalRows, comm) + + nonlocalMatrix = CSRMatrix(nonlocalRowMap, numEntPernonlocalRow, STATIC_PROFILE) + + curPos = 1 + for (key, val) in matrix.nonlocals + globalRow = key + + vals = val[1] + globalCols = val[2] + + insertGlobalValues(nonlocalMatrix, globalRow, globalCols, vals) + end + + origRowMap = rowMap(matrix) + origRowMapIsOneToOne = isOneToOne(origRowMap) + + + if origRowMapIsOneToOne + exportToOrig = Export(nonlocalRowMap, origRowMap) + isLocallyComplete = isLocallyComplete(exportToOrig) + doExport(nonlocalMatrix, matrix, exportToOrig, ADD) + else + oneToOneRowMap = createOneToOne(origRowMap) + exportToOneToOne = Export(nonlocalRowMap, oneToOneRowMap) + + isLocallyComplete = isLocallyComplete(exportToOneToOne) + + oneToOneMatrix = CSRMatrix{Data}(oneToOneRowMap, 0) + + doExport(nonlocalMatrix, onToOneMatrix, exportToOneToOne, ADD) + + #nonlocalMatrix = null + + importToOrig = Import(oneToOneRowMap, origRowMap) + doImport(oneToOneMatrix, matrix, importToOrig, ADD) + end + + empty!(matrix.nonlocals) + + globallyComplete = minAll(comm, isLocallyComplete) + if !globallyComplete + throw(InvalidArgumentError("On at least one process, insertGlobalValues " + * "was called with a global row index which is not in the matrix's " + * "row map on any process in its communicator.")) + end +end + + + +function fillLocalGraphAndMatrix(matrix::CSRMatrix{Data, GID, PID, LID}, + plist::Dict{Symbol}) where {Data, GID, PID, LID} + localNumRows = getLocalNumRows(matrix) + + myGraph = matrix.myGraph + localMatrix = matrix.localMatrix + + matrix.localMatrix.graph.entries = myGraph.localIndices1D + + #most of the debug sections were taken out, they could be re-added wrapped with `if @debug` + if getProfileType(matrix) == DYNAMIC_PROFILE + numRowEntries = myGraph.numRowEntries + + ptrs = Array{LID, 1}(localNumRows+1) + localTotalNumEntries = computeOffsets(ptrs, numRowEntries) + + inds = Array{LID, 1}(localTotalNumEntries) + #work around type instability required by localMatrix.values + vals_concrete = Array{Data, 1}(localTotalNumEntries) + vals = vals_concrete + + localIndices2D = myGraph.localIndices2D + for row = 1:localNumRows + numEnt = numRowEnt[row] + dest = range(ptrs[row], 1, numEnt) + + inds[dest] = localIndices2D[row][:] + vals_concrete[dest] = matrix.values2D[row][:] + end + elseif getProfileType(matrix) == STATIC_PROFILE + curRowOffsets = myGraph.rowOffsets + + if myGraph.storageStatus == STORAGE_1D_UNPACKED + #pack row offsets into ptrs + + localTotalNumEntries = 0 + + ptrs = Array{LID, 1}(localNumRows + 1) + numRowEnt = myGraph.numRowEntries + localTotalNumEntries = computeOffsets(ptrs, numRowEnt) + + inds = Array{LID, 1}(localTotalNumEntries) + #work around type instability required by localMatrix.values + vals_concrete = Array{Data, 1}(localTotalNumEntries) + vals = vals_concrete + + #line 1234 + for row in 1:localNumRows + srcPos = curRowOffsets[row] + dstPos = ptrs[row] + dstEnd = ptrs[row+1]-1 + dst = dstPos:dstEnd + src = srcPos:srcPos+dstEnd-dstPos + + inds[dst] = myGraph.localIndices1D[src] + vals_concrete[dst] = localMatrix.values[src] + end + else + #dont have to pack, just set pointers + ptrs = myGraph.rowOffsets + inds = myGraph.localIndices1D + vals = localMatrix.values + end + end + + if get(plist, :optimizeStorage, true) + empty!(myGraph.localIndices2D) + empty!(myGraph.numRowEntries) + + empty!(matrix.values2D) + + myGraph.rowOffsets = ptrs + myGraph.localIndices1D = inds + + myGraph.pftype = STATIC_PROFILE + myGraph.storageStatus = STORAGE_1D_PACKED + end + + myGraph.localGraph = LocalCRSGraph(inds, ptrs) + matrix.localMatrix = LocalCSRMatrix(myGraph.localGraph, vals, getLocalNumCols(matrix)) +end + +function insertNonownedGlobalValues(matrix::CSRMatrix{Data, GID, PID, LID}, + globalRow::GID, indices::AbstractArray{GID, 1}, values::AbstractArray{Data, 1} + ) where {Data, GID, PID, LID} + + curRow = matrix.nonlocals[globalRow] + curRowVals = curRow[1] + curRowInds = curRow[2] + + newCapacity = length(curRowInds) + length(indices) + + append!(curRowVals, values) + append!(curRowInds, indices) +end + +function getView(matrix::CSRMatrix{Data, GID, PID, LID}, rowInfo::RowInfo{LID})::AbstractArray{Data, 1} where {Data, GID, PID, LID} + if getProfileType(matrix) == STATIC_PROFILE && rowInfo.allocSize > 0 + range = rowInfo.offset1D:rowInfo.offset1D+rowInfo.allocSize-LID(1) + baseArray = matrix.localMatrix.values + if baseArray isa Vector{Data} + view(matrix.localMatrix.values, range) + else + view(matrix.localMatrix.values, range) + end + elseif getProfileType(matrix) == DYNAMIC_PROFILE + baseArray = matrix.values2D[rowInfo.localRow] + view(baseArray, LID(1):LID(length(baseArray))) + else + Data[] + end +end + +function getDiagCopyWithoutOffsets(rowMap, colMap, A::CSRMatrix{Data}) where {Data} + errCount = 0 + + D = Array{Data, 1}(getNumLocalRows(A)) + + for localRowIndex = 1:length(D) + D[localRowIndex] = 0 + globalIndex = gid(rowMap, localRowIndex) + localColIndex = lid(colMap, globalIndex) + if localColIndex != 0 + colInds, vals = getLocalRowView(A, localRowIndex) + + offset = 1 + numEnt = length(curRow) + while offset <= numEnt + if colInds[offset] == localColIndex + break; + end + offset += 1 + end + + if offset > numEnt + errCount += 1 + else + D[localRowIndex] = vals[offset] + end + end + end + D +end + + +function sortAndMergeIndicesAndValues(matrix::CSRMatrix{Data, GID, PID, LID}, + sorted, merged) where {Data, GID, PID, LID} + graph = getGraph(matrix) + localNumRows = getLocalNumRows(graph) + totalNumDups = 0 + + for localRow in LID(1):localNumRows + rowInfo = getRowInfo(graph, localRow) + if !sorted + inds, vals = getLocalRowView(matrix, rowInfo) + + order = sortperm(inds) + permute!(inds, order) + permute!(vals, order) + end + if !merged + totalNumDups += mergeRowIndicesAndValues(matrix, rowInfo) + end + + recycleRowInfo(rowInfo) + end + + if !sorted + graph.indicesAreSorted = true + end + if !merged + graph.nodeNumEntries -= totalNumDups + graph.noRedundancies = true + end +end + +function mergeRowIndicesAndValues(matrix::CSRMatrix{Data, GID, PID, LID}, + rowInfo::RowInfo{LID})::LID where {Data, GID, PID, LID} + + graph = getGraph(matrix) + indsView, valsView = getLocalRowView(matrix, rowInfo) + + if rowInfo.numEntries != 0 + newend = 1 + for cur in 2:rowInfo.numEntries + if indsView[newend]::LID != indsView[cur]::LID + #new entry, save it + newend += 1 + indsView[newend] = indsView[cur]::LID + valsView[newend] = valsView[cur]::Data + else + #old entry, merge it + valsView[newend] += valsView[cur]::Data + end + end + else + newend = 0 + end + + graph.numRowEntries[rowInfo.localRow] = newend + + rowInfo.numEntries - newend +end + + + +#### External methods #### +#TODO document external methods + +function insertGlobalValues(matrix::CSRMatrix{Data, GID, PID, LID}, globalRow::Integer, + indices::AbstractArray{LID, 1}, values::AbstractArray{Data, 1} + ) where {Data, GID, PID, LID} + myGraph = matrix.myGraph + + localRow = lid(getRowMap(matrix), globalRow) + + if localRow == 0 + insertNonownedGlobalValues(matrix, globalRow, indices, values) + else + numEntriesToInsert = length(indices) + if hasColMap(matrix) + colMap = getColMap(matrix) + + for k = 1:numEntriesToInsert + if !myGID(colMap, indices[k]) + throw(InvalidArgumentError("Attempted to insert entries into " + * "owned row $globalRow, at the following column indices " + * "$indices. At least one of those indices ($(indices[k])" + * ") is not in the column map on this process")) + end + end + end + + rowInfo = getRowInfo(myGraph, localRow) + curNumEntries = rowInfo.numEntries + newNumEntries = curNumEntries + length(numEntriesToInsert) + if newNumEntries > rowInfo.allocSize + if(getProfileType(matrix) == STATIC_PROFILE + && newNumEntries > rowInfo.allocSize) + throw(InvalidArgumentError("number of new indices exceed " + * "statically allocated graph structure")) + end + + updateGlobalAllocAndValues(myGraph, rowInfo, newNumEntries, + matrix.values2D[localRow]) + + recycleRowInfo(rowInfo) + rowInfo = getRowInfo(myGraph, localRow); + end + + insertIndicesAndValues(myGraph, rowInfo, indices, getView(matrix, rowInfo), + values, GLOBAL_INDICES) + + recycleRowInfo(rowInfo) + end + nothing +end + + +function resumeFill(matrix::CSRMatrix, plist::Dict{Symbol}) + resumeFill(matrix.myGraph, plist) + + clearGlobalConstants(matrix) + #graph handles fillComplete variable +end + +fillComplete(matrix::CSRMatrix; plist...) = fillComplete(matrix, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) + +function fillComplete(matrix::CSRMatrix, plist::Dict{Symbol}) + #TODO figure out if the second arg should be getColMap(matrix) + fillComplete(matrix, getRowMap(matrix), getRowMap(matrix), plist) +end + +function fillComplete(matrix::CSRMatrix{Data, GID, PID, LID}, + domainMap::BlockMap{GID, PID, LID}, rangeMap::BlockMap{GID, PID, LID}; + plist...) where {Data, GID, PID, LID} + fillComplete(matrix, domainMap, rangeMap, Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end + +function fillComplete(matrix::CSRMatrix{Data, GID, PID, LID}, + domainMap::BlockMap{GID, PID, LID}, rangeMap::BlockMap{GID, PID, LID}, + plist::Dict{Symbol}) where {Data, GID, PID, LID} + if isFillComplete(matrix) + throw(InvalidStateError( + "Matrix cannot be fill complete when fillComplete(...) is called")) + end + + const myGraph = matrix.myGraph + + assertNoNonlocalInserts = get(plist, :noNonlocalChanges, false) + #skipping sort ghosts stuff + + numProcs = numProc(comm(matrix)) + + needGlobalAssemble = !assertNoNonlocalInserts && numProcs > 1 + if needGlobalAssemble + globalAssemble(matrix) + else + if numProcs == 1 && length(matrix.nonlocals) != 0 + throw(InvalidStateError("Cannot have nonlocal entries on a serial run. An invalid entry is present.")) + end + end + + setDomainRangeMaps(myGraph, domainMap, rangeMap) + if !hasColMap(myGraph) + makeColMap(myGraph) + matrix.colMap = myGraph.colMap + end + + makeIndicesLocal(myGraph) + + sortAndMergeIndicesAndValues(matrix, isSorted(myGraph), isMerged(myGraph)) + + makeImportExport(myGraph) + computeGlobalConstants(myGraph) + + myGraph.fillComplete = true + checkInternalState(myGraph) + + fillLocalGraphAndMatrix(matrix, plist) +end + + +getProfileType(mat::CSRMatrix) = getProfileType(mat.myGraph) +isStorageOptimized(mat::CSRMatrix) = isStorageOptimized(mat.myGraph) + + +function getLocalDiagOffsets(matrix::CSRMatrix{Data, GID, PID, LID})::AbstractArray{LID, 1} where {Data, GID, PID, LID} + graph = matrix.myGraph + localNumRows = getLocalNumRows(graph) + getLocalDiagOffsets(graph) +end + + +#### Row Matrix functions #### + +isFillActive(mat::CSRMatrix) = isFillActive(mat.myGraph) +isFillComplete(mat::CSRMatrix) = isFillComplete(mat.myGraph) +getRowMap(mat::CSRMatrix) = mat.rowMap +hasColMap(mat::CSRMatrix) = !isnull(mat.colMap) +getColMap(mat::CSRMatrix) = get(mat.colMap) +isGloballyIndexed(mat::CSRMatrix) = isGloballyIndexed(mat.myGraph) +isLocallyIndexed(mat::CSRMatrix) = isLocallyIndexed(mat.myGraph) +getGraph(mat::CSRMatrix) = mat.myGraph + +getGlobalNumRows(mat::CSRMatrix) = getGlobalNumRows(mat.myGraph) +getGlobalNumCols(mat::CSRMatrix) = getGlobalNumCols(mat.myGraph) +getLocalNumRows(mat::CSRMatrix) = getLocalNumRows(mat.myGraph) +getLocalNumCols(mat::CSRMatrix) = numCols(mat.localMatrix) +getGlobalNumEntries(mat::CSRMatrix) = getGlobalNumEntries(mat.myGraph) +getLocalNumEntries(mat::CSRMatrix) = getLocalNumEntries(mat.myGraph) +getNumEntriesInGlobalRow(mat::CSRMatrix, grow) = getNumEntriesInGlobalRow(mat.myGraph, grow) +getNumEntriesInLocalRow(mat::CSRMatrix, lrow) = getNumEntriesInLocalRow(mat.myGraph, lrow) +getGlobalNumDiags(mat::CSRMatrix) = getGlobalNumDiags(mat.myGraph) +getLocalNumDiags(mat::CSRMatrix) = getLocalNumDiags(mat.myGraph) +getGlobalMaxNumRowEntries(mat::CSRMatrix) = getGlobalMaxNumRowEntries(mat.myGraph) +getLocalMaxNumRowEntries(mat::CSRMatrix) = getLocalMaxNumRowEntries(mat.myGraph) + +isLowerTriangular(mat::CSRMatrix) = isLowerTriangular(mat.myGraph) +isUpperTriangular(mat::CSRMatrix) = isUpperTriangular(mat.myGraph) + +function getGlobalRowCopy(matrix::CSRMatrix{Data, GID, PID, LID}, + globalRow::Integer + )::Tuple{Array{GID, 1}, Array{Data, 1}} where {Data, GID, PID, LID} + myGraph = matrix.myGraph + + rowInfo = getRowInfoFromGlobalRow(myGraph, GID(globalRow)) + viewRange = 1:rowInfo.numEntries + + + retVal = if rowInfo.localRow != 0 + if isLocallyIndexed(myGraph) + colMap = getColMap(myGraph) + curLocalIndices = getLocalView(myGraph, rowInfo)[viewRange] + curGlobalIndices = @. gid(colMap, curLocalIndices) + else + curGlobalIndices = getGlobalView(myGraph, rowInfo)[viewRange] + end + curValues = getView(matrix, rowInfo)[viewRange] + + (curGlobalIndices, curValues) + else + (GID[], Data[]) + end + + recycleRowInfo(rowInfo) + + retVal +end + + +function getLocalRowCopy(matrix::CSRMatrix{Data, GID, PID, LID}, + localRow::Integer + )::Tuple{AbstractArray{LID, 1}, AbstractArray{Data, 1}} where { + Data, GID, PID, LID} + myGraph = matrix.myGraph + + rowInfo = getRowInfo(myGraph, LID(localRow)) + viewRange = 1:rowInfo.numEntries + + retVal = if rowInfo.localRow != 0 + if isLocallyIndexed(myGraph) + curLocalIndices = Array{LID}(getLocalView(myGraph, rowInfo)[viewRange]) + else + colMap = getColMap(myGraph) + curGlobalIndices = getGlobalView(myGraph, rowInfo)[viewRange] + curLocalIndices = @. lid(colMap, curLocalIndices) + end + curValues = Array{Data}(getView(matrix, rowInfo)[viewRange]) + + (curLocalIndices, curValues) + else + (LID[], Data[]) + end + + recycleRowInfo(rowInfo) + + retVal +end + + +function getGlobalRowView(matrix::CSRMatrix{Data, GID, PID, LID}, + globalRow::Integer + )::Tuple{AbstractArray{GID, 1}, AbstractArray{Data, 1}} where { + Data, GID, PID, LID} + if isLocallyIndexed(matrix) + throw(InvalidStateError("The matrix is locally indexed, so cannot return a " + * "view of the row with global column indices. Use " + * "getGlobalRowCopy(...) instead.")) + end + myGraph = matrix.myGraph + + rowInfo = getRowInfoFromGlobalRow(myGraph, globalRow) + if rowInfo.localRow != 0 && rowInfo.numEntries > 0 + viewRange = 1:rowInfo.numEntries + indices = getGlobalView(myGraph, rowInfo)[viewRange] + values = getView(matrix, rowInfo)[viewRange] + else + indices = GID[] + values = Data[] + end + recycleRowInfo(rowInfo) + (indices, values) +end + +function getLocalRowView(matrix::CSRMatrix{Data, GID, PID, LID}, + localRow::Integer + )::Tuple{AbstractArray{LID, 1}, AbstractArray{Data, 1}} where { + Data, GID, PID, LID} + rowInfo = getRowInfo(matrix.myGraph, LID(localRow)) + retVal = getLocalRowView(matrix, rowInfo) + recycleRowInfo(rowInfo) + + retVal +end + +function getLocalRowView(matrix::CSRMatrix{Data, GID, PID, LID}, + rowInfo::RowInfo{LID} + )::Tuple{AbstractArray{LID, 1}, AbstractArray{Data, 1}} where { + Data, GID, PID, LID} + + if isGloballyIndexed(matrix) + throw(InvalidStateError("The matrix is globally indexed, so cannot return a " + * "view of the row with local column indices. Use " + * "getLocalalRowCopy(...) instead.")) + end + + const myGraph = matrix.myGraph + + if rowInfo.localRow != 0 && rowInfo.numEntries > 0 + viewRange = LID(1):rowInfo.numEntries + indices = view(getLocalView(myGraph, rowInfo), viewRange) + values = view(getView(matrix, rowInfo), viewRange) + else + indices = LID[] + values = Data[] + end + (indices, values) +end + + +TypeStability.@stable_function [(CSRMatrix{D, G, P, L}, L) + for (D, G, P, L) in Base.Iterators.product( + [Float64, Complex64], #Data + [UInt64, Int64, UInt32], #GID + [UInt8, Int8, UInt32], #PID + [UInt32, Int32]) #LID +] RegexDict((r"rowValues", Any)) begin + +function getLocalRowViewPtr end + +Base.@propagate_inbounds @inline function getLocalRowViewPtr( + matrix::CSRMatrix{Data, GID, PID, LID}, localRow::LID + )::Tuple{Ptr{LID}, Ptr{Data}, LID} where {Data, GID, PID, LID} + row = LID(localRow) + const graph = matrix.myGraph + + if getProfileType(graph) == STATIC_PROFILE + if (@debug) && !hasRowInfo(graph) + error("Row Info was deleted, but is still needed") + end + offset1D = graph.rowOffsets[row] + numEntries = (length(graph.numRowEntries) == 0 ? + graph.rowOffsets[row+1] - offset1D + : graph.numRowEntries[row]) + if numEntries > 0 + indicesPtr = pointer(graph.localIndices1D, offset1D) + rowValues = matrix.localMatrix.values + if rowValues isa SubArray{Data, 1, Vector{Data}, Tuple{UnitRange{LID}}, true} + valuesPtr = pointer(rowValues.parent, offset1D + rowValues.indexes[1].start) + elseif rowValues isa Vector{Data} + #else should be Vector, but assert anyways + valuesPtr = pointer(rowValues, offset1D) + else + error("localMatrix.values is of unsupported type $(typeof(rowValues)).") + end + + return (indicesPtr, valuesPtr, numEntries) + else + return (C_NULL, C_NULL, 0) + end + else #dynamic profile + indices = graph.localIndices2D[row] + values = matrix.values2D[row] + + return (pointer(indices, 0), pointer(values, 0), length(indices)) + end +end +end + + + +function getLocalDiagCopy(matrix::CSRMatrix{Data, GID, PID, LID})::MultiVector{Data, GID, PID, LID} where {Data, GID, PID, LID} + if !hasColMap(matrix) + throw(InvalidStateError("This method requires a column map")) + end + + rowMap = getRowMap(matrix) + colMap = getColMap(matrix) + + numLocalRows = getLocalNumRows(matrix) + + if isFillComplete(matrix) + diag = MultiVector{Data, GID, PID, LID}(rowMap, 1, false) + + + diag1D = getVectorView(diag, 1) + localRowMap = getLocalMap(rowMap) + localColMap = getLocalMap(colMap) + localMatrix = matrix.localMatrix + + diag1D[:] = getDiagCopyWithoutOffsets(matrix, localRowMap, localColMap, localMatrix) + + diag + else + getLocalDiagCopyWithoutOffsetsNotFillComplete(matrix) + end +end + +function leftScale!(matrix::CSRMatrix{Data}, X::AbstractArray{Data, 1}) where {Data <: Number} + for row in 1:getLocalNumRows(matrix) + _, vals = getLocalRowView(matrix, row) + LinAlg.scale!(vals, X[row]) + end +end + +function rightScale!(matrix::CSRMatrix{Data}, X::AbstractArray{Data, 1}) where {Data <: Number} + for row in 1:getLocalNumRows(matrix) + inds, vals = getLocalRowView(matrix, row) + for entry in 1:length(inds) + vals[entry] *= X[inds[entry]] + end + end +end + + +#### DistObject methods #### +function checkSizes(source::RowMatrix{Data, GID, PID, LID}, + target::CSRMatrix{Data, GID, PID, LID})::Bool where {Data, GID, PID, LID} + true +end + + +function copyAndPermute(source::RowMatrix{Data, GID, PID, LID}, + target::CSRMatrix{Data, GID, PID, LID}, numSameIDs::LID, + permuteToLIDs::AbstractArray{LID, 1}, permuteFromLIDs::AbstractArray{LID, 1} + ) where {Data, GID, PID, LID} + sourceIsLocallyIndexed = isLocallyIndexed(source) + + srcRowMap = getRowMap(source) + tgtRowMap = getRowMap(target) + + sameGIDs = @. gid(srcRowMap, collect(1:numSameIDs)) + permuteFromGIDs = @. gid(srcRowMap, permuteFromLIDs) + permuteToGIDs = @. gid(srcRowMap, permuteToLIDs) + + for (sourceGID, targetGID) in zip(vcat(sameGIDs, permuteFromGIDs), vcat(sameGIDs, permuteToGIDs)) + if sourceIsLocallyIndexed + rowInds, rowVals = getGlobalRowCopy(source, sourceGID) + else + rowInds, rowVals = getGlobalRowView(source, sourceGID) + end + combineGlobalValues(target, targetGID, rowInds, rowVals, INSERT) + end +end + +#TODO move this to RowMatrix. +#DECSION make a packable trait to encompassRowMatrix and RowGraph? as sources, checkSizes handles matching pairs of objects +function packAndPrepare(source::RowMatrix{Data, GID, PID, LID}, + target::CSRMatrix{Data, GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, + distor::Distributor{GID, PID, LID})::AbstractArray where {Data, GID, PID, LID} + pack(source, exportLIDs, distor) +end + +function pack(source::CSRMatrix{Data, GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, + distor::Distributor{GID, PID, LID})::AbstractArray where {Data, GID, PID, LID} + numExportLIDs = length(exportLIDs) + + localMatrix = source.localMatrix + localGraph = localMatrix.graph + + packed = Array{Tuple{AbstractArray{GID, 1}, AbstractArray{Data, 1}}}(numExportLIDs) + result = 0 + for i in 1:numExportLIDs + exportLID = exportLIDs[i] + start = localGraph.rowOffsets[exportLID] + last = localGraph.rowOffsets[exportLIDs+1]-1 + numEnt = last - start +1 + if numEnt == 0 + packed[i] = GID[], Data[] + else + values = view(localMatrix.values, start:last) + lids = view(localGraph.entries, start:last) + gids = @. gid(getColMap(source), lids) + packed[i] = gids, values + end + end + packed +end + +function unpackAndCombine(target::CSRMatrix{Data, GID, PID, LID}, + importLIDs::AbstractArray{LID, 1}, imports::AbstractArray, distor::Distributor{GID, PID, LID}, + cm::CombineMode) where{Data, GID, PID, LID} + + numImportLIDs = length(importLIDs) + + for i = 1:numImportLIDs + if length(imports[i] > 0) #ensure theres acutually something in the row + combineGlobalValues(target, importLIDs[i], imports[i][1], imports[i][2], cm) + end + end +end + + +#### Operator methods #### +function apply!(Y::MultiVector{Data, GID, PID, LID}, + operator::CSRMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID}, + mode::TransposeMode, alpha::Data, beta::Data) where {Data, GID, PID, LID} + + const ZERO = Data(0) + + if isFillActive(operator) + throw(InvalidStateError("Cannot call apply(...) until fillComplete(...)")) + end + + if alpha == ZERO + if beta == ZERO + fill!(Y, ZERO) + elseif beta != Data(1) + scale!(Y, beta) + end + return Y + end + + if mode == NO_TRANS + applyNonTranspose!(Y, operator, X, alpha, beta) + else + applyTranspose!(Y, operator, X, mode, alpha, beta) + end +end + +function applyNonTranspose!(Y::MultiVector{Data, GID, PID, LID}, operator::CSRMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID}, alpha::Data, beta::Data) where {Data, GID, PID, LID} + const ZERO = Data(0) + + #These are nullable + importer = getGraph(operator).importer + exporter = getGraph(operator).exporter + + #assumed to be shared by all data structures + resultComm = comm(Y) + + YIsOverwritten = (beta == ZERO) + YIsReplicated = !distributedGlobal(Y) && numProc(resultComm) != 0 + + #part of special case for replicated MV output + if YIsReplicated && myPid(resultComm) != 1 + beta = ZERO + end + + if isnull(importer) + XColMap = X + else + #need to import source multivector + + XColMap = get(getColumnMapMultiVector(operator, X)) + + doImport(X, XColMap, get(importer), INSERT) + end + + YRowMap = getRowMapMultiVector(operator, Y) + if !isnull(exporter) + localApply(YRowMap, operator, XColMap, NO_TRANS, alpha, ZERO) + + if YIsOverwritten + fill!(Y, ZERO) + else + scale!(Y, beta) + end + + doExport(YRowMap, Y, get(exporter), ADD) + else + #don't do export row Map and range map are the same + + if XColMap === Y + + YRowMap = getRowMapMultiVector(operator, Y, true) + + if beta != 0 + copy!(YRowMap, Y) + end + + localApply(YRowMap, operator, XColmap, NO_TRANS, alpha, ZERO) + copy!(Y, YRowMap) + else + localApply(Y, operator, XColMap, NO_TRANS, alpha, beta) + end + end + + if YIsReplicated + commReduce(Y) + end + Y +end + +function applyTranspose!(Yin::MultiVector{Data, GID, PID, LID}, operator::CSRMatrix{Data, GID, PID, LID}, Xin::MultiVector{Data, GID, PID, LID}, mode::TransposeMode, alpha::Data, beta::Data) where {Data, GID, PID, LID} + const ZERO = Data(0) + + nVects = numVectors(Xin) + importer = getGraph(operator).importer + exporter = getGraph(operator).exporter + + YIsReplicated = distributedGlobal(Yin) + YIsOverwritted = (beta == ZERO) + if YIsReplicated && myPID(comm(operator)) != 1 + beta = ZERO + end + + if isnull(importer) + X = copy(Xin) + else + X + end + + if !isnull(importer) + if !isnull(operator.importMV) && getNumVectors(get(operator.importMV)) != nVects + operator.importMV = Nullable{MultiVector{Data, GID, PID, LID}}() + end + if isnull(operator.importMV) + operator.importMV = Nullable(MultiVector(getColMap(operator), nVects)) + end + end + + if !isnull(exporter) + if !isnull(operator.exportMV) && getNumVectors(get(operator.exportMV)) != nVects + operator.exportMV = Nullable{MultiVector{Data, GID, PID, LID}}() + end + if isnull(operator.exportMV) + operator.exportMV = Nullable(MultiVector(getRowMap(operator), nVects)) + end + end + + if !isnull(exporter) + doImport(Xin, get(operator.exportMV), get(exporter), INSERT) + X = operator.exportMV + end + + if !isnull(importer) + localApply(get(operator.importMV), operator, X, mode, alpha, ZERO) + + if YIsOverwritten + fill!(Yin, ZERO) + else + scale!(Yin, beta) + end + doExport(get(operator.importMV), Yin, get(importer), ADD) + else + if X === Yin + Y = copy(Yin) + localApply(Y, operator, X, mode, alpha, beta) + copy!(Yin, Y) + else + localApply(Yin, operator, X, mode, alpha, beta) + end + end + if YIsReplicated + commReduce(Yin) + end + Yin +end + +TypeStability.@stable_function [(MultiVector{D, G, P, L}, CSRMatrix{D, G, P, L}, + MultiVector{D, G, P, L}, TransposeMode, D, D) + for (D, G, P, L) in Base.Iterators.product( + [Float64, Complex64], #Data + [UInt64, Int64, UInt32], #GID + [UInt8, Int8, UInt32], #PID + [UInt32, Int32]) #LID +# createRowInfo has a nessaccery Union variable +] RegexDict((r"rowValues", Any)) begin +function localApply(Y::MultiVector{Data, GID, PID, LID}, + A::CSRMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID}, + mode::TransposeMode, alpha::Data, beta::Data) where {Data, GID, PID, LID} + + const rawY = Y.data + const rawX = X.data + + + #TODO implement this better, can BLAS be used? + if !isTransposed(mode) + #TODO look at best way to order the loops to avoid cache misses + # I think this is the better order, since MultiVector is column oriented + numRows = getLocalNumRows(A) + for vect = LID(1):numVectors(Y) + for row = LID(1):numRows + sum::Data = Data(0) + @inbounds (indices, values, len) = getLocalRowViewPtr(A, row) + for i in LID(1):LID(len) + ind::LID = unsafe_load(indices, i) + val::Data = unsafe_load(values, i) + @inbounds sum += val*rawX[ind, vect] + end + sum = applyConjugation(mode, sum*alpha) + @inbounds rawY[row, vect] *= beta + @inbounds rawY[row, vect] += sum + end + end + else + rawY[:, :] *= beta + numRows = getLocalNumRows(A) + for vect = LID(1):numVectors(Y) + for mRow in LID(1):numRows + @inbounds (indices, values, len) = getLocalRowViewPtr(A, mRow) + for i in LID(1):LID(len) + ind::LID = unsafe_load(indices, i) + val::Data = unsafe_load(values, i) + @inbounds rawY[ind, vect] += applyConjugation(mode, alpha*rawX[mRow, vect]*val) + end + end + end + end + + Y +end +end diff --git a/src/Comm.jl b/src/Comm.jl new file mode 100644 index 0000000..2180f7a --- /dev/null +++ b/src/Comm.jl @@ -0,0 +1,194 @@ +export Comm +export barrier, broadcastAll, gatherAll, sumAll, maxAll, minAll, scanSum +export myPid, numProc, createDistributor + +# methods (and docs) are currently based straight off Epetra_Comm +# tpetra's equivalent seemed to be a wrapper to other Trilinos packages + +# following julia's convention, n processors are labled 1 through n +# count variables are removed, since that information is contained in the arrays + +""" +The base type for types that represent communication in parallel computing. +All subtypes must have the following methods, with CommImpl standing in for the subtype: + +barrier(comm::CommImpl) - Each processor must wait until all processors have arrived + +broadcastAll(comm::CommImpl, myvals::AbstractArray{T}, Root::Integer)::Array{T} where T + - Takes a list of input values from the root processor and sends to all + other processors. The values are returned (including on the root process) + +gatherAll(comm::CommImpl, myVals::AbstractArray{T})::Array{T} where T + - Takes a list of input values from all processors and returns an ordered + contiguous list of those values on each processor + +sumAll(comm::CommImpl, partialsums::AbstractArray{T})::Array{T} where T + - Take a list of input values from all processors and returns the sum on each + processor. The method +(::T, ::T)::T can be assumed to exist + +maxAll(comm::CommImpl, partialmaxes::AbstractArray{T})::Array{T} where T + - Takes a list of input values from all processors and returns the max to all + processors. The method <(::T, ::T)::Bool can be assumed to exist + +minAll(comm::CommImpl, partialmins::AbstractArray{T})::Array{T} where T + - Takes a list of input values from all processors and returns the min to all + processors. The method <(::T, ::T)::Bool can be assumed to exist + +scanSum(comm::CommImpl, myvals::AbstractArray{T})::Array{T} where T + - Takes a list of input values from all processors, computes the scan sum and + returns it to all processors such that processor i contains the sum of + values from processor 1 up to and including processor i. The method + +(::T, ::T)::T can be assumed to exist + +myPid(comm::CommImpl{GID, PID, LID})::PID - Returns the process rank + +numProc(comm::CommImpl{GID, PID, LID})::PID - Returns the total number of processes + +createDistributor(comm::CommImpl{GID, PID, LID})::Distributor{GID, PID, LID} - Create a distributor object + +""" +abstract type Comm{GID <: Integer, PID <:Integer, LID <: Integer} +end + +function Base.show(io::IO, comm::Comm) + print(io, split(String(Symbol(typeof(comm))), ".")[2]," with PID ", myPid(comm), + " and ", numProc(comm), " processes") +end + +""" + broadcastAll(::Comm, ::T, ::Integer)::T + +As `broadcastAll(::Comm, ::AbstractArray{T, 1}, ::Integer})::Array{T, 1}`, except only broadcasts a single elements +""" +function broadcastAll(comm::Comm, myVal::T, root::Integer)::T where T + broadcastAll(comm, [myVal], root)[1] +end + +""" + gatherAll(::Comm, ::T)::Array{T, 1} + +As `gatherAll(::Comm, ::AbstractArray{T, 1}})::Array{T, 1}`, except each process only sends a single elements +""" +function gatherAll(comm::Comm, myVal::T)::Array{T, 1} where T + gatherAll(comm, [myVal]) +end + +""" + sumAll(::Comm, ::T)::T + +As `sumAll(::Comm, ::AbstractArray{T, 1}})::Array{T, 1}`, except for a single element +""" +function sumAll(comm::Comm, val::T)::T where T + sumAll(comm, [val])[1] +end + +""" + maxAll(::Comm, ::T)::T + +As `maxAll(::Comm, ::AbstractArray{T, 1}})::Array{T, 1}`, except for a single element +""" +function maxAll(comm::Comm, val::T)::T where T + maxAll(comm, [val])[1] +end + +""" + minAll(::Comm, ::T)::T + +As `minAll(::Comm, ::AbstractArray{T, 1}})::Array{T, 1}`, except for a single element +""" +function minAll(comm::Comm, val::T)::T where T + minAll(comm, [val])[1] +end + +""" + scanSum(::Comm, ::T)::T + +As `scanSum(::Comm, ::AbstractArray{T, 1}})::Array{T, 1}`, except for a single element +""" +function scanSum(comm::Comm, val::T)::T where T + scanSum(comm, [val])[1] +end + + + +#### documentation for required methods #### + +""" + barrier(::Comm) + +Causes the process to pause until all processes have called barrier. Used to synchronize the processes +""" +function barrier end + + +""" + broadcastAll(comm::Comm, myVals::AbstractArray{T, 1}, root::Integer)::Array{T, 1} + +Takes a list of input values from the root processor and sends it to each +other processor. The broadcasted values are then returned, including on +the root process. +""" +function broadcastAll end + +""" + gatherAll(comm::Comm, myVals::AbstractArray{T, 1})::Array{T, 1} + +Takes a list of input values from all processors and returns an ordered, +contiguous list of those values. +""" +function gatherAll end + +""" + sumAll(comm::Comm, partialsums::AbstractArray{T, 1})::Array{T, 1} + +Takes a list of input values from all processors and returns the sum on each +processor. The method `+(::T, ::T)::T` must exist. +""" +function sumAll end + +""" + maxAll(comm::Comm, partialmaxes::AbstractArray{T, 1})::Array{T, 1} + +Takes a list of input values from all processors and returns the max to all +processors. The method `<(::T, ::T)::Bool` must exist. +""" +function maxAll end + +""" + minAll(comm::Comm, partialmins::AbstractArray{T, 1})::Array{T, 1} + +Takes a list of input values from all processors and returns the min to all +processors. The method `<(::T, ::T)::Bool` must exist. +""" +function minAll end + +""" + scanSum(comm::Comm, myvals::AbstractArray{T, 1})::Array{T, 1} + +Takes a list of input values from all processors, computes the scan sum and +returns it to all processors such that processor `i` contains the sum of +values from processor 1 up to, and including, processor `i`. The method ++(::T, ::T)::T must exist +""" +function scanSum end + +""" + myPid(::Comm{GID, PID, LID})::PID + +Returns the rank of the calling processor +""" +function myPid end + +""" + numProc(::Comm{GID, PID, LID})::PID + +Returns the total number of processes +""" +function numProc end + +""" + createDistributor(comm::Comm{GID, PID, LID})::Distributor{GID, PID, LID} + +Creates a distributor for the given Comm object +""" +function createDistributor end \ No newline at end of file diff --git a/src/ComputeOffsets.jl b/src/ComputeOffsets.jl new file mode 100644 index 0000000..e45153e --- /dev/null +++ b/src/ComputeOffsets.jl @@ -0,0 +1,28 @@ + +#used in implementation of CRSGraph, CSRMatrix and, if added, FixedHashTable + + +function computeOffsets(rowPtrs::AbstractArray{<: Integer, 1}, numEnts::Integer) + numOffsets = length(rowPtrs) + @simd for i = 1:numOffsets + @inbounds rowPtrs[i] = numEnts*(i-1)+1 + end + rowPtrs +end + + +function computeOffsets(rowPtrs::AbstractArray{<: Integer, 1}, numEnts::Array{<: Integer, 1}) + numOffsets = length(rowPtrs) + numCounts = length(numEnts) + if numCounts >= numOffsets + throw(InvalidArgumentError("length(numEnts) = $numCounts " + * ">= length(rowPtrs) = $numOffsets")) + end + sum = 1 + for i = 1:numCounts + @inbounds rowPtrs[i] = sum + @inbounds sum += numEnts[i] + end + @inbounds rowPtrs[numCounts+1:numOffsets] = sum + sum-1 +end \ No newline at end of file diff --git a/src/Directory.jl b/src/Directory.jl new file mode 100644 index 0000000..882f9dd --- /dev/null +++ b/src/Directory.jl @@ -0,0 +1,24 @@ +export Directory + +# methods and docs based straight off Epetra_Directory to match Comm + +""" +A base type as an interface to allow Map and BlockMap objects to reference non-local +elements. + +All subtypes must have the following methods, with DirectoryImpl standing in for +the subtype: + +getDirectoryEntries(directory::DirectoryImpl, map::BlockMap, globalEntries::AbstractArray{GID}, + high_rank_sharing_procs::Bool)::Tuple{AbstractArray{PID}, AbstractArray{LID}} + where GID <: Integer where PID <: Integer where LID <:Integer + - Returns processor and local id infor for non-local map entries. Returns a tuple + containing + 1 - an Array of processors owning the global ID's in question + 2 - an Array of local IDs of the global on the owning processor + +gidsAllUniquelyOwned(directory::DirectoryImpl) + - Returns true if all GIDs appear on just one processor +""" +abstract type Directory{GID <: Integer, PID <:Integer, LID <: Integer} +end \ No newline at end of file diff --git a/src/DirectoryMethods.jl b/src/DirectoryMethods.jl new file mode 100644 index 0000000..ba2ea89 --- /dev/null +++ b/src/DirectoryMethods.jl @@ -0,0 +1,41 @@ +export getDirectoryEntries, gidsAllUniquelyOwned +export createDirectory + +# has to be split from the declaration of Directory due to dependancy on files that require Directory + +function getDirectoryEntries(directory::Directory{GID, PID, LID}, map::BlockMap{GID, PID, LID}, + globalEntries::AbstractArray{Number}, high_rank_sharing_procs::Bool=false)::Tuple{AbstractArray{PID}, AbstractArray{LID}} where GID <: Integer where PID <: Integer where LID <: Integer + getDirectoryEntries(directory, map, Array{GID, 1}(globalEntries), high_rank_sharing_procs) +end + +function getDirectoryEntries(directory::Directory{GID, PID, LID}, map::BlockMap{GID, PID, LID}, + globalEntries::AbstractArray{GID})::Tuple{AbstractArray{PID}, AbstractArray{LID}} where GID <: Integer where PID <: Integer where LID <: Integer + getDirectoryEntries(directory, map, globalEntries, false) +end + + +""" + createDirectory(comm::Comm, map::BlockMap) +Create a directory object for the given Map +""" +function createDirectory(comm::Comm{GID, PID, LID}, map::BlockMap{GID, PID, LID})::BasicDirectory{GID, PID, LID} where GID <: Integer where PID <: Integer where LID <: Integer + BasicDirectory{GID, PID, LID}(map) +end + +#### required methods documentation stubs #### + +""" + getDirectoryEntries(directory, map::BlockMap{GID, PID, LID}, globalEntries::AbstractArray{GID}, high_rank_sharing_procs::Bool)::Tuple{AbstractArray{PID}, AbstractArray{LID}} + +Returns processor and local id information for non-local map entries. Returns a tuple containing +1. an Array of processors owning the global ID's in question +2. an Array of local IDs of the global on the owning processor +""" +function getDirectoryEntries end + +""" + gidsAllUniquelyOwned(directory) + +Returns true if all GIDs appear on just one processor +""" +function gidsAllUniquelyOwned end \ No newline at end of file diff --git a/src/DistObject.jl b/src/DistObject.jl new file mode 100644 index 0000000..4cf0ebd --- /dev/null +++ b/src/DistObject.jl @@ -0,0 +1,190 @@ + +export DistObject +export doImport, doExport +export copyAndPermute, packAndPrepare, unpackAndCombine, checkSize +export releaseViews, createViews, createViewsNonConst + +# Note that all packet size information was removed due to the use of julia's +# built in serialization/objects + +""" +A base type for constructing and using dense multi-vectors, vectors and matrices in parallel. + + +To support transfers the following methods must be implemented for the source type and the target type + + checkSizes(source::<:SrcDistObject{GID, PID, LID}, target::<:DistObject{GID, PID, LID})::Bool +Whether the source and target are compatible for a transfer + + copyAndPermute(source::<:SrcDistObject{GID, PID, LID}, target::<:DistObject{GID, PID, LID}, numSameIDs::LID, permuteToLIDs::AbstractArray{LID, 1}, permuteFromLIDs::AbstractArray{LID, 1}) +Perform copies and permutations that are local to this process. + + packAndPrepare(source::<:SrcDistObject{GID, PID, LID}, target::<:DistObjectGID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, distor::Distributor{GID, PID, LID})::AbstractArray +Perform any packing or preparation required for communications. The +method returns the array of objects to export + + unpackAndCombine(target::<:DistObject{GID, PID, LID},importLIDs::AbstractArray{LID, 1}, imports::AAbstractrray, distor::Distributor{GID, PID, LID}, cm::CombineMode) +Perform any unpacking and combining after communication +""" +abstract type DistObject{GID <:Integer, PID <: Integer, LID <: Integer} <: SrcDistObject{GID, PID, LID} +end + + +## import/export interface ## + +""" + doImport(target::Impl{GID, PID, LID}, source::SrcDistObject{GID, PID, LID}, importer::Import{GID, PID, LID}, cm::CombineMode) + +Import data into this object using an Import object ("forward mode") +""" +function doImport(source::SrcDistObject{GID, PID, LID}, + target::DistObject{GID, PID, LID}, importer::Import{GID, PID, LID}, + cm::CombineMode) where {GID <:Integer, PID <: Integer, LID <: Integer} + doTransfer(source, target, cm, numSameIDs(importer), permuteToLIDs(importer), + permuteFromLIDs(importer), remoteLIDs(importer), exportLIDs(importer), + distributor(importer), false) +end + +""" + doExport(target::Impl{GID, PID, LID}, source::SrcDistObject{GID, PID, LID}, exporter::Export{GID, PID, LID}, cm::CombineMode) + +Export data into this object using an Export object ("forward mode") +""" +function doExport(source::SrcDistObject{GID, PID, LID}, target::DistObject{GID, PID, LID}, + exporter::Export{GID, PID, LID}, cm::CombineMode) where { + GID <:Integer, PID <: Integer, LID <: Integer} + doTransfer(source, target, cm, numSameIDs(exporter), permuteToLIDs(exporter), + permuteFromLIDs(exporter), remoteLIDs(exporter), exportLIDs(exporter), + distributor(exporter), false) +end + +""" + doImport(source::SrcDistObject{GID, PID, LID}, target::DistObject{GID, PID, LID}, exporter::Export{GID, PID, LID}, cm::CombineMode) + +Import data into this object using an Export object ("reverse mode") +""" +function doImport(source::SrcDistObject{GID, PID, LID}, target::DistObject{GID, PID, LID}, + exporter::Export{GID, PID, LID}, cm::CombineMode) where { + GID <:Integer, PID <: Integer, LID <: Integer} + doTransfer(source, target, cm, numSameIDs(exporter), permuteToLIDs(exporter), + permuteFromLIDs(exporter), remoteLIDs(exporter), exportLIDs(exporter), + distributor(exporter), true) +end + +""" + doExport(source::SrcDistObject{GID, PID, LID}, target::DistObject{GID, PID, LID}, importer::Import{GID, PID, LID}, cm::CombineMode) + +Export data into this object using an Import object ("reverse mode") +""" +function doExport(source::SrcDistObject{GID, PID, LID}, target::DistObject{GID, PID, LID}, + importer::Import{GID, PID, LID}, cm::CombineMode) where { + GID <:Integer, PID <: Integer, LID <: Integer} + doTransfer(source, target, cm, numSameIDs(importer), permuteToLIDs(importer), + permuteFromLIDs(importer), remoteLIDs(importer), exportLIDs(importer), + distributor(importer), true) +end + + +## import/export functionality ## + +""" + checkSizes(source, target)::Bool + +Compare the source and target objects for compatiblity. By default, returns false. Override this to allow transfering to/from subtypes +""" +function checkSizes(source::SrcDistObject{GID, PID, LID}, + target::SrcDistObject{GID, PID, LID})::Bool where { + GID <: Integer, PID <: Integer, LID <: Integer} + false +end + +""" + doTransfer(src::SrcDistObject{GID, PID, LID}, target::Impl{GID, PID, LID}, cm::CombineMode, numSameIDs::LID, permuteToLIDs::AbstractArray{LID, 1}, permuteFromLIDs::AbstractArray{LID, 1}, remoteLIDs::AbstractArray{LID, 1}, exportLIDs::AbstractArray{LID, 1}, distor::Distributor{GID, PID, LID}, reversed::Bool) + +Perform actual redistribution of data across memory images +""" +function doTransfer(source::SrcDistObject{GID, PID, LID}, + target::DistObject{GID, PID, LID}, cm::CombineMode, + numSameIDs::LID, permuteToLIDs::AbstractArray{LID, 1}, + permuteFromLIDs::AbstractArray{LID, 1}, remoteLIDs::AbstractArray{LID, 1}, + exportLIDs::AbstractArray{LID, 1}, distor::Distributor{GID, PID, LID}, + reversed::Bool) where {GID <: Integer, PID <: Integer, LID <: Integer} + + if !checkSizes(source, target) + throw(InvalidArgumentError("checkSize() indicates that the destination " * + "object is not a legal target for redistribution from the " * + "source object. This probably means that they do not have " * + "the same dimensions. For example, MultiVectors must have " * + "the same number of rows and columns.")) + end + + readAlso = true #from TPetras rwo + if cm == INSERT || cm == REPLACE + numIDsToWrite = numSameIDs + length(permuteToLIDs) + length(remoteLIDs) + if numIDsToWrite == numMyElements(map(target)) + # overwriting all local data in the destination, so write-only suffices + + #TODO look at FIXME on line 503 + readAlso = false + end + end + + #TODO look at FIXME on line 514 + createViews(source) + + #tell target to create a view of its data + #TODO look at FIXME on line 531 + createViewsNonConst(target, readAlso) + + + if numSameIDs + length(permuteToLIDs) != 0 + copyAndPermute(source, target, numSameIDs, permuteToLIDs, permuteFromLIDs) + end + + # only need to pack & send comm buffers if combine mode is not ZERO + # ZERO combine mode indicates results are the same as if all zeros were recieved + if cm != ZERO + exports = packAndPrepare(source, target, exportLIDs, distor) + + if ((reversed && distributedGlobal(target)) + || (!reversed && distributedGlobal(source))) + if reversed + #do exchange of remote data + imports = resolveReverse(distor, exports) + else + imports = resolve(distor, exports) + end + + unpackAndCombine(target, remoteLIDs, imports, distor, cm) + end + end + + releaseViews(source) + releaseViews(target) +end + +""" + createViews(obj::SrcDistObject) + +doTransfer calls this on the source object. By default it does nothing, but the source object can use this as a hint to fetch data from a compute buffer on an off-CPU decice (such as GPU) into host memory +""" +function createViews(obj::SrcDistObject) +end + +""" + createViewsNonConst(obj::SrcDistObject, readAlso::Bool) + +doTransfer calls this on the target object. By default it does nothing, but the target object can use this as a hint to fetch data from a compute buffer on an off-CPU decice (such as GPU) into host memory +readAlso indicates whether the doTransfer might read from the original buffer +""" +function createViewsNonConst(obj::SrcDistObject, readAlso::Bool) +end + + +""" + releaseViews(obj::SrcDistObject) + +doTransfer calls this on the target and source as it completes to allow any releasing of buffers or views. By default it does nothing +""" +function releaseViews(obj::SrcDistObject) +end diff --git a/src/Distributor.jl b/src/Distributor.jl new file mode 100644 index 0000000..25b86ec --- /dev/null +++ b/src/Distributor.jl @@ -0,0 +1,121 @@ + + +export Distributor +export createFromSends, createFromRecvs +export resolve, resolveReverse, resolveWaits +export resolvePosts, resolveReversePosts, resolveReverseWaits + +# methods (and docs) are currently based straight off Epetra_Distributor to match Comm + +""" +The base type for gather/scatter setup. +All subtypes must have the following methods, with DistributorImpl standing +in for the subtype: + +createFromSends(dist::DistributorImpl,exportPIDs::AbstractArray{PID}) + ::Integer where PID <:Integer + - sets up the Distributor object using a list of process IDs to which we + export and the number of IDs being exported. Returns the number of + IDs this processor will be receiving + +createFromRecvs(dist::DistributorImpl, remoteGIDs::AbstractArray{GID}, + remotePIDs::AbstractArray{PID})::Tuple{AbstractArray{GID}, AbstractArray{PID}} + where GID <: Integer where PID <: Integer + - sets up the Distributor object using a list of remote global IDs and + corresponding PIDs. Returns a tuple with the global IDs and their + respective processor IDs being sent by me. + +resolvePosts(dist::DistributorImpl, exportObjs::AbstractArray) + - Post buffer of export objects (can do other local work before executing + Waits). Otherwise, as do(::DistributorImpl, ::Array{T})::Array{T} + +resolveWaits(dist::DistributorImpl)::AbstractArray - wait on a set of posts + +resolveReversePosts(dist::DistributorImpl, exportObjs::AbstractArray) + - Do reverse post of buffer of export objects (can do other local work + before executing Waits). Otherwise, as + doReverse(::DistributorImpl, ::AbstractArray{T})::AbstractArray{T} + +resolveReverseWaits(dist::DistributorImpl)::AbstractArray - wait on a reverse set of posts + +""" +abstract type Distributor{GID <: Integer, PID <: Integer, LID <: Integer} +end + +""" + createFromSends(dist::Distributor, exportPIDs::AbstractArray{<:Integer})::Integer + +Sets up the Distributor object using a list of process IDs to which we +export and the number of IDs being exported. Returns the number of +IDs this processor will be receiving +""" +function createFromSends(dist::Distributor{GID, PID, LID}, exportPIDs::AbstractArray{<:Integer}) where GID <: Integer where PID <: Integer where LID <: Integer + createFromSends(dist, Array{PID}(exportPIDs)) +end + +""" + createFromRecvs(dist::Distributor{GID, PID, LID}, remoteGIDs::AbstractArray{<:Integer}, remotePIDs::AbstractArray{<:Integer})::Tuple{AbstractArray{GID}, AbstractArray{PID}} + +Sets up the Distributor object using a list of remote global IDs and +corresponding PIDs. Returns a tuple with the global IDs and their +respective processor IDs being sent by me. +""" +function createFromRecvs(dist::Distributor{GID, PID, LID}, remoteGIDs::AbstractArray{<:Integer}, + remotePIDs::AbstractArray{<:Integer}) where GID <: Integer where PID <: Integer where LID <: Integer + createFromRecvs(dist, Array{GID}(remoteGIDs), Array{PID}(remotePIDs)) +end + +""" + resolve(dist::Distributor, exportObjs::AbstractArray{T})::AbstractArray{T} + +Execute the current plan on buffer of export objects and return the +objects set to this processor +""" +function resolve(dist::Distributor, exportObjs::AbstractArray{T})::AbstractArray{T} where T + resolvePosts(dist, exportObjs) + resolveWaits(dist) +end + +""" + resolveReverse(dist::Distributor, exportObjs::AbstractArray{T})::AbstractArray{T} + +Execute the reverse of the current plan on buffer of export objects and +return the objects set to this processor +""" +function resolveReverse(dist::Distributor, exportObjs::AbstractArray{T})::AbstractArray{T} where T + resolveReversePosts(dist, exportObjs) + resolveReverseWaits(dist) +end + + +#### required method documentation stubs #### + +""" + resolvePosts(dist::Distributor, exportObjs::AbstractArray) + +Post buffer of export objects (can do other local work before executing +Waits. Otherwise, as resolve(::DistributorImpl, ::AbstractArray{T})::AbstractArray{T} +""" +function resolvePosts end + +""" + resolveWaits(dist::Distributor)::AbstractArray + +wait on a set of posts +""" +function resolveWaits end + +""" + resolveReversePosts(dist::Distributor, exportObjs::AbstractArray) +Do reverse post of buffer of export objects (can do other local work +before executing Waits). Otherwise, as +doReverse(::DistributorImpl, ::AbstractArray{T})::AbstractArray{T} +""" +function resolveReversePosts end + +""" + resolveReverseWaits(dist::Distributor)::AbstractArray + +wait on a reverse set of posts +""" +function resolveReverseWaits end diff --git a/src/Enums.jl b/src/Enums.jl new file mode 100644 index 0000000..e4ba845 --- /dev/null +++ b/src/Enums.jl @@ -0,0 +1,67 @@ +export CombineMode, ADD, INSERT, REPLACE, ABSMAX, ZERO +export TransposeMode, NO_TRANS, TRANS, CONJ_TRANS +export ProfileType, STATIC_PROFILE, DYNAMIC_PROFILE +export IndexType, LOCAL_INDICES, GLOBAL_INDICES +export StorageStatus, STORAGE_2D, STORAGE_1D_UNPACKED, STORAGE_1D_PACKED + +""" +Tells petra how to combine data received from other processes with existing data on the calling process for specific import or export options. + +Here is the list of combine modes: + * ADD: Sum new values into existing values + * INSERT: Insert new values that don't currently exist + * REPLACE: REplace existing values with new values + * ABSMAX: If ``x_{old}`` is the old value and ``x_{new}`` the incoming new value, replace ``x_{old}`` with ``\\max(x_{old}, x_{new})`` + * ZERO: Replace old values with zero +""" +@enum CombineMode ADD=1 INSERT=2 REPLACE=3 ABSMAX=4 ZERO=5 + + + +""" +Tells petra whether to use the transpose or conjugate transpose of the matrix +""" +@enum TransposeMode NO_TRANS=1 TRANS=2 CONJ_TRANS=3 + +""" + isTransposed(mode::TransposeMode)::Bool + +Checks whether the given TransposeMode is transposed +""" +@inline function isTransposed(mode::TransposeMode)::Bool + mode != NO_TRANS +end + + +function applyConjugation(mode::TransposeMode, val::T)::T where T + if mode == CONJ_TRANS + conj(val) + else + val + end +end + +function applyConjugation(mode::TransposeMode, val::T)::T where {T <: Real} + val +end + + +""" +Allocation profile for matrix/graph entries +""" +@enum ProfileType STATIC_PROFILE DYNAMIC_PROFILE + + +""" +Can be used to differentiate global and local indices +""" +@enum IndexType LOCAL_INDICES GLOBAL_INDICES + + +""" +Status of the graph's or matrix's storage, when not in +a fill-complete state. +""" +@enum StorageStatus STORAGE_2D STORAGE_1D_UNPACKED STORAGE_1D_PACKED + + diff --git a/src/Error.jl b/src/Error.jl new file mode 100644 index 0000000..9591da4 --- /dev/null +++ b/src/Error.jl @@ -0,0 +1,23 @@ + +export InvalidArgumentError, InvalidStateError + +""" + InvalidArgumentError(msg) + +The values passed as arguments are not valid. Argument `msg` +is a descriptive error string. +""" +struct InvalidArgumentError <: Exception + msg::AbstractString +end + + +""" + InvalidStateError(msg) + +An object is not in a valid state for this method. Argument `msg` +is a descriptive error string. +""" +struct InvalidStateError <: Exception + msg::AbstractString +end \ No newline at end of file diff --git a/src/Export.jl b/src/Export.jl new file mode 100644 index 0000000..87780c9 --- /dev/null +++ b/src/Export.jl @@ -0,0 +1,257 @@ + +export Export + +""" +Communication plan for data redistribution from a (possibly) multiple-owned to a uniquely owned distribution +""" +struct Export{GID <: Integer, PID <:Integer, LID <: Integer} + exportData::ImportExportData{GID, PID, LID} +end + +#TODO document + +## Constructors ## + +function Export(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID}, remotePIDs::Nullable{AbstractArray{PID}}=Nullable{AbstractArray{PID}}(); plist...) where {GID <: Integer, PID <: Integer, LID <: Integer} + Export(source, target, + Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end + +function Export(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID}, + plist::Dict{Symbol}) where {GID <: Integer, PID <: Integer, LID <: Integer} + Export(source, target, Nullable{AbstractArray{PID}}(), plist) +end + +function Export(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID}, + remotePIDs::Nullable{AbstractArray{PID}}, plist::Dict{Symbol}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + + if @debug + info("$(myPid(comm(source))): Export ctor\n") + end + + expor = Export(ImportExportData(source, target)) + + exportGIDs = setupSamePermuteExport(expor) + + if @debug + info("$(myPid(comm(source))): Export ctor: setupSamePermuteExport done\n") + end + if distributedGlobal(source) + setupRemote(expor, exportGIDs) + end + if @debug + info("$(myPid(comm(source))): Export ctor: done\n") + end + + expor +end + + +## internal construction methods ## + +function setupSamePermuteExport(expor::Export{GID, PID, LID})::AbstractArray{GID} where {GID <: Integer, PID <: Integer, LID <: Integer} + + data = expor.exportData + + source = sourceMap(expor) + target = targetMap(expor) + + sourceGIDs = myGlobalElements(source) + targetGIDs = myGlobalElements(target) + + + numSrcGIDs = length(sourceGIDs) + numTgtGIDs = length(targetGIDs) + numGIDs = min(numSrcGIDs, numTgtGIDs) + + numSameGIDs = 1 + while numSameGIDs <= numGIDs && sourceGIDs[numSameGIDs] == targetGIDs[numSameGIDs] + numSameGIDs += 1 + end + numSameGIDs -= 1 + numSameIDs(data, numSameGIDs) + + exportGIDs = Array{GID, 1}(0) + permuteToLIDs = JuliaPetra.permuteToLIDs(data) + permuteFromLIDs = JuliaPetra.permuteFromLIDs(data) + exportLIDs = JuliaPetra.exportLIDs(data) + + for srcLID = (numSameGIDs+1):numSrcGIDs + const curSrcGID = sourceGIDs[srcLID] + const tgtLID = lid(target, curSrcGID) + if tgtLID != 0 + push!(permuteToLIDs, tgtLID) + push!(permuteFromLIDs, srcLID) + else + push!(exportGIDs, curSrcGID) + push!(exportLIDs, srcLID) + end + end + + if length(exportLIDs) != 0 && !distributedGlobal(source) + isLocallyComplete(data, false) + warn("Source has export LIDs but source not distributed globally. " * + "Exporting to a submap of the target map.") + end + + if distributedGlobal(source) + #resize!(JuliaPetra.exportPIDs(data), length(exportGIDs)) + + (exportPIDs, exportLIDs) = remoteIDList(target, exportGIDs) + JuliaPetra.exportPIDs(data, exportPIDs) + missingGIDs = 0 + for i = 1:length(exportPIDs) + if exportPIDs[i] == 0 + missingGIDs += 1 + end + end + + if missingGIDs != 0 + warn("The source Map has GIDs not found in the target Map") + + isLocallyComplete(data, false) + numInvalidExports = missingGIDs + totalNumExports = length(exportPIDs) + + if numInvalidExports == totalNumExports + # all exports invalid, can delete all exports + resize!(exportGIDs, 0) + resize!(exportLIDs, 0) + resize!(exportPIDs, 0) + else + #some exports are valid, need to keep the valid exports + numValidExports = 1 + for e = 1:totalNumExports + if exportPIDs[e] != 0 + exportGIDs[numValidExports] = exportGIDs[e] + exportLIDs[numValidExports] = exportLIDs[e] + exportPIDs[numValidExports] = exportPIDs[e] + numValidExports += 1 + end + end + numValidExports -= 1 + + resize!(exportGIDs, numValidExports) + resize!(exportLIDs, numValidExports) + resize!(exportPIDs, numValidExports) + end + end + end + exportGIDs +end + +function setupRemote(expor::Export{GID, PID, LID}, exportGIDs::AbstractArray{GID, 1}) where {GID <: Integer, PID <: Integer, LID <: Integer} + + data = expor.exportData + + target = targetMap(data) + + if @debug + info("$(myPid(comm(target))): setupRemote\n") + end + + exportPIDs = JuliaPetra.exportPIDs(data) + + order = sortperm(exportPIDs) + permute!(exportPIDs, order) + permute!(exportLIDs(data), order) + permute!(exportGIDs, order) + + if @debug + info("$(myPid(comm(target))): setupRemote: Calling createFromSends\n") + end + + numRemoteIDs = createFromSends(distributor(data), exportPIDs) + + if @debug + info("$(myPid(comm(target))): setupRemote: Calling doPostsAndWaits\n") + end + + remoteGIDs = resolve(distributor(data), exportGIDs) + + remoteLIDs = JuliaPetra.remoteLIDs(data) + + resize!(remoteLIDs, numRemoteIDs) + + for i in 1:length(remoteGIDs) + remoteLIDs[i] = lid(target, remoteGIDs[i]) + end + + if @debug + info("$(myPid(comm(target))): setupRemote: done\n") + end +end + +## Getters ## + +""" +Get the source map for the given ImportExportData +""" +function sourceMap(impor::Export{GID, PID, LID})::BlockMap{GID, PID, LID} where GID <: Integer where PID <:Integer where LID <: Integer + impor.exportData.source +end + +""" +Get the target map for the given ImportExportData +""" +function targetMap(impor::Export{GID, PID, LID})::BlockMap{GID, PID, LID} where GID <: Integer where PID <:Integer where LID <: Integer + impor.exportData.target +end + +""" +List of elements in the target map that are permuted. +""" +function permuteToLIDs(impor::Export{GID, PID, LID})::Array{LID} where GID <: Integer where PID <:Integer where LID <: Integer + impor.exportData.permuteToLIDs +end + +""" +List of elements in the source map that are permuted. +""" +function permuteFromLIDs(impor::Export{GID, PID, LID})::Array{LID} where GID <: Integer where PID <:Integer where LID <: Integer + impor.exportData.permuteFromLIDs +end + +""" +List of elements in the target map that are coming from other processors +""" +function remoteLIDs(impor::Export{GID, PID, LID})::Array{LID} where GID <: Integer where PID <: Integer where LID <: Integer + impor.exportData.remoteLIDs +end + +""" +List of elements that will be sent to other processors +""" +function exportLIDs(impor::Export{GID, PID, LID})::Array{LID} where GID <: Integer where PID <: Integer where LID <: Integer + impor.exportData.exportLIDs +end + +""" +List of processors to which elements will be sent `exportLID[i]` will be sent to processor `exportPIDs[i]` +""" +function exportPIDs(impor::Export{GID, PID, LID})::Array{PID} where GID <: Integer where PID <: Integer where LID <: Integer + impor.exportData.exportPIDs +end + +""" +Returns the number of elements that are identical between the source and target maps, up to the first different ID +""" +function numSameIDs(impor::Export{GID, PID, LID})::LID where GID <: Integer where PID <: Integer where LID <: Integer + impor.exportData.numSameIDs +end + + +""" +Returns the distributor being used +""" +function distributor(impor::Export{GID, PID, LID})::Distributor{GID, PID, LID} where GID <: Integer where PID <: Integer where LID <: Integer + impor.exportData.distributor +end + +""" +Returns whether the import or export is locally complete +""" +function isLocallyComplete(impor::Export{GID, PID, LID})::Bool where GID <: Integer where PID <: Integer where LID <: Integer + impor.exportData.isLocallyComplete +end diff --git a/src/Import.jl b/src/Import.jl new file mode 100644 index 0000000..d53949b --- /dev/null +++ b/src/Import.jl @@ -0,0 +1,356 @@ + +export Import + +""" +Communication plan for data redistribution from a uniquely-owned to a (possibly) multiply-owned distribution. +""" +struct Import{GID <: Integer, PID <:Integer, LID <: Integer} + importData::ImportExportData{GID, PID, LID} + + #default constructor appeared to accept a pair of BlockMaps + function Import{GID, PID, LID}( + importData::ImportExportData{GID, PID, LID}) where { + GID <: Integer, PID <: Integer, LID <: Integer} + new(importData) + end +end + +## Constructors ## + +function Import(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID}, + userRemotePIDs::AbstractArray{PID}, remoteGIDs::AbstractArray{GID}, + userExportLIDs::AbstractArray{LID}, userExportPIDs::AbstractArray{PID}, + useRemotePIDGID::Bool; plist...) where { + GID <: Integer, PID <: Integer, LID <: Integer} + Import{GID, PID, LID}(source, target, userRemotePIDs, remoteGIDs, + userExportLIDs, userExportPIDs, useRemotePIDGID, Dict(plist)) +end + +function Import(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID}, + userRemotePIDs::AbstractArray{PID}, remoteGIDs::AbstractArray{GID}, + userExportLIDs::AbstractArray{LID}, userExportPIDs::AbstractArray{PID}, + useRemotePIDGID::Bool, plist::Dict{Symbol} + ) where {GID <: Integer, PID <: Integer, LID <: Integer} + + importData = ImportExportData(source, target) + + if @debug + info("$(myPid(comm(source))): Import ctor expert\n") + end + + const remoteLIDs = JuliaPetra.remoteLIDs(importData) + + if !userRemotePIDGID + empty!(remoteGIDs) + empty!(remoteLIDs) + end + + getIDSource(data, remoteGIDs, !userRemotePIDGID) + + if length(remoteGIDs) > 0 && !isDistributed(source) + throw(InvalidArgumentError("Target has remote LIDs but source is not distributed globally")) + end + + (remotePIDs, _) = remoteIDList(source, remoteGIDs) + + remoteProcIDs = (useRemotePIDGID) ? userRemotePIDs : remotePIDs + + if !(length(remoteProcIDs) == length(remoteGIDs) && length(remoteGIDs) == length(remoteLIDs)) + throw(InvalidArgumentError("Size miss match on remoteProcIDs, remoteGIDs and remoteLIDs")) + end + + # ensure remoteProcIDs[i], remoteGIDs[i] and remoteLIDs[i] refer to the same thing + order = sortperm(remoteProcIDs) + permute!(remoteProcIDs, order) + permute!(remoteGIDs, order) + permute!(remoteLIDs, order) + + exportPIDs = Array{PID, 1}(length(userExportPIDs)) + exportLIDs = Array{PID, 1}(length(userExportPIDs)) + + #need the funcitons with these names, not the variables + JuliaPetra.remoteLIDs(importData, remoteLIDs) + JuliaPetra.exportPIDs(importData, exportPIDs) + JuliaPetra.exportLIDs(importData, exportLIDs) + + locallyComplete = true + for i = 1:length(userExportPIDs) + if userExportPIDs[i] == 0 + locallyComplete = false + end + + exportPIDs[i] = userExportPIDs[i] + exportLIDs[i] = userExportLIDs[i] + end + + isLocallyComplete(importData, locallyComplete) + #TODO create and upgrade to createFromSendsAndRecvs + #createFromSendsAndRecvs(distributor(importData), exportPIDs, remoteProcIDs) + createFromRecvs(distributor(importData), remoteGIDs, remotePIDs) + + Import(importData) +end + +function Import(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID}, remotePIDs::Nullable{AbstractArray{PID}}=Nullable{AbstractArray{PID}}(); plist...) where {GID <: Integer, PID <: Integer, LID <: Integer} + Import(source, target, remotePIDs, + Dict(Array{Tuple{Symbol, Any}, 1}(plist))) +end + +function Import(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID}, plist::Dict{Symbol}) where {GID <: Integer, PID <: Integer, LID <: Integer} + Import(source, target, Nullable{AbstractArray{PID}}(), plist) +end + +function Import(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID}, remotePIDs::Nullable{AbstractArray{PID}}, plist::Dict{Symbol}) where {GID <: Integer, PID <: Integer, LID <: Integer} + + if @debug + info("$(myPid(comm(source))): Import ctor\n") + end + + const impor = Import{GID, PID, LID}(ImportExportData(source, target)) + + const remoteGIDs = setupSamePermuteRemote(impor) + + if @debug + info("$(myPid(comm(source))): Import ctor: setupSamePermuteRemote done\n") + end + if distributedGlobal(source) + setupExport(impor, remoteGIDs, remotePIDs) + end + if @debug + info("$(myPid(comm(source))): Import ctor: done\n") + end + + impor +end + +## internal construction methods ## + +function setupSamePermuteRemote(impor::Import{GID, PID, LID}) where {GID <: Integer, PID <: Integer, LID <: Integer} + + data = impor.importData + + remoteGIDs = Array{GID, 1}(0) + + getIDSources(data, remoteGIDs) + + if length(remoteLIDs(data)) != 0 && !distributedGlobal(sourceMap(impor)) + isLocallyComplete(data, false) + + warn("Target has remote LIDs but source is not distributed globally. " * + "Importing a submap of the target map") + end + + remoteGIDs +end + + +function getIDSources(data, remoteGIDs, useRemotes=true) + const source = sourceMap(data) + const target = targetMap(data) + + const sourceGIDs = myGlobalElements(source) + const targetGIDs = myGlobalElements(target) + + const numSrcGIDs = length(sourceGIDs) + const numTgtGIDs = length(targetGIDs) + const numGIDs = min(numSrcGIDs, numTgtGIDs) + + numSameGIDs = 1 + while numSameGIDs <= numGIDs && sourceGIDs[numSameGIDs] == targetGIDs[numSameGIDs] + numSameGIDs += 1 + end + numSameGIDs -= 1 + numSameIDs(data, numSameGIDs) + + const permuteToLIDs = JuliaPetra.permuteToLIDs(data) + const permuteFromLIDs = JuliaPetra.permuteFromLIDs(data) + const remoteLIDs = JuliaPetra.remoteLIDs(data) + + + for tgtLID = (numSameGIDs+1):numTgtGIDs + const curTargetGID = targetGIDs[tgtLID] + const srcLID = lid(source, curTargetGID) + if srcLID != 0 + push!(permuteToLIDs, tgtLID) + push!(permuteFromLIDs, srcLID) + elseif useRemotes + push!(remoteGIDs, curTargetGID) + push!(remoteLIDs, tgtLID) + end + end +end + +function setupExport(impor::Import{GID, PID, LID}, remoteGIDs::AbstractArray{GID}, userRemotePIDs::Nullable{AbstractArray{PID}}) where {GID <: Integer, PID <: Integer, LID <: Integer} + data = impor.importData + const source = sourceMap(impor) + + useRemotePIDs = !isnull(userRemotePIDs) + + # Sanity Checks + if useRemotePIDs && length(get(userRemotePIDs)) != length(remoteGIDs) + throw(InvalidArgumentError("remotePIDs must either be null " * + "or match the size of remoteGIDs.")) + end + + + missingGID = 0 + + if !useRemotePIDs + newRemotePIDs = Array{PID, 1}(length(remoteGIDs)) + if @debug + info("$(myPid(comm(source))): setupExport(Import): about to call " * + "getRemoteIndexList on sourceMap\n") + end + (remoteProcIDs, remoteLIDs) = remoteIDList(source, remoteGIDs) + for e in remoteLIDs + if e == 0 + missingGID += 1 + end + end + else + remoteProcIDs = get(userRemotePIDs) + end + + #line 688 + + if missingGID != 0 + isLocallyComplete(data, false) + + warn("Source map was un-able to figure out which process owns one " * + "or more of the GIDs in the list of remote GIDs. This probably " * + "means that there is at least one GID owned by some process in " * + "the target map which is not owned by any process in the source " * + "Map. (That is, the source and target maps do not contain the " * + "same set of GIDs globally") + + #ignore remote GIDs that aren't owned by any process in the source Map + numInvalidRemote = missingGID + totalNumRemote = length(remoteGIDs) + if numInvalidRemote == totalNumRemote + #if all remotes are invalid, can delete them all + empty!(remoteProcIDs) + empty!(remoteGIDs) + empty!(JuliaPetra.remoteLIDs(data)) + else + numValidRemote = 1 + + remoteLIDs = JuliaPetra.remoteLIDs(data) + + for r = 1:totalNumRemote + if remoteProcIds[r] != 0 + remoteProcIds[numValidRemote] = remoteProcIDs[r] + remoteGIDs[numValidRemote] = remoteGIDs[r] + remoteLIDs[numValidRemote] = remoteLIDs[r] + numValidRemote += 1 + end + end + numValidRemote -= 1 + + if numValidRemote != totalNumRemote - numInvalidRemote + throw(InvalidStateError("numValidRemote = $numValidRemote " * + "!= totalNumRemote - numInvalidRemote " * + "= $(totalNumRemote - numInvalidRemote)")) + end + + resize!(remoteProcIDs, numValidRemote) + resize!(remoteGIDs, numValidRemote) + resize!(remoteLIDs, numValidRemote) + end + end + + order = sortperm(remoteProcIDs) + permute!(remoteProcIDs, order) + permute!(remoteGIDs, order) + permute!(remoteLIDs, order) + + (exportGIDs, exportPIDs) = createFromRecvs(distributor(data), remoteGIDs, remoteProcIDs) + + JuliaPetra.exportPIDs(data, exportPIDs) + + numExportIDs = length(exportGIDs) + + if numExportIDs > 0 + exportLIDs = JuliaPetra.exportLIDs(data) + resize!(exportLIDs, numExportIDs) + for k in 1:numExportIDs + exportLIDs[k] = lid(source, exportGIDs[k]) + end + end + + if @debug + info("$(myPid(comm(source))): setupExport: done\n") + end +end + +## Getters ## + +""" +Get the source map for the given ImportExportData +""" +function sourceMap(impor::Import{GID, PID, LID})::BlockMap{GID, PID, LID} where GID <: Integer where PID <:Integer where LID <: Integer + impor.importData.source +end + +""" +Get the target map for the given ImportExportData +""" +function targetMap(impor::Import{GID, PID, LID})::BlockMap{GID, PID, LID} where GID <: Integer where PID <:Integer where LID <: Integer + impor.importData.target +end + +""" +List of elements in the target map that are permuted. +""" +function permuteToLIDs(impor::Import{GID, PID, LID})::Array{LID} where GID <: Integer where PID <:Integer where LID <: Integer + impor.importData.permuteToLIDs +end + +""" +List of elements in the source map that are permuted. +""" +function permuteFromLIDs(impor::Import{GID, PID, LID})::Array{LID} where GID <: Integer where PID <:Integer where LID <: Integer + impor.importData.permuteFromLIDs +end + +""" +List of elements in the target map that are coming from other processors +""" +function remoteLIDs(impor::Import{GID, PID, LID})::Array{LID} where GID <: Integer where PID <: Integer where LID <: Integer + impor.importData.remoteLIDs +end + +""" +List of elements that will be sent to other processors +""" +function exportLIDs(impor::Import{GID, PID, LID})::Array{LID} where GID <: Integer where PID <: Integer where LID <: Integer + impor.importData.exportLIDs +end + +""" +List of processors to which elements will be sent `exportLID[i]` will be sent to processor `exportPIDs[i]` +""" +function exportPIDs(impor::Import{GID, PID, LID})::Array{PID} where GID <: Integer where PID <: Integer where LID <: Integer + impor.importData.exportPIDs +end + +""" +Returns the number of elements that are identical between the source and target maps, up to the first different ID +""" +function numSameIDs(impor::Import{GID, PID, LID})::LID where GID <: Integer where PID <: Integer where LID <: Integer + impor.importData.numSameIDs +end + + +""" +Returns the distributor being used +""" +function distributor(impor::Import{GID, PID, LID})::Distributor{GID, PID, LID} where GID <: Integer where PID <: Integer where LID <: Integer + impor.importData.distributor +end + +""" +Returns whether the import or export is locally complete +""" +function isLocallyComplete(impor::Import{GID, PID, LID})::Bool where GID <: Integer where PID <: Integer where LID <: Integer + impor.importData.isLocallyComplete +end diff --git a/src/ImportExportData.jl b/src/ImportExportData.jl new file mode 100644 index 0000000..a5fe76b --- /dev/null +++ b/src/ImportExportData.jl @@ -0,0 +1,84 @@ +mutable struct ImportExportData{GID <: Integer, PID <: Integer, LID <: Integer} + source::BlockMap{GID, PID, LID} + target::BlockMap{GID, PID, LID} + + permuteToLIDs::Array{LID, 1} + permuteFromLIDs::Array{LID, 1} + remoteLIDs::Array{LID, 1} + + exportLIDs::Array{LID, 1} + exportPIDs::Array{PID, 1} + + numSameIDs::GID + distributor::Distributor{GID, PID, LID} + + isLocallyComplete::Bool +end + +## Constructors ## +function ImportExportData(source::BlockMap{GID, PID, LID}, target::BlockMap{GID, PID, LID})::ImportExportData{GID, PID, LID} where GID <: Integer where PID <:Integer where LID <: Integer + ImportExportData{GID, PID, LID}(source, target, [], [], [], [], [], 0, createDistributor(comm(source)), true) +end + + +## Getters ## +function sourceMap(data::ImportExportData{GID, PID, LID})::BlockMap{GID, PID, LID} where GID <: Integer where PID <:Integer where LID <: Integer + data.source +end + +function targetMap(data::ImportExportData{GID, PID, LID})::BlockMap{GID, PID, LID} where GID <: Integer where PID <:Integer where LID <: Integer + data.target +end + +function permuteToLIDs(data::ImportExportData{GID, PID, LID})::Array{LID} where GID <: Integer where PID <:Integer where LID <: Integer + data.permuteToLIDs +end + +function permuteFromLIDs(data::ImportExportData{GID, PID, LID})::Array{LID} where GID <: Integer where PID <:Integer where LID <: Integer + data.permuteFromLIDs +end + +function remoteLIDs(data::ImportExportData{GID, PID, LID})::Array{LID} where GID <: Integer where PID <: Integer where LID <: Integer + data.remoteLIDs +end + +function remoteLIDs(data::ImportExportData{GID, PID, LID}, remoteLIDs::AbstractArray{<: Integer}) where GID <: Integer where PID <: Integer where LID <: Integer + data.remoteLIDs = remoteLIDs +end + +function exportLIDs(data::ImportExportData{GID, PID, LID})::Array{LID} where GID <: Integer where PID <: Integer where LID <: Integer + data.exportLIDs +end + +function exportLIDs(data::ImportExportData{GID, PID, LID}, exportLIDs::AbstractArray{<: Integer}) where GID <: Integer where PID <: Integer where LID <: Integer + data.exportLIDs = exportLIDs +end + +function exportPIDs(data::ImportExportData{GID, PID, LID})::Array{PID} where GID <: Integer where PID <: Integer where LID <: Integer + data.exportPIDs +end + +function exportPIDs(data::ImportExportData{GID, PID, LID}, exportPIDs::AbstractArray{<: Integer}) where GID <: Integer where PID <: Integer where LID <: Integer + data.exportPIDs = exportPIDs +end + +function numSameIDs(data::ImportExportData{GID, PID, LID})::LID where GID <: Integer where PID <: Integer where LID <: Integer + data.numSameIDs +end + +function numSameIDs(data::ImportExportData{GID, PID, LID}, numSame::Integer)::LID where GID <: Integer where PID <: Integer where LID <: Integer + data.numSameIDs = numSame +end + + +function distributor(data::ImportExportData{GID, PID, LID})::Distributor{GID, PID, LID} where GID <: Integer where PID <: Integer where LID <: Integer + data.distributor +end + +function isLocallyComplete(data::ImportExportData{GID, PID, LID})::Bool where GID <: Integer where PID <: Integer where LID <: Integer + data.isLocallyComplete +end + +function isLocallyComplete(data::ImportExportData{GID, PID, LID}, isLocallyComplete::Bool) where GID <: Integer where PID <: Integer where LID <: Integer + data.isLocallyComplete = isLocallyComplete +end \ No newline at end of file diff --git a/src/JuliaPetra.jl b/src/JuliaPetra.jl new file mode 100644 index 0000000..2b16458 --- /dev/null +++ b/src/JuliaPetra.jl @@ -0,0 +1,66 @@ +using TypeStability + +module JuliaPetra + +# Internal Utilities +include("Enums.jl") +include("Error.jl") +include("Macros.jl") + +include("ComputeOffsets.jl") + + +# Communication interface +include("Distributor.jl") +include("Directory.jl") +include("Comm.jl") +include("LocalComm.jl") + +include("BlockMapData.jl") +include("BlockMap.jl") + +include("DirectoryMethods.jl") +include("BasicDirectory.jl") + + +# Serial Communication +include("SerialDistributor.jl") +#include("SerialDirectory.jl") +include("SerialComm.jl") + + +# MPI Communication +include("MPIUtil.jl") +include("MPIComm.jl") +include("MPIDistributor.jl") + + +# Data interface +include("ImportExportData.jl") +include("Import.jl") +include("Export.jl") + +include("SrcDistObject.jl") +include("DistObject.jl") + + +# Dense Data types +include("MultiVector.jl") + + +# Sparse Data types +include("SparseRowView.jl") +include("LocalCRSGraph.jl") +include("LocalCSRMatrix.jl") + +include("RowGraph.jl") +include("CRSGraphConstructors.jl") +include("CRSGraphInternalMethods.jl") +include("CRSGraphExternalMethods.jl") + +include("RowMatrix.jl") +include("CSRMatrix.jl") + +include("Operator.jl") + +end # module diff --git a/src/LocalCRSGraph.jl b/src/LocalCRSGraph.jl new file mode 100644 index 0000000..4d5c6be --- /dev/null +++ b/src/LocalCRSGraph.jl @@ -0,0 +1,62 @@ +export LocalCRSGraph, numRows, maxEntry, minEntry + +""" + LocalCRSGraph{EntriesType, IndexType}() + LocalCRSGraph(entries::AbstractArray{EntriesType, 1}, rowMap::AbstractArray{IndexType, 1}) + +A compressed row storage array. Used by CRSGraph to store local structure. +`EntriesType` is the type of the data being held +`IndexType` is the type used to represent the indices +""" +mutable struct LocalCRSGraph{EntriesType, IndexType <: Integer} + entries::AbstractArray{EntriesType, 1} + rowMap::AbstractArray{IndexType, 1} +end + +function LocalCRSGraph{EntriesType, IndexType}() where{EntriesType, IndexType <: Integer} + LocalCRSGraph(Array{EntriesType, 1}(0), Array{IndexType, 1}(0)) +end + + +""" + numRows(::LocalCRSGraph{EntriesType, IndexType})::IndexType + +Gets the number of rows in the storage +""" +function numRows(graph::LocalCRSGraph{EntriesType, IndexType})::IndexType where { + EntriesType, IndexType <: Integer} + len = length(graph.rowMap) + if len != 0 + len - 1 + else + 0 + end +end + +""" + maxEntry(::LocalCRSGraph{EntriesType})::EntriesType + +Finds the entry with the maximum value. +""" +function maxEntry(graph::LocalCRSGraph{EntriesType})::EntriesType where { + EntriesType} + if length(graph.entries) != 0 + maximum(graph.entries) + else + throw(InvalidArgumentError("Cannot find the maximum of an empty graph")) + end +end + +""" + minEntry(::LocalCRSGraph{EntriesType})::EntriesType + +Finds the entry with the minimum value. +""" +function minEntry(graph::LocalCRSGraph{EntriesType})::EntriesType where { + EntriesType} + if length(graph.entries) != 0 + minimum(graph.entries) + else + throw(InvalidArgumentError("Cannot find the minimum of an empty graph")) + end +end \ No newline at end of file diff --git a/src/LocalCSRMatrix.jl b/src/LocalCSRMatrix.jl new file mode 100644 index 0000000..493a1e0 --- /dev/null +++ b/src/LocalCSRMatrix.jl @@ -0,0 +1,78 @@ +export LocalCSRMatrix, numRows, numCols, getRowView + +struct LocalCSRMatrix{Data, IndexType <: Integer} + graph::LocalCRSGraph{IndexType, IndexType} + values::Union{Vector{Data}, SubArray{Data, 1, Vector{Data}, Tuple{UnitRange{IndexType}}, true}} + + numCols::IndexType +end + +""" + LocalCSRMatrix{Data, IndexType}() + +Creates an empty LocalCSRMatrix +""" +function LocalCSRMatrix{Data, IndexType}() where {Data, IndexType} + LocalCSRMatrix(LocalCRSGraph{IndexType, IndexType}(), Data[], IndexType(0)) +end + +""" + LocalCSRMatrix(nRows::Integer, nCols::Integer, vals::AbstractArray{Data, 1}, rows::AbstractArray{IndexType, 1}, cols::AbstractArray{IndexType, 1}} + +Creates the specified LocalCSRMatrix +""" +function LocalCSRMatrix(nRows::Integer, nCols::Integer, + vals::AbstractArray{Data, 1}, rows::AbstractArray{IndexType, 1}, + cols::AbstractArray{IndexType, 1}) where {Data, IndexType} + if length(rows) != nRows + 1 + throw(InvalidArgumentError("length(rows) = $(length(rows)) != nRows+1 " + * "= $(nRows + 1)")) + end + LocalCSRMatrix(LocalCRSGraph(cols, rows), vals, IndexType(nCols)) +end + +""" + LocalCSRMatrix(numCols::IndexType, values::AbstractArray{Data, 1}, localGraph::LocalCRSGraph{IndexType, IndexType}) where {IndexType, Data <: Number} + +Creates the specified LocalCSRMatrix +""" +function LocalCSRMatrix(numCols::IndexType, values::AbstractArray{Data, 1}, + localGraph::LocalCRSGraph{IndexType, IndexType}) where {IndexType, Data <: Number} + if numCols < 0 + throw(InvalidArgumentError("Cannot have a negative number of rows")) + end + + LocalCSRMatrix(localGraph, values, numCols) +end + + +""" + numRows(::LocalCSRMatrix{Data, IndexType})::IndexType + +Gets the number of rows in the matrix +""" +numRows(matrix::LocalCSRMatrix) = numRows(matrix.graph) + +""" + numCols(::LocalCSRMatrix{Data, IndexType})::IndexType + +Gets the number of columns in the matrix +""" +numCols(matrix::LocalCSRMatrix) = matrix.numCols + +""" + getRowView((matrix::LocalCSRMatrix{Data, IndexType}, row::Integer)::SparseRowView{Data, IndexType} + +Gets a view of the requested row +""" +function getRowView(matrix::LocalCSRMatrix{Data, IndexType}, + row::Integer)::SparseRowView{Data, IndexType} where {Data, IndexType} + start = matrix.graph.rowMap[row] + count = matrix.graph.rowMap[row+1] - start + + if count == 0 + SparseRowView(Data[], IndexType[]) + else + SparseRowView(matrix.values, matrix.graph.entries, count, start) + end +end diff --git a/src/LocalComm.jl b/src/LocalComm.jl new file mode 100644 index 0000000..1814e8b --- /dev/null +++ b/src/LocalComm.jl @@ -0,0 +1,53 @@ +export LocalComm + +#This class is to create a stand in for Tpetra's local maps + +""" + LocalComm(::Comm{GID, PID, LID}) + +Creates a comm object that creates an error when inter-process communication is attempted, but still allows access to the correct process ID information +""" +struct LocalComm{GID <: Integer, PID <: Integer, LID <: Integer} <: Comm{GID, PID, LID} + original::Comm{GID, PID, LID} +end + + +function barrier(comm::LocalComm) + throw(InvalidStateError("Cannot call barrier on a local comm")) +end + +function broadcastAll(comm::LocalComm, v::AbstractArray, root::Integer) + throw(InvalidStateError("Cannot call broadcastAll on a local comm")) +end + +function gatherAll(comm::LocalComm, v::AbstractArray) + throw(InvalidStateError("Cannot call gatherAll on a local comm")) +end + +function sumAll(comm::LocalComm, v::AbstractArray) + throw(InvalidStateError("Cannot call sumAll on a local comm")) +end + +function maxAll(comm::LocalComm, v::AbstractArray) + throw(InvalidStateError("Cannot call maxAll on a local comm")) +end + +function minAll(comm::LocalComm, v::AbstractArray) + throw(InvalidStateError("Cannot call minAll on a local comm")) +end + +function scanSum(comm::LocalComm, v::AbstractArray) + throw(InvalidStateError("Cannot call scanSum on a local comm")) +end + +function myPid(comm::LocalComm) + myPid(comm.original) +end + +function numProc(comm::LocalComm) + numProc(comm.original) +end + +function createDistributor(comm::LocalComm) + throw(InvalidStateError("Cannot call createDistributor on a local comm")) +end \ No newline at end of file diff --git a/src/MPIComm.jl b/src/MPIComm.jl new file mode 100644 index 0000000..9bdf14b --- /dev/null +++ b/src/MPIComm.jl @@ -0,0 +1,91 @@ + +import MPI + +export MPIComm + +""" + MPIComm() + MPIComm(comm::MPI.Comm) + +An implementation of Comm using MPI +The no argument constructor uses MPI.COMM_WORLD +""" +struct MPIComm{GID <: Integer, PID <:Integer, LID <: Integer} <: Comm{GID, PID, LID} + mpiComm::MPI.Comm +end + +function MPIComm(GID::Type, PID::Type, LID::Type) + MPIInit() + comm = MPIComm{GID, PID, LID}(MPI.COMM_WORLD) + + comm +end + +MPINeedsInitialization = true + +""" + MPIInit() + +On the first call, initializes MPI and adds an exit hook to finalize MPI +Does nothing on subsequent calls +""" +function MPIInit() + global MPINeedsInitialization + if MPINeedsInitialization + MPI.Init() + atexit(() -> MPI.Finalize()) + + MPINeedsInitialization = false + end +end + +function barrier(comm::MPIComm) + MPI.Barrier(comm.mpiComm) +end + +function broadcastAll(comm::MPIComm, myvals::AbstractArray{T}, root::Integer)::Array{T} where T + vals = copy(myvals) + result = MPI.Bcast!(vals, root-1, comm.mpiComm) + result +end + +function gatherAll(comm::MPIComm, myVals::AbstractArray{T})::Array{T} where T + lengths = MPI.Allgather([convert(Cint, length(myVals))], comm.mpiComm) + MPI.Allgatherv(myVals, lengths, comm.mpiComm) +end + +function sumAll(comm::MPIComm, partialsums::AbstractArray{T})::Array{T} where T + MPI.allreduce(partialsums, +, comm.mpiComm) +end + +function maxAll(comm::MPIComm, partialmaxes::AbstractArray{T})::Array{T} where T + MPI.allreduce(partialmaxes, max, comm.mpiComm) +end + +function maxAll(comm::MPIComm, partialmaxes::AbstractArray{Bool})::Array{Bool} + Array{Bool}(maxAll(comm, Array{UInt8}(partialmaxes))) +end + +function minAll(comm::MPIComm, partialmins::AbstractArray{T})::Array{T} where T + MPI.allreduce(partialmins, min, comm.mpiComm) +end + +function minAll(comm::MPIComm, partialmins::AbstractArray{Bool})::Array{Bool} + Array{Bool}(minAll(comm, Array{UInt8}(partialmins))) +end + +function scanSum(comm::MPIComm, myvals::AbstractArray{T})::Array{T} where T + MPI.Scan(myvals, length(myvals), MPI.SUM, comm.mpiComm) +end + +function myPid(comm::MPIComm{GID, PID})::PID where {GID <: Integer, PID <: Integer} + MPI.Comm_rank(comm.mpiComm) + 1 +end + +function numProc(comm::MPIComm{GID, PID})::PID where {GID <: Integer, PID <:Integer} + MPI.Comm_size(comm.mpiComm) +end + +function createDistributor(comm::MPIComm{GID, PID, LID})::MPIDistributor{GID, PID, LID} where {GID <: Integer, PID <: Integer, LID <: Integer} + MPIDistributor(comm) +end diff --git a/src/MPIDistributor.jl b/src/MPIDistributor.jl new file mode 100644 index 0000000..8bf8190 --- /dev/null +++ b/src/MPIDistributor.jl @@ -0,0 +1,535 @@ + +import MPI + +export MPIDistributor + +""" + MPIDistributor{GID, PID, LID}(comm::MPIComm{GID, PID, LID}) +Creates an Distributor to work with MPIComm. Created by +createDistributor(::MPIComm{GID, PID, LID}) +""" +mutable struct MPIDistributor{GID <: Integer, PID <: Integer, LID <: Integer} <: Distributor{GID, PID, LID} + comm::MPIComm{GID, PID, LID} + + lengths_to::Vector{GID} + procs_to::Vector{PID} + indices_to::Vector{GID} + + lengths_from::Vector{GID} + procs_from::Vector{PID} + indices_from::Vector{GID} + + resized::Bool + sizes::Vector{GID} + + sizes_to::Vector{GID} + starts_to::Vector{GID} + #starts_to_ptr::Array{Integer} + #indices_to_ptr::Array{Integer} + + sizes_from::Vector{GID} + starts_from::Vector{GID} + #sizes_from_ptr::Array{Integer} + #starts_from_ptr::Array{Integer} + + numRecvs::GID + numSends::GID + numExports::GID + + selfMsg::GID + + maxSendLength::GID + totalRecvLength::GID + tag::GID + + request::Vector{MPI.Request} + status::Vector{MPI.Status} + + #sendArray::Array{UInt8} + + planReverse::Nullable{MPIDistributor{GID, PID, LID}} + + importObjs::Nullable{Vector{Vector{UInt8}}} + + #never seem to be used + #lastRoundBytesSend::Integer + #lastRoundBytesRecv::Integer + + function MPIDistributor(comm::MPIComm{GID, PID, LID}) where GID <: Integer where PID <: Integer where LID <: Integer + new{GID, PID, LID}(comm, [], [], [], [], [], [], false, [], [], [], [], [], + 0, 0, 0, 0, 0, 0, 0, [], [], Nullable{MPIDistributor}(), + Nullable{Array{UInt8}}()) + end +end + + +#### internal methods #### +function createSendStructure(dist::MPIDistributor{GID, PID, LID}, pid::PID, + nProcs::PID, exportPIDs::AbstractArray{PID} + ) where {GID <: Integer, PID <: Integer, LID <: Integer} + + numExports = length(exportPIDs) + dist.numExports = numExports + + #starts = Array{Integer}(nProcs + 1) + #fill!(starts, 0) + starts = zeros(GID, nProcs+1) + + nactive = 0 + noSendBuff = true + numDeadIndices::GID = 0 #for GIDs not owned by any processors + + for i = 1:numExports + if noSendBuff && i > 1 && exportPIDs[i] < exportPIDs[i-1] + noSendBuff = false + end + if exportPIDs[i] >= 1 + starts[exportPIDs[i]] += 1 + nactive += 1 + else + numDeadIndices += 1 + end + end + + dist.selfMsg = starts[pid] != 0 + dist.numSends = 0 + + if noSendBuff #grouped by processor, no send buffer or indices_to needed + for i = 1:nProcs + if starts[i] > 0 + dist.numSends += 1 + end + end + + dist.procs_to = Vector{PID}(dist.numSends) + dist.starts_to = Vector{GID}(dist.numSends) + dist.lengths_to = Vector{GID}(dist.numSends) + + index = numDeadIndices+1 + for i = 1:dist.numSends + dist.starts_to[i] = index + proc = exportPIDs[index] + dist.procs_to[i] = proc + index += starts[proc] + end + + perm = sortperm(dist.procs_to) + dist.procs_to = dist.procs_to[perm] + dist.starts_to = dist.starts_to[perm] + + # line 430 + + dist.maxSendLength = 0 + + for i = 1:dist.numSends + proc = dist.procs_to[i] + dist.lengths_to[i] = starts[proc] + if (proc != pid) && (dist.lengths_to[i] > dist.maxSendLength) + maxSendLength = dist.lengths_to[i] + end + end + else #not grouped by processor, need send buffer and indices_to + if starts[1] != 0 + dist.numSends = 1 + end + + for i = 2:nProcs + if starts[i] != 0 + dist.numSends += 1 + end + starts[i] += starts[i-1] + end + + for i = nProcs:-1:2 + starts[i] = starts[i-1] + 1 + end + starts[1] = 1 + + if nactive > 0 + dist.indices_to = Array{GID, 1}(nactive) + end + + + for i = 1:numExports + if exportPIDs[i] >= 1 + dist.indices_to[starts[exportPIDs[i]]] = i + starts[exportPIDs[i]] += 1 + end + end + + #reconstruct starts array to index into indices_to + + for i = nProcs:-1:2 + starts[i] = starts[i-1] + end + starts[1] = 1 + starts[nProcs+1] = nactive+1 + + + if dist.numSends > 0 + dist.lengths_to = Array{GID}(dist.numSends) + dist.procs_to = Array{PID}(dist.numSends) + dist.starts_to = Array{GID}(dist.numSends) + end + + j::GID = 1 + dist.maxSendLength = 0 + for i = 1:nProcs + if starts[i+1] != starts[i] + dist.lengths_to[j] = starts[i+1] - starts[i] + dist.starts_to[j] = starts[i] + if (i != pid) && (dist.lengths_to[j] > dist.maxSendLength) + dist.maxSendLength = dist.lengths_to[j] + end + dist.procs_to[j] = i + j += 1 + end + end + end + + dist.numSends -= dist.selfMsg + + dist +end + + +function computeRecvs(dist::MPIDistributor{GID, PID, LID}, myProc::PID, nProcs::PID) where GID <: Integer where PID <: Integer where LID <: Integer + + msgCount = zeros(Int, nProcs) + + for i = 1:(dist.numSends + dist.selfMsg) + msgCount[dist.procs_to[i]] += 1 + end + + #bug fix for reduce-scatter bug applied since no reduce_scatter is present in julia's MPI + rawCounts = MPI.Reduce(msgCount, MPI.SUM, 0, dist.comm.mpiComm) + if rawCounts isa Void + counts = Int[] + else + counts = rawCounts + end + dist.numRecvs = MPI.Scatter(counts, 1, 0, dist.comm.mpiComm)[1] + + dist.lengths_from = zeros(Int, dist.numRecvs) + dist.procs_from = zeros(PID, dist.numRecvs) + + #using NEW_COMM_PATTERN (see line 590) + + if dist.request == [] + dist.request = Array{MPI.Request}(dist.numRecvs - dist.selfMsg) + end + + #line 616 + + lengthWrappers = [Array{Int, 1}(1) for i in 1:(dist.numRecvs - dist.selfMsg)] + for i = 1:(dist.numRecvs - dist.selfMsg) + dist.request[i] = MPI.Irecv!(lengthWrappers[i], MPI.ANY_SOURCE, dist.tag, dist.comm.mpiComm) + end + + for i = 1:(dist.numSends+dist.selfMsg) + if dist.procs_to[i] != myProc + #have to use Rsend in MPIUtil + MPI_Rsend(dist.lengths_to[i], dist.procs_to[i]-1, dist.tag, dist.comm.mpiComm) + else + dist.lengths_from[dist.numRecvs] = dist.lengths_to[i] + dist.procs_from[dist.numRecvs] = myProc + end + end + + if dist.numRecvs > dist.selfMsg + dist.status = MPI.Waitall!(dist.request) + end + + for i = 1:(dist.numRecvs - dist.selfMsg) + dist.lengths_from[i] = lengthWrappers[i][1] + end + + + for i = 1:(dist.numRecvs - dist.selfMsg) + dist.procs_from[i] = MPI.Get_source(dist.status[i])+1 + end + + perm = sortperm(dist.procs_from) + dist.procs_from = dist.procs_from[perm] + dist.lengths_from = dist.lengths_from[perm] + + dist.starts_from = Vector{GID}(dist.numRecvs) + j = GID(1) + for i = 1:dist.numRecvs + dist.starts_from[i] = j + j += dist.lengths_from[i] + end + + dist.totalRecvLength = 0 + for i = 1:dist.numRecvs + dist.totalRecvLength += dist.lengths_from[i] + end + + dist.numRecvs -= dist.selfMsg + + dist +end + +function computeSends(dist::MPIDistributor{GID, PID, LID}, + remoteGIDs::AbstractArray{GID, 1}, remotePIDs::AbstractArray{PID, 1} + )::Tuple{AbstractArray{GID, 1}, AbstractArray{PID, 1} + } where {GID <:Integer, PID <:Integer, LID <:Integer} + numImports = length(remoteGIDs) + + tmpPlan = MPIDistributor(dist.comm) + + importObjs = Array{Tuple{GID, PID}}(numImports) + for i = 1:numImports + importObjs[i] = (remoteGIDs[i], myPid(dist.comm))#remotePIDs[i]) + end + + numExports = createFromSends(tmpPlan, copy(remotePIDs)) + + exportIDs = Array{GID}(numExports) + exportProcs = Array{PID}(numExports) + + exportObjs = resolve(tmpPlan, importObjs) + for i = 1:numExports + exportIDs[i] = exportObjs[i][1] + exportProcs[i] = exportObjs[i][2] + end + (exportIDs, exportProcs) +end + +""" +Creates a reverse distributor for the given MPIDistributor +""" +function createReverseDistributor(dist::MPIDistributor{GID, PID, LID} + ) where {GID <: Integer, PID <: Integer, LID <: Integer} + myProc = myPid(dist.comm) + + if isnull(dist.planReverse) + totalSendLength = reduce(+, dist.lengths_to) + + maxRecvLength::GID = 0 + for i = 1:dist.numRecvs + if dist.procs_from[i] != myProc + maxRecvLength = max(maxRecvLength, dist.lengths_from[i]) + end + end + + reverse = MPIDistributor(dist.comm) + dist.planReverse = Nullable(reverse) + + reverse.lengths_to = dist.lengths_from + reverse.procs_to = dist.procs_from + reverse.indices_to = dist.indices_from + reverse.starts_to = dist.starts_from + + reverse.lengths_from = dist.lengths_to + reverse.procs_from = dist.procs_to + reverse.indices_from = dist.indices_to + reverse.starts_from = dist.starts_to + + reverse.numSends = dist.numRecvs + reverse.numRecvs = dist.numSends + reverse.selfMsg = dist.selfMsg + + reverse.maxSendLength = maxRecvLength + reverse.totalRecvLength = totalSendLength + + reverse.request = Array{MPI.Request}(reverse.numRecvs) + reverse.status = Array{MPI.Status}(reverse.numRecvs) + end + + nothing +end + + +#### Distributor interface #### + +function createFromSends(dist::MPIDistributor{GID, PID, LID}, exportPIDs::AbstractArray{PID, 1})::Integer where GID <:Integer where PID <:Integer where LID <:Integer + const pid = myPid(dist.comm) + const nProcs = numProc(dist.comm) + createSendStructure(dist, pid, nProcs, exportPIDs) + computeRecvs(dist, pid, nProcs) + if dist.numRecvs > 0 + if dist.request == [] + dist.request = Vector{MPI.Request}(dist.numRecvs) + dist.status = Vector{MPI.Status}(dist.numRecvs) + end + end + dist.totalRecvLength +end + + +function createFromRecvs(dist::MPIDistributor{GID, PID, LID}, + remoteGIDs::AbstractArray{GID, 1}, remotePIDs::AbstractArray{PID, 1} + )::Tuple{AbstractArray{GID, 1}, AbstractArray{PID, 1} + } where {GID <: Integer, PID <: Integer, LID <: Integer} + if length(remoteGIDs) != length(remotePIDs) + throw(InvalidArgumentError("remote lists must be the same length")) + end + (exportGIDs, exportPIDs) = computeSends(dist, remoteGIDs, remotePIDs) + createFromSends(dist, exportPIDs) + + exportGIDs, exportPIDs +end + +function resolvePosts(dist::MPIDistributor{GID, PID, LID}, exportObjs::AbstractArray{T, 1}) where {T, GID<:Integer, PID<:Integer, LID<:Integer} + myProc = myPid(dist.comm) + + nBlocks::GID = dist.numSends + dist.selfMsg + procIndex::GID = 1 + while procIndex <= nBlocks && dist.procs_to[procIndex] < myProc + procIndex += 1 + end + if procIndex == nBlocks + procIndex = 1 + end + + exportBytes = Vector{Vector{UInt8}}(dist.numRecvs + dist.selfMsg) + + j::GID = 1 + + buffer = IOBuffer() + if length(dist.indices_to) == 0 #data already grouped by processor + for i = 1:nBlocks + Serializer.serialize(buffer, exportObjs[j:j+dist.lengths_to[i]-1]) + j += dist.lengths_to[i] + exportBytes[i] = take!(buffer) + end + else #data not grouped by proc, must be grouped first + for i = 1:nBlocks + j = dist.starts_to[i] + sendArray = Array{T}(dist.lengths_to[i]) + for k = 1:dist.lengths_to[i] + sendArray[k] = exportObjs[dist.indices_to[j+k-1]] + end + Serializer.serialize(buffer, sendArray) + exportBytes[i] = take!(buffer) + end + end + + + ## get sizes of data begin received ## + lengthRequests = Array{MPI.Request}(dist.numRecvs) + lengths = Array{Array{Int, 1}}(dist.numRecvs + dist.selfMsg) + for i = 1:dist.numRecvs + dist.selfMsg + lengths[i] = Array{Int}(1) + end + + j = 1 + + for i = 1:dist.numRecvs + dist.selfMsg + if dist.procs_from[i] != myProc + lengthRequests[j] = MPI.Irecv!(lengths[i], dist.procs_from[i]-1, dist.tag, dist.comm.mpiComm) + j += 1 + end + end + + barrier(dist.comm) + + for i = 1:dist.numSends + dist.selfMsg + p = procIndex + 1 + if p > nBlocks + p -= nBlocks + end + if dist.procs_to[i] != myProc + MPI_Rsend(length(exportBytes[i]), dist.procs_to[i]-1, dist.tag, dist.comm.mpiComm) + else + lengths[i][1] = length(exportBytes[i]) + end + end + + MPI.Waitall!(lengthRequests) + #at this point `lengths` should contain the sizes of incoming data + + importObjs = Vector{Vector{UInt8}}(dist.numRecvs+dist.selfMsg) + for i = 1:length(importObjs) + importObjs[i] = Vector{UInt8}(lengths[i][1]) + end + + dist.importObjs = Nullable(importObjs) + + ## back to the regularly scheduled program ## + + k::GID = 0 + j = 0 + selfRecvAddress::GID = 0 + for i = 1:dist.numRecvs + dist.selfMsg + if dist.procs_from[i] != myProc + MPI.Irecv!(importObjs[i], dist.procs_from[i]-1, dist.tag, dist.comm.mpiComm) + else + selfRecvAddress = i + end + end + + barrier(dist.comm) + + #line 844 + + selfNum::GID = 0 + +# if dist.indices_to == [] #data already grouped by processor + for i = 1:nBlocks + p = i + procIndex - 1 + if p > nBlocks + p -= nBlocks + end + if dist.procs_to[p] != myProc + MPI_Rsend(exportBytes[p], dist.procs_to[p]-1, dist.tag, dist.comm.mpiComm) + else + selfNum = p + end + end + + if dist.selfMsg != 0 + importObjs[selfRecvAddress] = exportBytes[selfNum] + end + + nothing +end + + + +function resolveWaits(dist::MPIDistributor)::Array + barrier(dist.comm)#run into issues deserializing otherwise + if dist.numRecvs> 0 + dist.status = MPI.Waitall!(dist.request) + end + + if isnull(dist.importObjs) + throw(InvalidStateError("Cannot resolve waits when no posts have been made")) + end + + importObjs = get(dist.importObjs) + deserializedObjs = Vector{Vector{Any}}(length(importObjs)) + for i = 1:length(importObjs) + deserializedObjs[i] = Serializer.deserialize(IOBuffer(importObjs[i])) + end + + dist.importObjs = Nullable{Array{Array{UInt8}}}() + + reduce(vcat, [], deserializedObjs) +end + + +function resolveReversePosts(dist::MPIDistributor{GID, PID, LID}, + exportObjs::AbstractVector{T} + ) where {GID <: Integer, PID <: Integer, LID <: Integer, T} + if dist.indices_to != [] + throw(InvalidStateError("Cannot do reverse comm when data is not blocked by processor")) + end + + if isnull(dist.planReverse) + createReverseDistributor(dist) + end + + resolvePosts(get(dist.planReverse), exportObjs) +end + + +function resolveReverseWaits(dist::MPIDistributor{GID, PID, LID} + )::Vector where {GID <: Integer, PID <: Integer, LID <: Integer} + if isnull(dist.planReverse) + throw(InvalidStateError("Cannot resolve reverse waits if there is no reverse plan")) + end + + resolveWaits(get(dist.planReverse)) +end diff --git a/src/MPIUtil.jl b/src/MPIUtil.jl new file mode 100644 index 0000000..7125325 --- /dev/null +++ b/src/MPIUtil.jl @@ -0,0 +1,20 @@ +#Contains MPI related things that the library is missing +import MPI + +#TODO document + +function MPI_Rsend{T}(buf::MPI.MPIBuffertype{T}, count::Integer, + dest::Integer, tag::Integer, comm::MPI.Comm) + ccall(MPI.MPI_RSEND, Void, + (Ptr{T}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, Ptr{Cint}, + Ptr{Cint}), + buf, &count, &MPI.mpitype(T), &dest, &tag, &comm.val, &0) +end + +function MPI_Rsend{T}(buf::AbstractArray{T}, dest::Integer, tag::Integer, comm::MPI.Comm) + MPI_Rsend(buf, length(buf), dest, tag, comm) +end + +function MPI_Rsend{T}(obj::T, dest::Integer, tag::Integer, comm::MPI.Comm) + MPI_Rsend([obj], dest, tag, comm) +end \ No newline at end of file diff --git a/src/Macros.jl b/src/Macros.jl new file mode 100644 index 0000000..ce1ac17 --- /dev/null +++ b/src/Macros.jl @@ -0,0 +1,14 @@ +#contains random utility macros + +""" +Returns the global debug value + +Has an optional ignored argument +""" +macro debug() + if isdefined(Main, :globalDebug) + Main.globalDebug::Bool + else + false + end +end diff --git a/src/MultiVector.jl b/src/MultiVector.jl new file mode 100644 index 0000000..4a1e02a --- /dev/null +++ b/src/MultiVector.jl @@ -0,0 +1,301 @@ +export MultiVector +export localLength, globalLength, numVectors, map +export scale!, scale +export getVectorView, getVectorCopy +export commReduce, norm2 + + + +""" +MultiVector represents a dense multi-vector. Note that all the vectors in a single MultiVector are the same size +""" +type MultiVector{Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} <: DistObject{GID, PID, LID} + data::Array{Data, 2} # data[1, 2] is the first element of the second matrix + localLength::LID + globalLength::GID + numVectors::LID + + map::BlockMap{GID, PID, LID} +end + +## Constructors ## + +""" + MultiVector{Data, GID, PID, LID}(::BlockMap{GID, PID, LID}, numVecs::Integer, zeroOut=true) + +Creates a new MultiVector based on the given map +""" +function MultiVector{Data, GID, PID, LID}(map::BlockMap{GID, PID, LID}, numVecs::Integer, zeroOut=true) where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + localLength = numMyElements(map) + if zeroOut + data = zeros(Data, (localLength, numVecs)) + else + data = Array{Data, 2}(localLength, numVecs) + end + MultiVector{Data, GID, PID, LID}(data, localLength, numGlobalElements(map), numVecs, map) +end + +""" + MultiVector{Data, GID, PID, LID}(map::BlockMap{GID, PID, LID}, data::AbstractArray{Data, 2}) + +Creates a new MultiVector wrapping the given array. Changes to the MultiVector or Array will affect the other +""" +function MultiVector(map::BlockMap{GID, PID, LID}, data::AbstractArray{Data, 2}) where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + localLength = numMyElements(map) + if size(data, 1) != localLength + throw(InvalidArgumentError("Length of vectors does not match local length indicated by map")) + end + MultiVector{Data, GID, PID, LID}(data, localLength, numGlobalElements(map), size(data, 2), map) +end + +## External methods ## + +""" + copy(::MutliVector{Data, GID, PID, LID})::MultiVector{Data, GID, PID, LID} +Returns a copy of the multivector +""" +function Base.copy(vect::MultiVector{Data, GID, PID, LID})::MultiVector{Data, GID, PID, LID} where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + MultiVector{Data, GID, PID, LID}(copy(vect.data), vect.localLength, vect.globalLength, vect.numVectors, vect.map) +end + +function Base.copy!(dest::MultiVector{Data, GID, PID, LID}, src::MultiVector{Data, GID, PID, LID})::MultiVector{Data, GID, PID, LID} where {Data, GID, PID, LID} + copy!(dest.data, src.data) + dest.localLength = src.localLength + dest.globalLength = src.globalLength + dest.numVectors = src.numVectors + dest.map = src.map + + dest +end + +""" + localLength(::MutliVector{Data, GID, PID, LID})::LID + +Returns the local length of the vectors in the multivector +""" +function localLength(vect::MultiVector{Data, GID, PID, LID})::LID where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + vect.localLength +end + +""" + globalLength(::MultiVector{Data, GID, PID, LID})::GID + +Returns the global length of the vectors in the mutlivector +""" +function globalLength(vect::MultiVector{Data, GID})::GID where {Data <: Number, GID <: Integer} + vect.globalLength +end + +""" + numVectors(::MultiVector{Data, GID, PID, LID})::LID + +Returns the number of vectors in this multivector +""" +function numVectors(vect::MultiVector{Data, GID, PID, LID})::LID where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + vect.numVectors +end + +""" + map(::MultiVector{Data, GID, PID, LID})::BlockMap{GID, PID, LID} + +Returns the BlockMap used by this multivector +""" +function map(vect::MultiVector{Data, GID, PID, LID})::BlockMap{GID, PID, LID} where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + vect.map +end + + +# have to use Base.scale! to avoid requiring module qualification everywhere +""" + scale!(::MultiVector{Data, GID, PID, LID}, ::Data})::MultiVector{Data, GID, PID, LID} + +Scales the mulitvector in place and returns it +""" +function Base.scale!(vect::MultiVector{Data, GID, PID, LID}, alpha::Data)::MultiVector{Data, GID, PID, LID} where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + @. vect.data = vect.data*alpha + vect +end + +""" + scale!(::MultiVector{Data, GID, PID, LID}, ::Data)::MultiVector{Data, GID, PID, LID} + +Scales a copy of the mulitvector and returns the copy +""" +function scale(vect::MultiVector{Data, GID, PID, LID}, alpha::Data)::MultiVector{Data, GID, PID, LID} where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + scale!(copy(vect), alpha) +end + +""" + scale!(::MultiVector{Data, GID, PID, LID}, ::AbstractArray{Data, 1})::MultiVector{Data, GID, PID, LID} + +Scales each column of the mulitvector in place and returns it +""" +function Base.scale!(vect::MultiVector{Data, GID, PID, LID}, alpha::AbstractArray{Data, 1})::MultiVector{Data, GID, PID, LID} where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + for v = 1:vect.numVectors + vect.data[:, v] *= alpha[v] + end + vect +end + +""" + scale(::MultiVector{Data, GID, PID, LID}, ::AbstractArray{Data, 1})::MultiVector{Data, GID, PID, LID} + +Scales each column of a copy of the mulitvector and returns the copy +""" +function scale(vect::MultiVector{Data, GID, PID, LID}, alpha::AbstractArray{Data, 1})::MultiVector{Data, GID, PID, LID} where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + scale!(copy(vect), alpha) +end + + +function Base.dot(vect1::MultiVector{Data, GID, PID, LID}, vect2::MultiVector{Data, GID, PID, LID} + )::AbstractArray{Data} where {Data, GID, PID, LID} + numVects = numVectors(vect1) + length = localLength(vect1) + if numVects != numVectors(vect2) + throw(InvalidArgumentError("MultiVectors must have the same number of vectors to take the dot product of them")) + end + if length != localLength(vect2) + throw(InvalidArgumentError("Vectors must have the same length to take the dot product of them")) + end + dotProducts = Array{Data, 1}(numVects) + + data1 = vect1.data + data2 = vect2.data + + for vect in 1:numVects + sum = Data(0) + for i = 1:length + sum += data1[i, vect]*data2[i, vect] + end + dotProducts[vect] = sum + end + + dotProducts = sumAll(comm(vect1), dotProducts) + + dotProducts +end + +""" + getVectorView(::MultiVector{Data}, columns)::AbstractArray{Data} + +Gets a view of the requested column vector(s) in this multivector +""" +function getVectorView(mVect::MultiVector{Data}, column)::AbstractArray{Data} where {Data} + view(mVect.data, :, column) +end + +""" + getVectorCopy(::MultiVector{Data}, columns)::Array{Data} + +Gets a copy of the requested column vector(s) in this multivector +""" +function getVectorCopy(mVect::MultiVector{Data}, column)::Array{Data} where {Data} + mVect.data[:, column] +end + +function Base.fill!(mVect::MultiVector, values) + fill!(mVect.data, values) + mVect +end + +""" + commReduce(::MultiVector) + +Reduces the content of the MultiVector across all processes. Note that the MultiVector cannot be distributed globally. +""" +function commReduce(mVect::MultiVector) + #can only reduce locally replicated mutlivectors + if distributedGlobal(mVect) + throw(InvalidArgumentError("Cannot reduce distributed MultiVector")) + end + + mVect.data = sumAll(comm(mVect), mVect.data) +end + +""" +Handles the non-infinate norms +""" +macro normImpl(mVect, Data, normType) + quote + const numVects = numVectors($(esc(mVect))) + const localVectLength = localLength($(esc(mVect))) + norms = Array{$(esc(Data)), 1}(numVects) + for i = 1:numVects + sum = $(esc(Data))(0) + for j = 1:localVectLength + $(if normType == 2 + quote + val = $(esc(mVect)).data[j, i] + sum += val*val + end + else + :(sum += $(esc(Data))($(esc(mVect)).data[j, i]^$normType)) + end) + end + norms[i] = sum + end + + norms = sumAll(comm(map($(esc(mVect)))), norms) + + $(if normType == 2 + :(@. norms = sqrt(norms)) + else + :(@. norms = norms^(1/$normType)) + end) + norms + end +end + + +function norm2(mVect::MultiVector{Data})::AbstractArray{Data, 1} where Data + @normImpl mVect Data 2 +end + + +## DistObject interface ## + +function checkSizes(source::MultiVector{Data, GID, PID, LID}, + target::MultiVector{Data, GID, PID, LID})::Bool where { + Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + (source.numVectors == target.numVectors + && source.globalLength == target.globalLength + )#&& source.localLength == target.localLength) +end + + +function copyAndPermute(source::MultiVector{Data, GID, PID, LID}, + target::MultiVector{Data, GID, PID, LID}, numSameIDs::LID, + permuteToLIDs::AbstractArray{LID, 1}, permuteFromLIDs::AbstractArray{LID, 1} + ) where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + numPermuteIDs = length(permuteToLIDs) + @inbounds for j in 1:numVectors(source) + for i in 1:numSameIDs + target.data[i, j] = source.data[i, j] + end + + #don't need to sort permute[To/From]LIDs, since the orders match + for i in 1:numPermuteIDs + target.data[permutToLIDs[i], j] = source.data[permuteFromLIDs[i], j] + end + end +end + +function packAndPrepare(source::MultiVector{Data, GID, PID, LID}, + target::MultiVector{Data, GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, + distor::Distributor{GID, PID, LID})::Array where { + Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + exports = Array{Array{Data, 1}}(length(exportLIDs)) + for i = 1:length(exports) + exports[i] = source.data[exportLIDs[i], :] + end + exports +end + +function unpackAndCombine(target::MultiVector{Data, GID, PID, LID}, + importLIDs::AbstractArray{LID, 1}, imports::AbstractArray, + distor::Distributor{GID, PID, LID},cm::CombineMode) where { + Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + for i = 1:length(importLIDs) + target.data[importLIDs[i], :] = imports[i] + end +end diff --git a/src/Operator.jl b/src/Operator.jl new file mode 100644 index 0000000..2cc6d15 --- /dev/null +++ b/src/Operator.jl @@ -0,0 +1,64 @@ +export apply!, apply + + +""" +Operator is a description of all types that have a specific set of methods. ``@operatorFunctions typ`` must be called +for Operator type ``typ``. Operators must have 4 parametric types: + Data - the type of the data + GID - the type of the global indexes + PID - the type of the processor ranks + LID - the type of the local indexes + +All Operator types must implement the following methods (with Op standing in for the Operator): + +apply!(Y::MultiVector{Data, GID, PID, LID}, operator::Op{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID}, mode::TransposeMode, alpha::Data, beta::Data) + Computes ``Y = α\cdot A^{mode}\cdot X + β\cdot Y``, with the following exceptions + If beta == 0, apply MUST overwrite Y, so that any values in Y (including NaNs) are ignored. + If alpha == 0, apply MAY short-circuit the operator, so that any values in X (including NaNs) are ignored + + +getDomainMap(operator::Op{Data, GID, PID, LID})::BlockMap{GID, PID, LID} + Returns the BlockMap associated with the domain of this operation + +getRangeMap(operator::Op{Data, GID, PID, LID})::BlockMap{GID, PID, LID} + Returns the BlockMap associated with the range of this operation +The field operators contains all types that have had ``@operatorFunctions`` called on them +""" +const Operator = Any #allow Operator to be documented + + +""" + apply!(Y::MultiVector, operator, X::MultiVector, mode::TransposeMode=NO_TRANS, alpha=1, beta=0) + apply!(Y::MultiVector, operator, X::MultiVector, alpha=1, beta=0) + +Computes ``Y = α\cdot A^{mode}\cdot X + β\cdot Y``, with the following exceptions: +* If beta == 0, apply MUST overwrite Y, so that any values in Y (including NaNs) are ignored. +* If alpha == 0, apply MAY short-circuit the operator, so that any values in X (including NaNs) are ignored +""" +function apply! end + + +function apply!(Y::MultiVector{Data, GID, PID, LID}, operator::Any, X::MultiVector{Data, GID, PID, LID}, mode::TransposeMode=TransposeMode.NO_TRANS, alpha::Data=1) where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + apply!(Y, operator, X, mode, alpha, 0) +end + +function apply!(Y::MultiVector{Data, GID, PID, LID}, operator::Any, X::MultiVector{Data, GID, PID, LID}, alpha::Data, beta::Data=0) where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + apply!(Y, operator, X, NO_TRANS, alpha, beta) +end + +""" + apply(Y::MultiVector,operator, X::MultiVector, mode::TransposeMode=NO_TRANS, alpha=1, beta=0) + apply(Y::MultiVector, operator, X::MultiVector, alpha=1, beta=0) + +As `apply!` except returns a new array for the results +""" +function apply(Y::MultiVector{Data, GID, PID, LID}, operator::Any, X::MultiVector{Data, GID, PID, LID}, mode::TransposeMode=NO_TRANS, alpha::Data=1, beta=0)::MultiVector{Data, GID, PID, LID} where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + Y = copy(Y) + apply!(Y, operator, X, mode, alpha, beta) + Y +end + +function apply(Y::MultiVector{Data, GID, PID, LID}, operator::Any, X::MultiVector{Data, GID, PID, LID}, alpha::Data, beta=0)::MultiVector{Data, GID, PID, LID} where {Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + apply(Y, operator, X, NO_TRANS, alpha, beta) +end + diff --git a/src/RowGraph.jl b/src/RowGraph.jl new file mode 100644 index 0000000..12bc784 --- /dev/null +++ b/src/RowGraph.jl @@ -0,0 +1,346 @@ +export SrcDistRowGraph, DistRowGraph, RowGraph +#required methods +export getRowMap, getColMap, getDomainMap, getRangeMap, getImporter, getExporter +export getGlobalNumRows, getGlobalNumCols, getGlobalNumEntries, getGlobalNumDiags +export getLocalNumRows, getLocalNumCols, getLocalNumEntries, getLocalNumDiags +export getNumEntriesInGlobalRow, getNumEntriesInLocalRow +export getGlobalMaxNumRowEntries, getLocalMaxNumRowEntries +export hasColMap, isLowerTriangular, sUpperTriangular +export isLocallyIndexed, isGloballyIndexed, isFillComplete +export getGlobalRowCopy, getLocalRowCopy, pack +#implemented methods +export isFillActive + +""" +The version of RowMatrix that isn't a subtype of DistObject +""" +abstract type SrcDistRowGraph{GID <: Integer, PID <: Integer, LID <: Integer} <: SrcDistObject{GID, PID, LID} +end + +""" +The version of RowMatrix that is a subtype of DistObject +""" +abstract type DistRowGraph{GID <: Integer, PID <: Integer, LID <: Integer} <: DistObject{GID, PID, LID} +end + +""" +RowGraph is the base "type" for all row oriented storage graphs + +RowGraph is actually a type union of SrcDistRowGraph and DistRowGraph, +which are (direct) subtypes of SrcDistObject and DistObject, respectively. + +Instances of these types are required to implement the following submethods + + getRowMap(::RowGraph{GID, PID, LID})::BlockMap{GID, PID, LID} +Gets the row map for the graph + + getColMap(::RowGraph{GID, PID, LID})::BlockMap{GID, PID, LID} +Gets the column map for the graph + + getDomainMap(::RowGraph{GID, PID, LID})::BlockMap{GID, PID, LID} +Gets the domain map for the graph + + getRangeMap(::RowGraph{GID, PID, LID})::BlockMap{GID, PID, LID} +Gets the range map for the graph + + getImporter(::RowGraph{GID, PID, LID})::Import{GID, PID, LID} +Gets the graph's Import object + + getExporter(::RowGraph{GID, PID, LID})::Export{GID, PID, LID} +Gets the graph's Export object + + getGlobalNumRows(::RowGraph{GID})::GID +Returns the number of global rows in the graph + + getGlobalNumCols(::RowGraph{GID})::GID +Returns the number of global columns in the graph + + getLocalNumRows(::RowGraph{GID, PID, LID})::LID +Returns the number of rows owned by the calling process + + getLocalNumCols(::RowGraph{GID, PID, LID})::LID +Returns the number of columns owned by teh calling process + + getGlobalNumEntries(::RowGraph{GID, PID, LID})::GID +Returns the global number of entries in the graph + + getLocalNumEntries(::RowGraph{GID, PID, LID})::LID +Returns the local number of entries in the graph + + getNumEntriesInGlobalRow(::RowGraph{GID, PID, LID}, row::GID)::LID +Returns the current number of local entries in the given row + + getNumEntriesInLocalRow(::RowGraph{GID, PID, LID}, row::LID)::LID +Returns the current number of local entries in the given row + + getGlobalNumDiags(::RowGraph{GID, PID, LID})::GID +Returns the global number of diagonal entries + + getLocalNumDiags(::RowGraph{GID, PID, LID})::LID +Returns the local number of diagonal entries + + getGlobalMaxNumRowEntries(::RowGraph{GID, PID, LID})::LID +Returns the maximum number of entries across all rows/columns on all processors + + getLocalMaxNumRowEntries(::RowGraph{GID, PID, LID})::LID +Returns the maximum number of entries across all rows/columns on this processor + + hasColMap(::RowGraph{GID, PID, LID})::Bool +Whether the graph has a well-defined column map + + isLowerTriangular(::RowGraph{GID, PID, LID})::Bool +Whether the graph is lower trianguluar + + isUpperTriangular(::RowGraph{GID, PID, LID})::Bool +Whether the graph is upper trianguluar + + isLocallyIndexed(::RowGraph)::Bool +Whether the graph is using local indices + + isGloballyIndexed(::RowGraph)::Bool +Whether the graph is using global indices + + isFillComplete(::RowGraph) +Whether `fillComplete()` has been called + + getGlobalRowCopy(::RowGraph{GID, PID, LID}, row::GID)::AbstractArray{GID, 1} +Extracts a copy of the given row of the graph + + getLocalRowCopy(::RowGraph{GID, PID, LID}, row::LID)::AbstractArray{LID, 1} +Extracts a copy of the given row of the graph + + pack(::RowGraph{GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, distor::Distributor{GID, PID, LID})::AbstractArray{AbstractArray{LID, 1}} +Packs this object's data for import or export +""" +const RowGraph{GID <: Integer, PID <: Integer, LID <: Integer} = Union{SrcDistRowGraph{GID, PID, LID}, DistRowGraph{GID, PID, LID}} + +""" + isFillActive(::RowGraph) + +Whether the graph is being built +""" +isFillActive(graph::RowGraph) = !isFillComplete(graph) + + +""" + getNumEntriesInGlobalRow(graph::RowGraph{GID, PID, LID}, row::Integer)::LID + +Returns the current number of local entries in the given row +""" +function getNumEntriesInGlobalRow(graph::RowGraph{GID, PID, LID}, + row::Integer)::LID where{GID, PID, LID} + getNumEntriesInGlobalRow(graph, GID(row)) +end + +""" + getNumEntriesInLocalRow(::RowGraph{GID, PID, LID}, row::Integer)::LID + +Returns the current number of local entries in the given row +""" +function getNumEntriesInLocalRow(graph::RowGraph{GID, PID, LID}, + row::Integer)::LID where{GID, PID, LID} + getNumEntriesInLocalRow(graph, LID(row)) +end + +""" + getGlobalRowCopy(::RowGraph{GID, PID, LID}, row::Integer)::AbstractArray{GID, 1} + +Extracts a copy of the given row of the graph +""" +function getGlobalRowCopy(graph::RowGraph{GID, PID, LID}, + row::Integer)::AbstractArray{GID, 1} where{GID, PID, LID} + getGlobalRowCopy(graph, GID(row)) +end + +""" + getLocalRowCopy(::RowGraph{GID, PID, LID}, row::Integer)::AbstractArray{LID, 1} + +Extracts a copy of the given row of the graph +""" +function getLocalRowCopy(graph::RowGraph{GID, PID, LID}, + row::Integer)::AbstractArray{LID, 1} where{GID, PID, LID} + getLocalRowCopy(graph, LID(row)) +end + + +""" + isLocallyIndexed(::RowGraph)::Bool + +Whether the graph is using local indices +""" +isLocallyIndexed(graph::RowGraph) = !isGloballyIndexed(graph) + + +#### SrcDistObject methods #### +map(graph::RowGraph) = getRowMap(graph) + + +#### documentation for required methods #### + +""" + isFillComplete(mat::RowGraph) + +Whether `fillComplete(...)` has been called +""" +function isFillComplete end + +""" + getRowMap(::RowGraph{GID, PID, LID})::BlockMap{GID, PID, LID} + +Gets the row map for the graph +""" +function getRowMap end + +""" + getColMap(::RowGraph{GID, PID, LID})::BlockMap{GID, PID, LID} + +Gets the column map for the graph +""" +function getColMap end + +""" + hasColMap(::RowGraph{GID, PID, LID})::Bool + +Whether the graph has a well-defined column map +""" +function hasColMap end + +""" + getDomainMap(::RowGraph{GID, PID, LID})::BlockMap{GID, PID, LID} + +Gets the domain map for the graph +""" +function getDomainMap end + +""" + getRangeMap(::RowGraph{GID, PID, LID})::BlockMap{GID, PID, LID} + +Gets the range map for the graph +""" +function getRangeMap end + +""" + getImporter(::RowGraph{GID, PID, LID})::Import{GID, PID, LID} + +Gets the graph's Import object +""" +function getImporter end + +""" + getExporter(::RowGraph{GID, PID, LID})::Export{GID, PID, LID} + +Gets the graph's Export object +""" +function getExporter end + +""" + getGlobalNumRows(::RowGraph{GID})::GID + +Returns the number of global rows in the graph +""" +function getGlobalNumRows end + +""" + getGlobalNumCols(::RowGraph{GID})::GID + +Returns the number of global columns in the graph +""" +function getGlobalNumCols end #is this really a thing???? + +""" + getLocalNumRows(::RowGraph{GID, PID, LID})::LID + +Returns the number of rows owned by the calling process +""" +function geLocalNumRows end + +""" + getLocalNumCols(::RowGraph{GID, PID, LID})::LID + +Returns the number of columns owned by the calling process +""" +function getLocalNumCols end + +""" + getGlobalNumEntries(::RowGraph{GID, PID, LID})::GID + +Returns the global number of entries in the graph +""" +function getGlobalNumEntries end + +""" + getLocalNumEntries(::RowGraph{GID, PID, LID})::LID + +Returns the local number of entries in the graph +""" +function getLocalNumEntries end + +""" + getNumEntriesInGlobalRow(::RowGraph{GID, PID, LID}, row::GID)::LID + +Returns the current number of local entries in the given row +""" +function getNumEntriesInGlobalRow end + +""" + getNumEntriesInLocalRow(::RowGraph{GID, PID, LID}, row::LID)::LID + +Returns the current number of local entries in the given row +""" +function getNumEntriesInLocalRow end + +""" + getGlobalNumDiags(::RowGraph{GID, PID, LID})::GID + +Returns the global number of diagonal entries +""" +function getGlobalNumDiags end + +""" + getLocalNumDiags(::RowGraph{GID, PID, LID})::LID + +Returns the local number of diagonal entries +""" +function getLocalNumDiags end + +""" + getGlobalMaxNumRowEntries(::RowGraph{GID, PID, LID})::LID + +Returns the maximum number of entries across all rows/columns on all processors +""" +function getGlobalMaxNumRowEntries end + +""" + getLocalMaxNumRowEntries(::RowGraph{GID, PID, LID})::LID + +Returns the maximum number of entries across all rows/columns on the calling processor +""" +function getLocalMaxNumRowEntries end + +""" + isLowerTriangular(::RowGraph{GID, PID, LID})::Bool + +Whether the graph is lower trianguluar +""" +function isLowerTriangular end + +""" + isUpperTriangular(::RowGraph{GID, PID, LID})::Bool + +Whether the graph is upper trianguluar +""" +function isUpperTriangular end + +""" + isGloballyIndexed(::RowGraph)::Bool + +Whether the graph is using global indices +""" +function isGloballyIndexed end + +""" + pack(::RowGraph{GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, distor::Distributor{GID, PID, LID})::AbstractArray{AbstractArray{LID, 1}} + +Packs this object's data for import or export +""" +function pack end diff --git a/src/RowMatrix.jl b/src/RowMatrix.jl new file mode 100644 index 0000000..83c15a0 --- /dev/null +++ b/src/RowMatrix.jl @@ -0,0 +1,378 @@ + +export SrcDistRowMatrix, DistRowMatrix, RowMatrix +export isFillActive, isLocallyIndexed +export getGraph, getGlobalRowCopy, getLocalRowCopy, getGlobalRowView, getLocalRowView, getLocalDiagCopy, leftScale!, rightScale! + +""" +The version of RowMatrix that isn't a subtype of DestObject +""" +abstract type SrcDistRowMatrix{Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} <: SrcDistObject{GID, PID, LID} +end + +""" +The version of RowMatrix that is a subtype of DestObject +""" +abstract type DistRowMatrix{Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} <: DistObject{GID, PID, LID} +end + +#DECISION are any other mathmatical operations needed? + +""" +RowMatrix is the base "type" for all row oriented matrices + +RowMatrix is actually a type union of SrcDestRowMatrix and DestRowMatrix, +which are (direct) subtypes of SrcDestObject and DestObject, respectively. + +All subtypes must have the following methods, with Impl standing in for the subtype: + + getGraph(mat::RowMatrix) +Returns the graph that represents the structure of the row matrix + + getGlobalRowCopy(matrix::RowMatrix{Data, GID, PID, LID}, globalRow::Integer)::Tuple{AbstractArray{GID, 1}, Array{Data, 1}} +Returns a copy of the given row using global indices + + getLocalRowCopy(matrix::RowMatrix{Data, GID, PID, LID},localRow::Integer)::Tuple{AbstractArray{LID, 1}, AbstractArray{Data, 1}} +Returns a copy of the given row using local indices + + getGlobalRowView(matrix::RowMatrix{Data, GID, PID, LID},globalRow::Integer)::Tuple{AbstractArray{GID, 1}, AbstractArray{Data, 1}} +Returns a view to the given row using global indices + + getLocalRowView(matrix::RowMatrix{Data, GID, PID, LID},localRow::Integer)::Tuple{AbstractArray{GID, 1}, AbstractArray{Data, 1}} +Returns a view to the given row using local indices + + getLocalNumDiags(mat::RowMatrix) +Returns the number of diagonal element on the calling processor + + getLocalDiagCopy(matrix::RowMatrix{Data, GID, PID, LID})::MultiVector{Data, GID, PID, LID} +Returns a copy of the diagonal elements on the calling processor + + leftScale!(matrix::Impl{Data, GID, PID, LID}, X::AbstractArray{Data, 1}) +Scales matrix on the left with X + + rightScale!(matrix::Impl{Data, GID, PID, LID}, X::AbstractArray{Data, 1}) +Scales matrix on the right with X + + pack(::RowGraph{GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, distor::Distributor{GID, PID, LID})::AbstractArray{AbstractArray{LID, 1}} +Packs this object's data for import or export + + +Additionally, the following method must be implemented to fufil the operator interface: + + apply!(matrix::RowMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID}, Y::MultiVector{Data, GID, PID, LID}, mode::TransposeMode, alpha::Data, beta::Data) + +However, the following methods are implemented by redirecting the call to the matrix's graph by calling `getGraph(matrix)`. + domainMap(operator::RowMatrix{Data, GID, PID, LID})::BlockMap{GID, PID, LID} + + rangeMap(operator::RowMatrix{Data, GID, PID, LID})::BlockMap{GID, PID, LID} + +The required methods from DistObject must also be implemented. `map(...)`, as required by SrcDistObject, is implemented to forward the call to `rowMap(...)` + + +The following methods are currently implemented by redirecting the call to the matrix's graph by calling `getGraph(matrix)`. It is recommended that the implmenting class implements these more efficiently. + + isFillComplete(mat::RowMatrix) +Whether `fillComplete(...)` has been called + getRowMap(mat::RowMatrix) +Returns the BlockMap associated with the rows of this matrix + hasColMap(mat::RowMatrix) +Whether the matrix has a column map + getColMap(mat::RowMatrix) +Returns the BlockMap associated with the columns of this matrix + isGloballyIndexed(mat::RowMatrix) +Whether the matrix stores indices with global indexes + getGlobalNumRows(mat::RowMatrix) +Returns the number of rows across all processors + getGlobalNumCols(mat::RowMatrix) +Returns the number of columns across all processors + getLocalNumRows(mat::RowMatrix) +Returns the number of rows on the calling processor + getLocalNumCols(mat::RowMatrix) +Returns the number of columns on the calling processor + getGlobalNumEntries(mat::RowMatrix) +Returns the number of entries across all processors + getLocalNumEntries(mat::RowMatrix) +Returns the number of entries on the calling processor + getNumEntriesInGlobalRow(mat::RowMatrix, globalRow) +Returns the number of entries on the local processor in the given row + getNumEntriesInLocalRow(mat::RowMatrix, localRow) +Returns the number of entries on the local processor in the given row + getGlobalNumDiags(mat::RowMatrix) +Returns the number of diagonal elements across all processors + getGlobalMaxNumRowEntries(mat::RowMatrix) +Returns the maximum number of row entries across all processors + getLocalMaxNumRowEntries(mat::RowMatrix) +Returns the maximum number of row entries on the calling processor + isLowerTriangular(mat::RowMatrix) +Whether the matrix is lower triangular + isUpperTriangular(mat::RowMatrix) +Whether the matrix is upper triangular +""" +const RowMatrix{Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} = Union{SrcDistRowMatrix{Data, GID, PID, LID}, DistRowMatrix{Data, GID, PID, LID}} + + + +function leftScale!(matrix::RowMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID}) where { + Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + if numVectors(X) != 1 + throw(InvalidArgumentError("Can only scale CRS matrix with column vector, not multi vector")) + end + leftScale!(matrix, X.data) +end + +function rightScale!(matrix::RowMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID}) where { + Data <: Number, GID <: Integer, PID <: Integer, LID <: Integer} + if numVectors(X) != 1 + throw(InvalidArgumentError("Can only scale CRS matrix with column vector, not multi vector")) + end + rightScale!(matrix, X.data) +end + +isFillActive(matrix::RowMatrix) = !isFillComplete(matrix) +isLocallyIndexed(matrix::RowMatrix) = !isGloballyIndexed(matrix) + +#for SrcDistObject +function map(matrix::RowMatrix) + getRowMap(matrix) +end + + +#TODO document +function getLocalDiagCopyWithoutOffsetsNotFillComplete(A::RowMatrix{Data, GID, PID, LID})::MultiVector{Data, GID, PID, LID} where {Data, GID, PID, LID} + + localRowMap = getLocalMap(getRowMap(A)) + localColMap = getLocalMap(getColMap(A)) + sorted = isSorted(A.myGraph) + + localNumRows = getLocalNumRows(A) + diag = MultiVector{Data, GID, PID, LID}(getRowMap(A), 1) + diagLocal1D = getVectorView(diag, 1) + + range = 1:localNumRows + for localRowIndex in range + diagLocal1D[localRowIndex] = 0 + globalIndex = gid(localRowMap, localRowIndex) + localColIndex = lid(localColMap, globalIndex) + if localColIndex != 0 + indices, values = getLocalRowView(A, localRowIndex) + + if !sorted + offset = findfirst(indices, localColumnIndex) + else + offset = searchsorted(indices, localColumnIndex) + end + + if offset <= length(indices) + diagLocal1D[localRowIndex] = values[offset] + end + end + end + diag +end + + +#### default implementations using getGraph(...) #### +""" + isFillComplete(mat::RowMatrix) + +Whether `fillComplete(...)` has been called +""" +isFillComplete(mat::RowMatrix) = isFillComplete(getGraph(mat)) + +""" + getRowMap(::RowMatrix{Data, GID, PID, LID})::BlockMap{GID, PID, LID} + +Gets the row map for the container +""" +getRowMap(mat::RowMatrix) = getRowMap(getGraph(mat)) + +""" + getColMap(::RowMatrix{Data, GID, PID, LID})::BlockMap{GID, PID, LID} + +Gets the column map for the container +""" +getColMap(mat::RowMatrix) = getColMap(getGraph(mat)) + +""" + hasColMap(::RowMatrix)::Bool + +Whether the container has a well-defined column map +""" +hasColMap(mat::RowMatrix) = hasColMap(getGraph(mat)) + +""" + isGloballyIndexed(mat::RowMatrix) + +Whether the matrix stores indices with global indexes +""" +isGloballyIndexed(mat::RowMatrix) = isGloballyIndexed(getGraph(mat)) + +""" + getGlobalNumRows(mat::RowMatrix) + +Returns the number of rows across all processors +""" +getGlobalNumRows(mat::RowMatrix) = getGlobalNumRows(getGraph(mat)) + +""" + getGlobalNumCols(mat::RowMatrix) + +Returns the number of columns across all processors +""" +getGlobalNumCols(mat::RowMatrix) = getGlobalNumCols(getGraph(mat)) + +""" + getLocalNumRows(mat::RowMatrix) + +Returns the number of rows on the calling processor +""" +getLocalNumRows(mat::RowMatrix) = getLocalNumRows(getGraph(mat)) + +""" + getLocalNumCols(mat::RowMatrix) + +Returns the number of columns on the calling processor +""" +getLocalNumCols(mat::RowMatrix) = getLocalNumCols(getGraph(mat)) + +""" + getGlobalNumEntries(mat::RowMatrix) + +Returns the number of entries across all processors +""" +getGlobalNumEntries(mat::RowMatrix) = getGlobalNumEntries(getGraph(mat)) + +""" + getLocalNumEntries(mat::RowMatrix) + +Returns the number of entries on the calling processor +""" +getLocalNumEntries(mat::RowMatrix) = getLocalNumEntries(getGraph(mat)) + +""" + getNumEntriesInGlobalRow(mat::RowMatrix, globalRow) + +Returns the number of entries on the local processor in the given row +""" +getNumEntriesInGlobalRow(mat::RowMatrix, globalRow) = getNumEntriesInGlobalRow(getGraph(mat), globalRow) + +""" + getNumEntriesInLocalRow(mat::RowMatrix, localRow) + +Returns the number of entries on the local processor in the given row +""" +getNumEntriesInLocalRow(mat::RowMatrix, localRow) = getNumEntriesInLocalRow(getGraph(mat), localRow) + +""" + getGlobalNumDiags(mat::RowMatrix) + +Returns the number of diagonal elements across all processors +""" +getGlobalNumDiags(mat::RowMatrix, gRow) = getGlobalNumDiags(getGraph(mat), gRow) + +""" + getLocalNumDiags(mat::RowMatrix) + +Returns the number of diagonal element on the calling processor +""" +getLocalNumDiags(mat::RowMatrix, lRow) = getLocalNumDiags(getGraph(mat), lRow) + +""" + getGlobalMaxNumRowEntries(mat::RowMatrix) + +Returns the maximum number of row entries across all processors +""" +getGlobalMaxNumRowEntries(mat::RowMatrix) = getGlobalMaxNumRowEntries(getGraph(mat)) + +""" + getLocalMaxNumRowEntries(mat::RowMatrix) + +Returns the maximum number of row entries on the calling processor +""" +getLocalMaxNumRowEntries(mat::RowMatrix) = getLocalMaxNumRowEntries(getGraph(mat)) + +""" + isLowerTriangular(mat::RowMatrix) + +Whether the matrix is lower triangular +""" +isLowerTriangular(mat::RowMatrix) = isLowerTriangular(getGraph(mat)) + +""" + isUpperTriangular(mat::RowMatrix) + +Whether the matrix is upper triangular +""" +isUpperTriangular(mat::RowMatrix) = isUpperTriangular(getGraph(mat)) + +""" + pack(::RowMatrix{GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, distor::Distributor{GID, PID, LID})::AbstractArray{AbstractArray{GID, 1}, AbstractArray{Data, 1}} + +Packs this object's data for import or export +""" +function pack(mat::RowMatrix{Data, GID, PID, LID}, exportLIDs::AbstractArray{LID, 1}, + distor::Distributor{GID, PID, LID}) where {Data, GID, PID, LID} + throw(InvalidStateError("No pack implementation for objects of type $(typeof(mat))")) +end + + +getDomainMap(mat::RowMatrix) = getDomainMap(getGraph(mat)) +getRangeMap(mat::RowMatrix) = getRangeMap(getGraph(mat)) + + +#### required method documentation stubs #### + +""" + getGraph(mat::RowMatrix) + +Returns the graph that represents the structure of the row matrix +""" +function getGraph end + +""" + getGlobalRowCopy(matrix::RowMatrix{Data, GID, PID, LID}, globalRow::Integer)::Tuple{AbstractArray{GID, 1}, AbstractArray{Data, 1}} + +Returns a copy of the given row using global indices +""" +function getGlobalRowCopy end + +""" + getLocalRowCopy(matrix::RowMatrix{Data, GID, PID, LID},localRow::Integer)::Tuple{AbstractArray{LID, 1}, AbstractArray{Data, 1}} + +Returns a copy of the given row using local indices +""" +function getLocalRowCopy end + +""" + getGlobalRowView(matrix::RowMatrix{Data, GID, PID, LID},globalRow::Integer)::Tuple{AbstractArray{GID, 1}, AbstractArray{Data, 1}} + +Returns a view to the given row using global indices +""" +function getGlobalRowView end + +""" + getLocalRowView(matrix::RowMatrix{Data, GID, PID, LID},localRow::Integer)::Tuple{AbstractArray{GID, 1}, AbstractArray{Data, 1}} + +Returns a view to the given row using local indices +""" +function getLocalRowView end + +""" + getLocalDiagCopy(matrix::RowMatrix{Data, GID, PID, LID})::MultiVector{Data, GID, PID, LID} + +Returns a copy of the diagonal elements on the calling processor +""" +function getLocalDiagCopy end + +""" + leftScale!(matrix::Impl{Data, GID, PID, LID}, X::AbstractArray{Data}) + +Scales matrix on the left with X +""" +function leftScale! end + +""" + rightScale!(matrix::Impl{Data, GID, PID, LID}, X::AbstractArray{Data}) + +Scales matrix on the right with X +""" +function rightScale! end \ No newline at end of file diff --git a/src/SerialComm.jl b/src/SerialComm.jl new file mode 100644 index 0000000..0611902 --- /dev/null +++ b/src/SerialComm.jl @@ -0,0 +1,58 @@ + +export SerialComm + +""" + SerialComm() + +Gets an serial communication instance. +Serial communication results in mostly no-ops for the communication operations +""" +struct SerialComm{GID <: Integer, PID <:Integer, LID <: Integer} <: Comm{GID, PID, LID} +end + + +# most of these functions are no-ops or identify functions since there is only +# one processor + +function barrier(comm::SerialComm) +end + + +function broadcastAll(comm::SerialComm, myVals::AbstractArray{T}, root::Integer)::Array{T} where T + if root != 1 + throw(InvalidArgumentError("SerialComm can only accept PID of 1")) + end + myVals +end + +function gatherAll(comm::SerialComm, myVals::AbstractArray{T})::Array{T} where T + myVals +end + +function sumAll(comm::SerialComm, partialsums::AbstractArray{T})::Array{T} where T + partialsums +end + +function maxAll(comm::SerialComm, partialmaxes::AbstractArray{T})::Array{T} where T + partialmaxes +end + +function minAll(comm::SerialComm, partialmins::AbstractArray{T})::Array{T} where T + partialmins +end + +function scanSum(comm::SerialComm, myvals::AbstractArray{T})::Array{T} where T + myvals +end + +function myPid(comm::SerialComm{GID, PID})::PID where GID <: Integer where PID <: Integer + 1 +end + +function numProc(comm::SerialComm{GID, PID})::PID where GID <: Integer where PID <: Integer + 1 +end + +function createDistributor(comm::SerialComm{GID, PID, LID})::SerialDistributor{GID, PID, LID} where GID <: Integer where PID <: Integer where LID <: Integer + SerialDistributor{GID, PID, LID}() +end diff --git a/src/SerialDistributor.jl b/src/SerialDistributor.jl new file mode 100644 index 0000000..94b0ad8 --- /dev/null +++ b/src/SerialDistributor.jl @@ -0,0 +1,74 @@ +export SerialDistributor + +""" + SerialDistributor() + +Creates a distributor to work with SerialComm +""" +type SerialDistributor{GID <: Integer, PID <:Integer, LID <: Integer} <: Distributor{GID, PID, LID} + post::Nullable{AbstractArray} + reversePost::Nullable{AbstractArray} + + function SerialDistributor{GID, PID, LID}() where GID <: Integer where PID <: Integer where LID <: Integer + new(nothing, nothing) + end +end + + +function createFromSends(dist::SerialDistributor{GID, PID, LID}, + exportPIDs::AbstractArray{PID})::Integer where GID <: Integer where PID <: Integer where LID <: Integer + for id in exportPIDs + if id != 1 + throw(InvalidArgumentError("SerialDistributor can only accept PID of 1")) + end + end + length(exportPIDs) +end + +function createFromRecvs( + dist::SerialDistributor{GID, PID, LID}, remoteGIDs::AbstractArray{GID}, remotePIDs::AbstractArray{PID} + )::Tuple{AbstractArray{GID}, AbstractArray{PID}} where GID <: Integer where PID <: Integer where LID <: Integer + for id in remotePIDs + if id != 1 + throw(InvalidArgumentError("SerialDistributor can only accept PID of 1")) + end + end + remoteGIDs,remotePIDs +end + +function resolve(dist::SerialDistributor, exportObjs::AbstractArray{T})::AbstractArray{T} where T + exportObjs +end + +function resolveReverse(dist::SerialDistributor, exportObjs::AbstractArray{T})::AbstractArray{T} where T + exportObjs +end + +function resolvePosts(dist::SerialDistributor, exportObjs::AbstractArray) + dist.post = Nullable(exportObjs) +end + +function resolveWaits(dist::SerialDistributor)::AbstractArray + if isnull(dist.post) + throw(InvalidStateError("Must post before waiting")) + end + + result = get(dist.post) + dist.post = Nullable{AbstractArray}() + result +end + +function resolveReversePosts(dist::SerialDistributor, exportObjs::AbstractArray) + dist.reversePost = Nullable(exportObjs) +end + +function resolveReverseWaits(dist::SerialDistributor)::AbstractArray + if isnull(dist.reversePost) + throw(InvalidStateError("Must reverse post before reverse waiting")) + end + + result = get(dist.reversePost) + dist.reversePost = Nullable{AbstractArray}() + result +end + diff --git a/src/SparseRowView.jl b/src/SparseRowView.jl new file mode 100644 index 0000000..681351d --- /dev/null +++ b/src/SparseRowView.jl @@ -0,0 +1,45 @@ + +export SparseRowView, vals, cols + +""" + SparseRowView(vals::AbstractArray{Data}, cols::AbstractArray{IndexType}, count::Integer=length(vals), start::Integer=1, stride::Integer=1) + +Creates a view of a sparse row +""" +struct SparseRowView{Data, IndexType <: Integer} + vals::AbstractArray{Data} + cols::AbstractArray{IndexType} + + function SparseRowView(vals::AbstractArray{Data}, cols::AbstractArray{IndexType} + ) where {Data, IndexType} + if length(vals) != length(cols) + throw(InvalidArgumentError("length(vals) = $(length(vals)) " + * "!= length(cols) = $(length(cols))")) + end + new{Data, IndexType}(vals, cols) + end +end + +function SparseRowView(vals::AbstractArray{Data}, cols::AbstractArray{IndexType}, + count::Integer, start::Integer=1, stride::Integer=1) where {Data, IndexType} + SparseRowView(view(vals, range(start, stride, count)), + view(cols, range(start, stride, count))) +end + +function Base.nnz(row::SparseRowView{Data, IndexType}) where{Data, IndexType} + IndexType(length(row.vals)) +end + +""" + vals(::SparseRowView{Data, IndexType})::AbstractArray{Data, 1} + +Gets the values of the row +""" +vals(row::SparseRowView) = row.vals + +""" + cols(::SparseRowView{Data, IndexType})::AbstractArray{IndexType, 1} + +Gets the column indices of the row +""" +cols(row::SparseRowView) = row.cols \ No newline at end of file diff --git a/src/SrcDistObject.jl b/src/SrcDistObject.jl new file mode 100644 index 0000000..6b5691c --- /dev/null +++ b/src/SrcDistObject.jl @@ -0,0 +1,34 @@ + +""" +A base type for supporting flexible source distributed objects for import/export operations. + +Subtypes must implement a map(::Impl{GID, PID, LID})::BlockMap{GID, PID, LID} method, +where Impl is the subtype +""" +abstract type SrcDistObject{GID <: Integer, PID <: Integer, LID <: Integer} +end + +""" +Returns true if this object is a distributed global +""" +function distributedGlobal(obj::SrcDistObject) + distributedGlobal(map(obj)) +end + + +""" +Get's the Comm instance being used by this object +""" +function comm(obj::SrcDistObject) + comm(map(obj)) +end + + +#### required method documentation stubs #### + +""" + map(obj::SrcDistObject{GID, PID, LID})::BlockMap{GID, PID, LID} + +Gets the `BlockMap` associated with the given SrcDistObject +""" +function map end \ No newline at end of file diff --git a/test/BasicDirectoryTests.jl b/test/BasicDirectoryTests.jl new file mode 100644 index 0000000..fd6bd38 --- /dev/null +++ b/test/BasicDirectoryTests.jl @@ -0,0 +1,16 @@ + +function basicDirectoryTests(comm::Comm{GID, PID, LID}) where {GID, PID, LID} + const n = 8 + const nProc = numProc(comm) + const pid = myPid(comm) + + map = BlockMap(n*nProc, n, comm) + + dir = BasicDirectory{GID, PID, LID}(map) + @test isa(dir, BasicDirectory{GID, PID, LID}) + + dir = createDirectory(comm, map) + @test isa(dir, BasicDirectory{GID, PID, LID}) + @test gidsAllUniquelyOwned(dir) + @test (repeat(1:nProc, inner=n), repeat(1:n, outer=nProc)) == getDirectoryEntries(dir, map, AbstractArray{GID}(1:n*nProc)) +end \ No newline at end of file diff --git a/test/BlockMapTests.jl b/test/BlockMapTests.jl new file mode 100644 index 0000000..79edf7a --- /dev/null +++ b/test/BlockMapTests.jl @@ -0,0 +1,146 @@ +#### Test BlockMap with SerialComm #### + +function SerialMapTests(map::BlockMap{Int, Int, Int}, map2::BlockMap{Int, Int, Int}, diffMap::BlockMap{Int, Int, Int}) +# quote + mapCopy = BlockMap{Int, Int, Int}(map) + + @test uniqueGIDs(map) + + for i = 1:5 + @test myLID(map, i) + @test myGID(map, i) + @test i == lid(map, i) + @test i == gid(map, i) + end + + @test !myLID(map, -1) + @test !myLID(map, 0) + @test !myLID(map, 6) + @test !myLID(map, 30) + @test !myGID(map, -1) + @test !myGID(map, 0) + @test !myGID(map, 6) + @test !myGID(map, 30) + + @test 0 == lid(map, -1) + @test 0 == lid(map, 0) + @test 0 == lid(map, 6) + @test 0 == lid(map, 30) + @test 0 == gid(map, -1) + @test 0 == gid(map, 0) + @test 0 == gid(map, 6) + @test 0 == gid(map, 30) + + @test !distributedGlobal(map) + + @test 5 == numGlobalElements(map) + @test 5 == numMyElements(map) + + @test 1 == minMyGID(map) + @test 5 == maxMyGID(map) + @test 1 == minAllGID(map) + @test 5 == maxAllGID(map) + @test 1 == minLID(map) + @test 5 == maxLID(map) + + @test ([1, 1, 1, 1, 1], [1, 2, 3, 4, 5]) == remoteIDList(map, [1, 2, 3, 4, 5]) + + @test [1, 2, 3, 4, 5] == myGlobalElements(map) + + @test sameBlockMapDataAs(map, mapCopy) + @test sameBlockMapDataAs(mapCopy, map) + @test !sameBlockMapDataAs(map, map2) + @test !sameBlockMapDataAs(map2, map) + @test !sameBlockMapDataAs(map, diffMap) + @test !sameBlockMapDataAs(diffMap, map) + + @test sameAs(map, mapCopy) + @test sameAs(mapCopy, map) + @test sameAs(map, map2) + @test sameAs(map2, map) + @test !sameAs(map, diffMap) + @test !sameAs(diffMap, map) + @test !sameAs(map2, diffMap) + @test !sameAs(diffMap, map2) + + @test linearMap(map) + + @test [1, 2, 3, 4, 5] == myGlobalElementIDs(map) + + @test commVal == comm(map) +# end +end + +commVal = SerialComm{Int, Int, Int}() + + +## constructor 1 ## +@test_throws InvalidArgumentError BlockMap(-8, commVal) +@test_throws InvalidArgumentError BlockMap(-1, commVal) + +BlockMap(0, commVal) +BlockMap(1, commVal) + +map = BlockMap(5, commVal) +map2 = BlockMap(5, commVal) +diffMap = BlockMap(6, commVal) +#@SerialMapTests +SerialMapTests(map, map2, diffMap) + +## constructor 2 ## +@test_throws InvalidArgumentError BlockMap(-8, 4, commVal) +@test_throws InvalidArgumentError BlockMap(-2, 4, commVal) +@test_throws InvalidArgumentError BlockMap(5, -6, commVal) +@test_throws InvalidArgumentError BlockMap(4, -1, commVal) + +BlockMap(0, 0, commVal) +BlockMap(1, 1, commVal) + +map = BlockMap(5, 5, commVal) +map2 = BlockMap(5, 5, commVal) +diffMap = BlockMap(6, 6, commVal) +#@SerialMapTests +SerialMapTests(map, map2, diffMap) + +map = BlockMap(-1, 5, commVal) +map2 = BlockMap(-1, 5, commVal) +diffMap = BlockMap(-1, 6, commVal) +#@SerialMapTests +SerialMapTests(map, map2, diffMap) + + +## constructor 3 ## +BlockMap(Int[], commVal) +BlockMap([1], commVal) + +map = BlockMap([1, 2, 3, 4, 5], commVal) +map2 = BlockMap([1, 2, 3, 4, 5], commVal) +diffMap = BlockMap([1, 2, 3, 4, 5, 6], commVal) +#@SerialMapTests +SerialMapTests(map, map2, diffMap) + +## constructor 4 ## +@test_throws InvalidArgumentError BlockMap(-8, 4, [1, 2, 3, 4], false, 1, 4, commVal) +@test_throws InvalidArgumentError BlockMap(-2, 4, [1, 2, 3, 4], false, 1, 4, commVal) +@test_throws InvalidArgumentError BlockMap(5, -6, [1, 2, 3, 4, 5], false, 1, 5, commVal) +@test_throws InvalidArgumentError BlockMap(4, -1, [1, 2, 3, 4], false, 1, 4, commVal) + +BlockMap(0, 0, Int[], false, 1, 0, commVal) +BlockMap(1, 1, [1], false, 1, 1, commVal) + +map = BlockMap(5, 5, [1, 2, 3, 4, 5], false, 1, 5, commVal) +map2 = BlockMap(5, 5, [1, 2, 3, 4, 5], false, 1, 5, commVal) +diffMap = BlockMap(6, 6, [1, 2, 3, 4, 5, 6], false, 1, 6, commVal) +#@SerialMapTests +SerialMapTests(map, map2, diffMap) + + +#stability tests +for (GID, PID, LID) in Base.product(stableGIDs, stablePIDs, stableLIDs) + @test is_stable(check_method(gid, (BlockMap{GID, PID, LID}, LID))) + @test is_stable(check_method(lid, (BlockMap{GID, PID, LID}, GID))) + @test is_stable(check_method(myGID, (BlockMap{GID, PID, LID}, GID))) + @test is_stable(check_method(myLID, (BlockMap{GID, PID, LID}, LID))) + @test is_stable(check_method(numGlobalElements, (BlockMap{GID, PID, LID},))) + @test is_stable(check_method(numMyElements, (BlockMap{GID, PID, LID},))) +end diff --git a/test/CRSGraphTests.jl b/test/CRSGraphTests.jl new file mode 100644 index 0000000..e4e0008 --- /dev/null +++ b/test/CRSGraphTests.jl @@ -0,0 +1,46 @@ + +#a few light tests to catch basic issues + +commObj = SerialComm{UInt32, UInt8, UInt16}() +map = BlockMap(20, commObj) + +function basicTests(graph) + @test !isLocallyIndexed(graph) + @test isGloballyIndexed(graph) + @test isFillActive(graph) + @test !isFillComplete(graph) + @test !hasColMap(graph) + @test 0 == getNumEntriesInGlobalRow(graph, 1) +end + +graph = CRSGraph(map, UInt16(15), STATIC_PROFILE, Dict{Symbol, Any}()) +@test map == JuliaPetra.map(graph) +@test STATIC_PROFILE == getProfileType(graph) +basicTests(graph) + +graph = CRSGraph(map, UInt16(15), STATIC_PROFILE, Dict{Symbol, Any}(:debug=>true)) +@test map == JuliaPetra.map(graph) +@test STATIC_PROFILE == getProfileType(graph) +basicTests(graph) +insertGlobalIndices(graph, 1, [2, 3]) +@test 2 == getNumEntriesInGlobalRow(graph, 1) + +graph2 = CRSGraph(map, UInt16(15), DYNAMIC_PROFILE, Dict{Symbol, Any}(:debug=>true)) +@test map == JuliaPetra.map(graph2) +@test DYNAMIC_PROFILE == getProfileType(graph2) +basicTests(graph2) +@test 0 == getNumEntriesInGlobalRow(graph2, 1) + +impor = Import(map, map) +doImport(graph, graph2, impor, REPLACE) +@test 2 == getNumEntriesInGlobalRow(graph2, 1) + + +commObj = SerialComm{UInt8, Int8, UInt16}() +map = BlockMap(20, commObj) + +@test_throws InvalidArgumentError CRSGraph(map, UInt16(15), STATIC_PROFILE, Dict{Symbol, Any}()) + + +#TODO ensure result of CRSGraph(rowMap, colMap, localGraph, plist) is fill complete +#TODO ensure CRSGraph(rowMap, colMap, rowOffsets, entries, plist) sets up local graph correctly (same length and content in local graph as was given to constructor) diff --git a/test/CSRMatrixMPITests.jl b/test/CSRMatrixMPITests.jl new file mode 100644 index 0000000..4b5572c --- /dev/null +++ b/test/CSRMatrixMPITests.jl @@ -0,0 +1,62 @@ + +n = 2 +nProc = numProc(comm) +Data = Float32 + +######Build matrix###### +map = BlockMap(nProc*n, comm) + +numMyElts = numMyElements(map) +numGlobalElts = numGlobalElements(map) +myGlobalElts = myGlobalElements(map) + +numNz = Array{GID, 1}(numMyElts) +for i = 1:numMyElts + if myGlobalElts[i] == 1 || myGlobalElts[i] == numGlobalElts + numNz[i] = 2 + else + numNz[i] = 3 + end +end + +const A = CSRMatrix{Data}(map, numNz, STATIC_PROFILE) + +const values = Data[-1, -1] +indices = Array{LID, 1}(2) +two = Data[2] + +for i = 1:numMyElts + if myGlobalElts[i] == 1 + indices = LID[2] + elseif myGlobalElts[i] == numGlobalElts + indices = LID[numGlobalElts-2] + else + indices = LID[myGlobalElts[i]-1, myGlobalElts[i]+1] + end + + insertGlobalValues(A, myGlobalElts[i], indices, values) + insertGlobalValues(A, myGlobalElts[i], LID[myGlobalElts[i]], two) +end + +fillComplete(A, map, map) + + +Y = MultiVector(map, diagm(Data(1):n)) +X = MultiVector(map, fill(Data(2), n, n)) + +@test Y === apply!(Y, A, X, NO_TRANS, Float32(3), Float32(.5)) + +@test fill(2, n, n) == X.data #ensure X isn't mutated + +exp = diagm(Data(1):n)*.5 +for i in 1:n + if i == 1 && pid == 1 + exp[1, :] += 6 + elseif i == n && pid == nProc + exp[i, :] += 6 + #else + #exp[i, :] += -6 +12-6 + end +end + +@test exp == Y.data \ No newline at end of file diff --git a/test/CSRMatrixTests.jl b/test/CSRMatrixTests.jl new file mode 100644 index 0000000..0d5f3a4 --- /dev/null +++ b/test/CSRMatrixTests.jl @@ -0,0 +1,121 @@ + +#TODO write MPI tests + +#TODO ensure result of CSRMatrix(rowMap, colMap, localMatrix, plist) is fill complete + +#TODO make testing version of checkInternalState to run during testing + +#### Serial Tests####s + +n = 8 +m = 6 + +Data = Float32 +GID = UInt16 +PID = Bool +LID = UInt8 + +commObj = SerialComm{GID, PID, LID}() +rowMap = BlockMap(n, n, commObj) + + +mat = CSRMatrix{Data}(rowMap, m, STATIC_PROFILE) +@test isa(mat, CSRMatrix{Data, GID, PID, LID}) + +mat = CSRMatrix{Data}(rowMap, m, STATIC_PROFILE, Dict{Symbol, Any}()) +@test isa(mat, CSRMatrix{Data, GID, PID, LID}) +@test STATIC_PROFILE == getProfileType(mat) +@test isFillActive(mat) +@test !isFillComplete(mat) +@test isGloballyIndexed(mat) +@test !isLocallyIndexed(mat) +@test rowMap == getRowMap(mat) +@test !hasColMap(mat) +@test n == getGlobalNumRows(mat) +@test n == getLocalNumRows(mat) + + +@test 0 == getNumEntriesInLocalRow(mat, 2) +@test 0 == getNumEntriesInGlobalRow(mat, 2) +@test 0 == getLocalNumEntries(mat) +@test 0 == getGlobalNumEntries(mat) +@test 0 == getGlobalNumDiags(mat) +@test 0 == getLocalNumDiags(mat) +@test 0 == getGlobalMaxNumRowEntries(mat) +@test 0 == getLocalMaxNumRowEntries(mat) +insertGlobalValues(mat, 2, LID[1, 3, 4], Data[2.5, 6.21, 77]) +@test 3 == getNumEntriesInLocalRow(mat, 2) +@test 3 == getNumEntriesInGlobalRow(mat, 2) +@test 3 == getLocalNumEntries(mat) +@test 0 == getLocalNumDiags(mat) +rowInfo = JuliaPetra.getRowInfo(mat.myGraph, LID(2)) +@test 3 == rowInfo.numEntries +JuliaPetra.recycleRowInfo(rowInfo) +#skipped many of the global methods because those require re-generating and may not be up to date + +row = getGlobalRowCopy(mat, 2) +@test isa(row, Tuple{<: AbstractArray{GID, 1}, <: AbstractArray{Data, 1}}) +@test GID[1, 3, 4] == row[1] +@test Data[2.5, 6.21, 77] == row[2] + +row = getGlobalRowView(mat, 2) +@test isa(row, Tuple{<: AbstractArray{GID, 1}, <: AbstractArray{Data, 1}}) +@test GID[1, 3, 4] == row[1] +@test Data[2.5, 6.21, 77] == row[2] + +fillComplete(mat) + +row = getLocalRowCopy(mat, 2) +@test isa(row, Tuple{<: AbstractArray{LID, 1}, <: AbstractArray{Data, 1}}) +@test LID[1, 2, 3] == row[1] +@test Data[2.5, 6.21, 77] == row[2] + +row = getLocalRowView(mat, 2) +@test isa(row, Tuple{<: AbstractArray{LID, 1}, <: AbstractArray{Data, 1}}) +@test LID[1, 2, 3] == row[1] +@test Data[2.5, 6.21, 77] == row[2] + +#= + +getGlobalNumCols(mat::CSRMatrix) = -1#TODO figure out +getLocalNumCols(mat::CSRMatrix) = numCols(mat.localMatrix) +=# + +map = BlockMap(2, 2, commObj) + +mat = CSRMatrix{Data}(map, 2, STATIC_PROFILE) +insertGlobalValues(mat, 1, LID[1, 2], Data[2, 3]) +insertGlobalValues(mat, 2, LID[1, 2], Data[5, 7]) +fillComplete(mat) + +@test [1, 2, 1, 2] == (mat.myGraph.localIndices1D) +@test (LID[1, 2], Data[2, 3]) == getLocalRowCopy(mat, 1) +@test (LID[1, 2], Data[5, 7]) == getLocalRowCopy(mat, 2) +@test (LID[1, 2], Data[2, 3]) == getLocalRowView(mat, 1) +@test (LID[1, 2], Data[5, 7]) == getLocalRowView(mat, 2) + + + +Y = MultiVector(map, diagm(Data(1):2)) +X = MultiVector(map, fill(Data(2), 2, 2)) + +@test Y === apply!(Y, mat, X, NO_TRANS, Data(3), Data(.5)) + +@test fill(2, 2, 2) == X.data #ensure X isn't mutated +exp = Array{Data, 2}(2, 2) +exp[1, :] = [30.5, 30] +exp[2, :] = [72, 73] +@test exp == Y.data + + + +Y = MultiVector(map, diagm(Data(1):2)) +X = MultiVector(map, fill(Data(2), 2, 2)) #ensure bugs in the previous test don't affect this test + +@test Y === apply!(Y, mat, X, TRANS, Float32(3), Float32(.5)) + +@test fill(2, 2, 2) == X.data #ensure X isn't mutated +exp = Array{Data, 2}(2, 2) +exp[1, :] = [42.5, 42] +exp[2, :] = [60, 61] +@test exp == Y.data diff --git a/test/ComputeOffsetsTests.jl b/test/ComputeOffsetsTests.jl new file mode 100644 index 0000000..cb5e614 --- /dev/null +++ b/test/ComputeOffsetsTests.jl @@ -0,0 +1,13 @@ + +rowPtrs = Array{Int, 1}(30) +JuliaPetra.computeOffsets(rowPtrs, 9) +@test collect(1:9:9*30) == rowPtrs + +rowPtrs = Array{Int, 1}(30) +numEnts = collect(1:29) +@test sum(1:29) == JuliaPetra.computeOffsets(rowPtrs, numEnts) +@test [sum(numEnts[1:i-1])+1 for i = 1:30] == rowPtrs + +@test_throws InvalidArgumentError JuliaPetra.computeOffsets(rowPtrs, Array{Int, 1}(30)) +@test_throws InvalidArgumentError JuliaPetra.computeOffsets(rowPtrs, Array{Int, 1}(31)) +@test_throws InvalidArgumentError JuliaPetra.computeOffsets(rowPtrs, Array{Int, 1}(42)) diff --git a/test/Import-Export Tests.jl b/test/Import-Export Tests.jl new file mode 100644 index 0000000..5832f7c --- /dev/null +++ b/test/Import-Export Tests.jl @@ -0,0 +1,61 @@ +#mainly just exercies the basic constructors + +n = 8 + +serialComm = SerialComm{Int32, Bool, Int16}() +srcMap = BlockMap(n, n, serialComm) +desMap = BlockMap(n, n, serialComm) + +function basicTest(impor) + if isa(impor, Import) + data = impor.importData + else + # basic exports are about the same anyways + data = impor.exportData + end + @test srcMap == data.source + @test desMap == data.target + @test n == data.numSameIDs + @test isa(data.distributor, Distributor{Int32, Bool, Int16}) + @test [] == data.permuteToLIDs + @test [] == data.permuteFromLIDs + @test [] == data.remoteLIDs + @test [] == data.exportLIDs + @test [] == data.exportPIDs + @test true == data.isLocallyComplete +end + +#for scoping purposes +impor = Array{Import, 1}(1) +expor = Array{Export, 1}(1) + +#ensure at least a few lines, each starting with the PID +#Need to escape coloring: .* +debugregex = Regex("^(?:.*INFO: .*$(myPid(serialComm)): .+\n){2,}.*\$") + +# basic import +@test_warn debugregex impor[1] = Import(srcMap, desMap) +basicTest(impor[1]) + +@test_warn debugregex impor[1] = Import(srcMap, desMap, Nullable{AbstractArray{Bool}}()) +basicTest(impor[1]) + + +# import using Dicts +@test_warn debugregex impor[1] = Import(srcMap, desMap, Dict{Symbol, Any}()) +basicTest(impor[1]) +@test_warn debugregex impor[1] = Import(srcMap, desMap, Nullable{AbstractArray{Bool}}(), Dict{Symbol, Any}()) +basicTest(impor[1]) + + +# basic export +@test_warn debugregex expor[1] = Export(srcMap, desMap) +basicTest(expor[1]) +@test_warn debugregex expor[1] = Export(srcMap, desMap, Nullable{AbstractArray{Bool}}()) +basicTest(expor[1]) + +#export using Dicts +@test_warn debugregex expor[1] = Export(srcMap, desMap, Dict{Symbol, Any}()) +basicTest(expor[1]) +@test_warn debugregex expor[1] = Export(srcMap, desMap, Nullable{AbstractArray{Bool}}(), Dict{Symbol, Any}()) +basicTest(expor[1]) diff --git a/test/LocalCRSGraphTests.jl b/test/LocalCRSGraphTests.jl new file mode 100644 index 0000000..3cece7b --- /dev/null +++ b/test/LocalCRSGraphTests.jl @@ -0,0 +1,18 @@ + +graph = LocalCRSGraph{UInt16, UInt32}() +@test Array{UInt16, 1}(0) == graph.entries +@test Array{UInt32, 1}(0) == graph.rowMap +@test 0 == numRows(graph) +@test_throws InvalidArgumentError maxEntry(graph) +@test_throws InvalidArgumentError minEntry(graph) + + +entries = UInt16[248, 230, 17, 26, 143, 101, 251, 13, 97, 380, + 28, 16, 139, 9, 820, 637, 879, 156, 42, 339] +rowMap = UInt32[1, 3, 8, 9, 15, 18, 21] +graph = LocalCRSGraph(entries, rowMap) +@test entries === graph.entries +@test rowMap === graph.rowMap +@test 6 == numRows(graph) +@test 879 == maxEntry(graph) +@test 9 == minEntry(graph) \ No newline at end of file diff --git a/test/LocalCSRMatrixTests.jl b/test/LocalCSRMatrixTests.jl new file mode 100644 index 0000000..19a2807 --- /dev/null +++ b/test/LocalCSRMatrixTests.jl @@ -0,0 +1,22 @@ + +default = LocalCSRMatrix{Float32, UInt32}() +@test 0 == numRows(default) +@test 0 == numCols(default) +@test_throws BoundsError getRowView(default, 1) +@test_throws BoundsError getRowView(default, 5) + + +rawVals = Float32[5, 8, 6, 2, 1, 6] +rawCols = UInt32[2, 4, 5, 2, 3, 1] +mat = LocalCSRMatrix(4, 5, rawVals, UInt32[1, 2, 4, 6, 7], rawCols) +@test isa(mat, LocalCSRMatrix{Float32, UInt32}) +@test 4 == numRows(mat) +@test 5 == numCols(mat) +@test Float32[5] == vals(getRowView(mat, 1)) +@test UInt32[2] == cols(getRowView(mat, 1)) +@test Float32[8, 6] == vals(getRowView(mat, 2)) +@test UInt32[4, 5] == cols(getRowView(mat, 2)) +@test Float32[2, 1] == vals(getRowView(mat, 3)) +@test UInt32[2, 3] == cols(getRowView(mat, 3)) +@test Float32[6] == vals(getRowView(mat, 4)) +@test UInt32[1] == cols(getRowView(mat, 4)) \ No newline at end of file diff --git a/test/LocalCommTests.jl b/test/LocalCommTests.jl new file mode 100644 index 0000000..29bcbd0 --- /dev/null +++ b/test/LocalCommTests.jl @@ -0,0 +1,29 @@ + +function runLocalCommTests(origComm::Comm{GID, PID, LID}) where{GID, PID, LID} + localComm = LocalComm(origComm) + + @test_throws InvalidStateError barrier(localComm) + @test_throws InvalidStateError broadcastAll(localComm, [1, 2, 3], 1) + @test_throws InvalidStateError broadcastAll(localComm, 18, 1) + @test_throws InvalidStateError createDistributor(localComm) + + + @test_throws InvalidStateError gatherAll(localComm, [4, 5, 6, 7]) + @test_throws InvalidStateError gatherAll(localComm, 8) + + @test_throws InvalidStateError sumAll(localComm, [4, 5, 6, 7]) + @test_throws InvalidStateError sumAll(localComm, 8) + + @test_throws InvalidStateError minAll(localComm, [4, 5, 6, 7]) + @test_throws InvalidStateError minAll(localComm, 8) + + @test_throws InvalidStateError maxAll(localComm, [4, 5, 6, 7]) + @test_throws InvalidStateError maxAll(localComm, 8) + + @test_throws InvalidStateError scanSum(localComm, [4, 5, 6, 7]) + @test_throws InvalidStateError scanSum(localComm, 8) + + + @test myPid(origComm) == myPid(localComm) + @test numProc(origComm) == numProc(localComm) +end \ No newline at end of file diff --git a/test/MPIBlockMapTests.jl b/test/MPIBlockMapTests.jl new file mode 100644 index 0000000..d7dadcd --- /dev/null +++ b/test/MPIBlockMapTests.jl @@ -0,0 +1,73 @@ + +#created in including file +#comm = MPIComm{UInt64, UInt16, UInt32}() + +pid = myPid(comm) + +macro MPIMapTests() + @test isa(map, BlockMap{UInt64, UInt16, UInt32}) + + @test uniqueGIDs(map) + + for i = 1:5 + @test myLID(map, i) + @test pid*5 + i - 5 == gid(map, i) + end + + for i = 1:20 + if cld(i, 5) == pid + @test myGID(map, i) + @test (i-1)%5+1 == lid(map, i) + else + @test !myGID(map, i) + @test 0 == lid(map, i) + end + end + + + @test !myLID(map, -1) + @test !myLID(map, 0) + @test !myLID(map, 6) + @test !myLID(map, 46) + @test !myGID(map, -1) + @test !myGID(map, 0) + @test !myGID(map, 21) + @test !myGID(map, 46) + + @test 0 == lid(map, -1) + @test 0 == lid(map, 0) + @test 0 == lid(map, 21) + @test 0 == lid(map, 46) + @test 0 == gid(map, -1) + @test 0 == gid(map, 0) + @test 0 == gid(map, 6) + @test 0 == gid(map, 46) + + @test distributedGlobal(map) + + @test linearMap(map) + + @test 20 == numGlobalElements(map) + @test 5 == numMyElements(map) + + @test pid*5 - 4 == minMyGID(map) + @test pid*5 == maxMyGID(map) + @test 1 == minAllGID(map) + @test 20 == maxAllGID(map) + @test 1 == minLID(map) + @test 5 == maxLID(map) + + @test ([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4], + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5]) == remoteIDList(map, collect(1:20)) + + @test collect((1:5) + 5*(pid - 1)) == myGlobalElementIDs(map) +end + +map = BlockMap(20, comm) +@MPIMapTests + +map = BlockMap(20, 5, comm) +@MPIMapTests + +map = BlockMap(collect((1:5) + 5*(pid - 1)), comm) +@MPIMapTests \ No newline at end of file diff --git a/test/MPICommTests.jl b/test/MPICommTests.jl new file mode 100644 index 0000000..c151e29 --- /dev/null +++ b/test/MPICommTests.jl @@ -0,0 +1,65 @@ + +#ensure multiple calls to MPIComm works +MPIComm(Bool, Bool, Bool) +MPIComm(UInt64, UInt16, UInt32) + + +@test 4 == numProc(comm) +@test isa(numProc(comm), UInt16) + +@test 1 <= myPid(comm) <= 4 +@test isa(myPid(comm), UInt16) + +@test [1, 8] == broadcastAll(comm, [myPid(comm), 8], 1) +@test [2, 5] == broadcastAll(comm, [myPid(comm), 5], 2) +@test [3, 7] == broadcastAll(comm, [myPid(comm), 7], 3) +@test [4, 6] == broadcastAll(comm, [myPid(comm), 6], 4) + +@test [1, 2, 3, 4] == gatherAll(comm, [myPid(comm)]) +@test ([1, 2, 3, 2, 4, 6, 3, 6, 9, 4, 8, 12] + == gatherAll(comm, [myPid(comm), myPid(comm)*2, myPid(comm)*3])) + +#check for hangs and such, hard to test if all processes are at the same spot +barrier(comm) + +@test [10] == sumAll(comm, [myPid(comm)]) +@test [32, 12, 10, 8] == sumAll(comm, [8, 3, myPid(comm), 2]) + +@test [4] == maxAll(comm, [myPid(comm)]) +@test [4, -1, 8] == maxAll(comm, [myPid(comm), -Int(myPid(comm)), 8]) + +@test [1] == minAll(comm, [myPid(comm)]) +@test [1, -4, 6] == minAll(comm, [myPid(comm), -Int(myPid(comm)), 6]) + +@test [sum(1:myPid(comm))] == scanSum(comm, [myPid(comm)]) +@test ([myPid(comm)*5, sum(-2:-2:(-2*Int(myPid(comm)))), myPid(comm)*3] + == scanSum(comm, [5, -2*Int(myPid(comm)), 3])) + +#test distributor + +dist = createDistributor(comm) +@test isa(dist, Distributor{UInt64, UInt16, UInt32}) + +#check for error when not waiting +@test_throws InvalidStateError resolveWaits(dist) +@test_throws InvalidStateError resolveReverseWaits(dist) + +@test 4 == createFromSends(dist, [1, 2, 3, 4]) + +resolvePosts(dist, [pid, 2*pid, 3*pid, 4*pid]) +@test_throws InvalidStateError resolveReverseWaits(dist) +@test pid*[1, 2, 3, 4] == resolveWaits(dist) + +#check for error when not waiting +@test_throws InvalidStateError resolveWaits(dist) +@test_throws InvalidStateError resolveReverseWaits(dist) + +#test distributor when elements not blocked by processor +dist2 = createDistributor(comm) +@test 8 == createFromSends(dist, [1, 2, 3, 4, 1, 2, 3, 4]) + +@test (reduce(vcat, [], [[(pid-1)*5+j, (pid+3)*5+j] for j in 1:4]) + == resolve(dist, [pid, 5+pid, 10+pid, 15+pid, 20+pid, 25+pid, 30+pid, 35+pid])) + +#test distributore createFromRecvs +@test ([(pid-1)*5+j for j in 1:4], [1, 2, 3, 4]) == createFromRecvs(dist, [pid, 5+pid, 10+pid, 15+pid], [1, 2, 3, 4]) \ No newline at end of file diff --git a/test/MPITestsStarter.jl b/test/MPITestsStarter.jl new file mode 100644 index 0000000..ef72015 --- /dev/null +++ b/test/MPITestsStarter.jl @@ -0,0 +1,7 @@ + +test_path = test_path = abspath(Pkg.dir(), "JuliaPetra", "test", "runMPITests.jl") +color = Base.have_color? "--color=yes" : "--color=no" +codecov = (Bool(Base.JLOptions().code_coverage)? ["--code-coverage=user"] : ["--code-coverage=none"]) +compilecache = "--compilecache=" * (Bool(Base.JLOptions().use_compilecache) ? "yes" : "no") +julia_exe = Base.julia_cmd() +run(`mpirun -np 4 $julia_exe --check-bounds=yes $codecov $color $compilecache $test_path`) diff --git a/test/MPIimport-export Tests.jl b/test/MPIimport-export Tests.jl new file mode 100644 index 0000000..2f262f7 --- /dev/null +++ b/test/MPIimport-export Tests.jl @@ -0,0 +1,55 @@ +n = 4 + +srcMap = BlockMap(4*n, comm) +desMap = BlockMap(collect((1:n) + n*(pid%4)), comm) + +function basicMPITest(impor) + if isa(impor, Import) + data = impor.importData + else + data = impor.exportData + end + @test srcMap == data.source + @test desMap == data.target + @test 0 == data.numSameIDs + @test isa(data.distributor, Distributor{UInt64, UInt16, UInt32}) + @test [] == data.permuteToLIDs + @test [] == data.permuteFromLIDs + #TODO test remoteLIDs, exportLIDs, exportPIDs + @test true == data.isLocallyComplete +end + +#for scoping purposes +impor = Array{Import, 1}(1) +expor = Array{Export, 1}(1) + +#ensure at least a few lines, each starting with the PID +#Need to escape coloring: .* +#"^(?:.*INFO: .*$pid: .+\n){2,}.*\$" +debugregex = Regex("^(?:.*INFO: .*$pid: .+\n){2,}.*\$") + +# basic import +@test_warn debugregex impor[1] = Import(srcMap, desMap) +basicMPITest(impor[1]) +@test_warn debugregex impor[1] = Import(srcMap, desMap, Nullable{AbstractArray{UInt16}}()) +basicMPITest(impor[1]) + + +# import using Dicts +@test_warn debugregex impor[1] = Import(srcMap, desMap, Dict{Symbol, Any}()) +basicMPITest(impor[1]) +@test_warn debugregex impor[1] = Import(srcMap, desMap, Nullable{AbstractArray{UInt16}}(), Dict{Symbol, Any}()) +basicMPITest(impor[1]) + + +# basic export +@test_warn debugregex expor[1] = Export(srcMap, desMap) +basicMPITest(expor[1]) +@test_warn debugregex expor[1] = Export(srcMap, desMap, Nullable{AbstractArray{UInt16}}()) +basicMPITest(expor[1]) + +#export using Dicts +@test_warn debugregex expor[1] = Export(srcMap, desMap, Dict{Symbol, Any}()) +basicMPITest(expor[1]) +@test_warn debugregex expor[1] = Export(srcMap, desMap, Nullable{AbstractArray{UInt16}}(), Dict{Symbol, Any}()) +basicMPITest(expor[1]) diff --git a/test/MacroTests.jl b/test/MacroTests.jl new file mode 100644 index 0000000..e805116 --- /dev/null +++ b/test/MacroTests.jl @@ -0,0 +1,4 @@ + + +#test with debug mode enabled +@test true == @macroexpand JuliaPetra.@debug diff --git a/test/MultiVectorTests.jl b/test/MultiVectorTests.jl new file mode 100644 index 0000000..f665d9c --- /dev/null +++ b/test/MultiVectorTests.jl @@ -0,0 +1,146 @@ +#these tests are used for test MultiVector under both serial and MPI comms + +function multiVectorTests(comm::Comm{UInt64, UInt16, UInt32}) + #number of elements in vectors + n = 8 + + pid = myPid(comm) + nProcs = numProc(comm) + + curMap = BlockMap(nProcs*n, n, comm) + + # test basic construction with setting data to zeros + vect = MultiVector{Float64, UInt64, UInt16, UInt32}(curMap, 3, true) + @test n == localLength(vect) + @test nProcs*n == globalLength(vect) + @test 3 == numVectors(vect) + @test curMap == JuliaPetra.map(vect) + @test zeros(Float64, (n, 3)) == vect.data + + # test basic construction without setting data to zeros + vect = MultiVector{Float64, UInt64, UInt16, UInt32}(curMap, 3, false) + @test n == localLength(vect) + @test nProcs*n == globalLength(vect) + @test 3 == numVectors(vect) + @test curMap == JuliaPetra.map(vect) + + # test wrapper constructor + arr = Array{Float64, 2}(n, 3) + vect = MultiVector(curMap, arr) + @test n == localLength(vect) + @test nProcs*n == globalLength(vect) + @test 3 == numVectors(vect) + @test curMap == JuliaPetra.map(vect) + @test arr === vect.data + + # test copy + vect2 = copy(vect) + @test n == localLength(vect) + @test nProcs*n == globalLength(vect) + @test 3 == numVectors(vect) + @test curMap == JuliaPetra.map(vect) + @test vect.data == vect2.data + @test vect.data !== vect2.data #ensure same contents, but different address + + vect2 = MultiVector{Float64, UInt64, UInt16, UInt32}(curMap, 3, false) + @test vect2 === copy!(vect2, vect) + @test localLength(vect) == localLength(vect2) + @test globalLength(vect) == globalLength(vect2) + @test numVectors(vect) == numVectors(vect2) + @test JuliaPetra.map(vect) == JuliaPetra.map(vect2) + @test vect.data == vect2.data + @test vect.data !== vect2.data + + + # test scale and scale! + vect = MultiVector(curMap, ones(Float64, n, 3)) + @test vect == scale!(vect, pid*5.0) + @test pid*5*ones(Float64, (n, 3)) == vect.data + + vect = MultiVector(curMap, ones(Float64, n, 3)) + vect2 = scale(vect, pid*5.0) + @test vect !== vect2 + @test pid*5*ones(Float64, (n, 3)) == vect2.data + + increase = pid*nProcs + + vect = MultiVector(curMap, ones(Float64, n, 3)) + @test vect == scale!(vect, increase+[2.0, 3.0, 4.0]) + @test hcat( (increase+2)*ones(Float64, n), + (increase+3)*ones(Float64, n), + (increase+4)*ones(Float64, n)) == vect.data + + for i = 1:3 + act = i+1+repeat(Float64[increase], inner=n) + @test act == getVectorView(vect, i) + @test act == getVectorCopy(vect, i) + end + + vect = MultiVector(curMap, ones(Float64, n, 3)) + vect2 = scale(vect, pid*nProcs+[2.0, 3.0, 4.0]) + @test vect !== vect2 + @test hcat( (pid*nProcs+2)*ones(Float64, n), + (pid*nProcs+3)*ones(Float64, n), + (pid*nProcs+4)*ones(Float64, n)) == vect2.data + + #test dot + vect = MultiVector(curMap, ones(Float64, n, 3)) + @test fill(n*nProcs, 3) == dot(vect, vect) + + #test fill! + fill!(vect, 8) + @test 8*ones(Float64, (n, 3)) == vect.data + + + #test reduce + arr = (10^pid)*ones(Float64, n, 3) + vect = MultiVector(BlockMap(n, n, comm), arr) + commReduce(vect) + @test sum(10^i for i in 1:nProcs)*ones(Float64, n, 3) == vect.data + + + #test norm2 + arr = ones(Float64, n, 3) + vect = MultiVector(curMap, arr) + @test [sqrt(n*nProcs), sqrt(n*nProcs), sqrt(n*nProcs)] == norm2(vect) + + arr = 2*ones(Float64, n, 3) + vect = MultiVector(curMap, arr) + @test [sqrt(4*n*nProcs), sqrt(4*n*nProcs), sqrt(4*n*nProcs)] == norm2(vect) + + + + #test imports/exports + source = MultiVector(curMap, + Array{Float64, 2}(reshape(collect(1:(3*n)), (n, 3)))) + target = MultiVector{Float64, UInt64, UInt16, UInt32}(curMap, 3, false) + impor = Import(curMap, curMap) + doImport(source, target, impor, REPLACE) + @test reshape(Array{Float64, 1}(collect(1:(3*n))), (n, 3)) == target.data + + + source = MultiVector(curMap, + Array{Float64, 2}(reshape(collect(1:(3*n)), (n, 3)))) + target = MultiVector{Float64, UInt64, UInt16, UInt32}(curMap, 3, false) + expor = Export(curMap, curMap) + doExport(source, target, expor, REPLACE) + @test reshape(Array{Float64, 1}(collect(1:(3*n))), (n, 3)) == target.data + + source = MultiVector(curMap, + Array{Float64, 2}(reshape(collect(1:(3*n)), (n, 3)))) + target = MultiVector{Float64, UInt64, UInt16, UInt32}(curMap, 3, false) + impor = Import(curMap, curMap) + doExport(source, target, impor, REPLACE) + @test reshape(Array{Float64, 1}(collect(1:(3*n))), (n, 3)) == target.data + + + source = MultiVector(curMap, + Array{Float64, 2}(reshape(collect(1:(3*n)), (n, 3)))) + target = MultiVector{Float64, UInt64, UInt16, UInt32}(curMap, 3, false) + expor = Export(curMap, curMap) + doImport(source, target, expor, REPLACE) + @test reshape(Array{Float64, 1}(collect(1:(3*n))), (n, 3)) == target.data + + #TODO create import expor tests to test non trivial case + +end diff --git a/test/SerialCommTests.jl b/test/SerialCommTests.jl new file mode 100644 index 0000000..5f9a714 --- /dev/null +++ b/test/SerialCommTests.jl @@ -0,0 +1,79 @@ + +### test Serial Comm ### + +serialComm = SerialComm{Int, Int, Int}() +@test typeof(serialComm) == SerialComm{Int, Int, Int} +@test typeof(SerialComm{Int, Int, Int}()) == SerialComm{Int, Int, Int} + +io = IOBuffer() +show(io, serialComm) +@test "SerialComm{$(String(Symbol(Int))),$(String(Symbol(Int))),$(String(Symbol(Int)))} with PID 1 and 1 processes" == String(take!(io)) + +# ensure no errors or hangs +barrier(serialComm) + +@test_throws InvalidArgumentError broadcastAll(serialComm, [1, 2, 3], 2) +@test [1, 2, 3] == broadcastAll(serialComm, [1, 2, 3], 1) +@test ['a', 'b', 'c'] == broadcastAll(serialComm, ['a', 'b', 'c'], 1) + +@test [1, 2, 3] == gatherAll(serialComm, [1, 2, 3]) +@test ['a', 'b', 'c'] == gatherAll(serialComm, ['a', 'b', 'c']) + +@test [1, 2, 3] == sumAll(serialComm, [1, 2, 3]) +@test ['a', 'b', 'c'] == sumAll(serialComm, ['a', 'b', 'c']) + +@test [1, 2, 3] == maxAll(serialComm, [1, 2, 3]) +@test ['a', 'b', 'c'] == maxAll(serialComm, ['a', 'b', 'c']) + +@test [1, 2, 3] == minAll(serialComm, [1, 2, 3]) +@test ['a', 'b', 'c'] == minAll(serialComm, ['a', 'b', 'c']) + +@test [1, 2, 3] == scanSum(serialComm, [1, 2, 3]) +@test ['a', 'b', 'c'] == scanSum(serialComm, ['a', 'b', 'c']) + +@test 1 == myPid(serialComm) +@test 1 == numProc(serialComm) + +serialDistributor = createDistributor(serialComm) +@test typeof(serialDistributor) <: Distributor + + +### test Serial Distributor ### + +@test_throws InvalidArgumentError createFromSends(serialDistributor, [1, 1, 1, 2]) +@test_throws InvalidArgumentError createFromSends(serialDistributor, [1, 1, 2, 1]) +@test_throws InvalidArgumentError createFromSends(serialDistributor, [2, 1, 1]) +@test 1 == createFromSends(serialDistributor, [1]) +@test 2 == createFromSends(serialDistributor, [1, 1]) +@test 5 == createFromSends(serialDistributor, [1, 1, 1, 1, 1]) +@test 8 == createFromSends(serialDistributor, [1, 1, 1, 1, 1, 1, 1, 1]) + + +@test_throws InvalidArgumentError createFromRecvs(serialDistributor, [1, 2, 3, 4], [1, 1, 1, 2]) +@test_throws InvalidArgumentError createFromRecvs(serialDistributor, [1, 2, 3, 4, 5], [2, 1, 1, 1, 1]) +@test_throws InvalidArgumentError createFromRecvs(serialDistributor, [1, 2, 3, 4, 5, 6], [1, 1, 1, 2, 1, 1]) +@test ([2], [1]) == createFromRecvs(serialDistributor, [2], [1]) +@test ([2, 3, 4, 5], [1, 1, 1, 1]) == createFromRecvs(serialDistributor, [2, 3, 4, 5], [1, 1, 1, 1]) + + +@test [3] == resolve(serialDistributor, [3]) +@test [3, 4, 5, 6, 7, 8, 9, 10] == resolve(serialDistributor, [3, 4, 5, 6, 7, 8, 9, 10]) +@test ['a', 'b', 'c', 'd'] == resolve(serialDistributor, ['a', 'b', 'c', 'd']) + + +@test [4] == resolveReverse(serialDistributor, [4]) +@test [11, 4, 5, 6, 7, 8, 9, 10] == resolveReverse(serialDistributor, [11, 4, 5, 6, 7, 8, 9, 10]) +@test ['a', 'b', 'c', 'k'] == resolveReverse(serialDistributor, ['a', 'b', 'c', 'k']) + +@test_throws InvalidStateError resolveWaits(serialDistributor) +@test_throws InvalidStateError resolveReverseWaits(serialDistributor) + +resolvePosts(serialDistributor, [1, 2, 3]) +resolvePosts(serialDistributor, [6, 7, 8, 9]) +@test [6, 7, 8, 9] == resolveWaits(serialDistributor) +@test_throws InvalidStateError resolveWaits(serialDistributor) + +resolveReversePosts(serialDistributor, [11, 12, 13]) +resolveReversePosts(serialDistributor, [16, 17, 81, 19]) +@test [16, 17, 81, 19] == resolveReverseWaits(serialDistributor) +@test_throws InvalidStateError resolveReverseWaits(serialDistributor) diff --git a/test/Slow Stability Tests.jl b/test/Slow Stability Tests.jl new file mode 100644 index 0000000..6142fdf --- /dev/null +++ b/test/Slow Stability Tests.jl @@ -0,0 +1,31 @@ + +#these tests take way too long to be included in the normal unit tests + +using JuliaPetra +include("TypeStability.jl") +include("TestUtil.jl") + +println("starting csrmatrix stability tests") +#stability tests +lidCount = length(stableLIDs) +pidCount = length(stablePIDs) +gidRange = 1:length(stableGIDs) +for Data in stableDatas + for i in gidRange + @inbounds GID = stableGIDs[i] + for j in i:lidCount + @inbounds LID = stableLIDs[j] + for k in j:pidCount + @inbounds PID = stablePIDs[k] + @test is_stable(check_method(apply!, (MultiVector{Data, GID, PID, LID}, + CSRMatrix{Data, GID, PID, LID}, + MultiVector{Data, GID, PID, LID}, + TransposeMode, + Data, + Data))) + end + end + end +end + +println("finished with csrmatrix stability tests") diff --git a/test/SparseRowViewTests.jl b/test/SparseRowViewTests.jl new file mode 100644 index 0000000..205f320 --- /dev/null +++ b/test/SparseRowViewTests.jl @@ -0,0 +1,32 @@ +rawVals = Float32[2, 5, 3, 4, 8, 6, 2, 4] +rawCols = Int32[1, 5, 8, 9, 11, 16, 20, 24] + +rowView = SparseRowView(rawVals, rawCols) +@test isa(rowView, SparseRowView{Float32, Int32}) +@test 8 == nnz(rowView) +@test rawVals == vals(rowView) +@test rawCols == cols(rowView) + +rowView = SparseRowView(rawVals, rawCols, 6) +@test isa(rowView, SparseRowView{Float32, Int32}) +@test 6 == nnz(rowView) +@test rawVals[1:6] == vals(rowView) +@test rawCols[1:6] == cols(rowView) + +rowView = SparseRowView(rawVals, rawCols, 5, 2) +@test isa(rowView, SparseRowView{Float32, Int32}) +@test 5 == nnz(rowView) +@test rawVals[2:6] == vals(rowView) +@test rawCols[2:6] == cols(rowView) + +rowView = SparseRowView(rawVals, rawCols, 3, 3, 2) +@test isa(rowView, SparseRowView{Float32, Int32}) +@test 3 == nnz(rowView) +@test rawVals[3:2:7] == vals(rowView) +@test rawCols[3:2:7] == cols(rowView) + +@test_throws InvalidArgumentError SparseRowView([1, 2, 3], [1]) +@test_throws BoundsError SparseRowView(rawVals, rawCols, 20) +@test_throws BoundsError SparseRowView(rawVals, rawCols, 9) +@test_throws BoundsError SparseRowView(rawVals, rawCols, 5, 5) +@test_throws BoundsError SparseRowView(rawVals, rawCols, 4, 1, 3) diff --git a/test/TestCommand.jl b/test/TestCommand.jl new file mode 100644 index 0000000..d362bd2 --- /dev/null +++ b/test/TestCommand.jl @@ -0,0 +1,30 @@ +function testJuliaPetra(flags...; coverage=false) +# function test!(pkg::AbstractString, +# errs::Vector{AbstractString}, +# nopkgs::Vector{AbstractString}, +# notests::Vector{AbstractString}; coverage::Bool=false) + + formattedFlags = ["--$flag" for flag in flags] + combinedFlags = join(formattedFlags, "\n") + + test_path = abspath(Pkg.dir(), "JuliaPetra", "test", "runtests.jl") + info("Testing JuliaPetra") + Base.cd(dirname(test_path)) do + try + cmd = ``` + $(Base.julia_cmd()) + --code-coverage=$(coverage ? "user" : "none") + --color=$(Base.have_color ? "yes" : "no") + --compilecache=$(Bool(Base.JLOptions().use_compilecache) ? "yes" : "no") + --check-bounds=yes + --startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no") + $test_path + $formattedFlags + ``` + run(cmd) + info("JuliaPetra tests passed") + catch err + Base.Pkg.Entry.warnbanner(err, label="[ ERROR: JuliaPetra ]") + end + end +end diff --git a/test/TestUtil.jl b/test/TestUtil.jl new file mode 100644 index 0000000..cc89e2f --- /dev/null +++ b/test/TestUtil.jl @@ -0,0 +1,7 @@ + +const stableGIDs = (UInt128, UInt64, UInt32, UInt16, UInt8, Int128, Int64, Int32, Int16, Int8) +const stablePIDs = (UInt128, UInt64, UInt32, UInt16, UInt8, Int128, Int64, Int32, Int16, Int8, Bool) +const stableLIDs = (UInt128, UInt64, UInt32, UInt16, UInt8, Int128, Int64, Int32, Int16, Int8) + +const stableReals= (Float64, Float32, Float16, UInt128, UInt64, UInt32, UInt16, UInt8, Int128, Int64, Int32, Int16, Int8, Bool) +const stableDatas= union(stableReals, [Complex{r} for r in stableReals]) diff --git a/test/TypeStability.jl b/test/TypeStability.jl new file mode 100644 index 0000000..1f107b4 --- /dev/null +++ b/test/TypeStability.jl @@ -0,0 +1,94 @@ + +#export check_function, check_method +#export StabilityReport, is_stable + +#TODO figure out IDE inter-op +#which(func, params).file gives Symbol containing the source file of the function +#which(func, params).line gives Int of line number of source of function + + +function check_function(func, param_types; unstable_vars=Dict{Symbol, Type}(), unstable_return::Bool=false) + result = Tuple{Any, StabilityReport}[] + for params in param_types + push!(result, (params, check_method(func, params; unstable_vars=unstable_vars, unstable_return=unstable_return))) + end + result +end + +#Based off julia's code_warntype +function check_method(func, param_types; unstable_vars=Dict{Symbol, Type}(), unstable_return::Bool=false) + function slots_used(ci, slotnames) + used = falses(length(slotnames)) + scan_exprs!(used, ci.code) + return used + end + + function scan_exprs!(used, exprs) + for ex in exprs + if isa(ex, Slot) + used[ex.id] = true + elseif isa(ex, Expr) + scan_exprs!(used, ex.args) + end + end + end + + #loop over possible methods for the given argument types + code = code_typed(func, param_types) + if length(code) != 1 + warn("mutliple methods for $func matching $param_types") + end + + unstable_vars = Array{Tuple{Symbol, Type}, 1}(0) + unstable_ret = Nullable{Type}() + + for (src, rettyp) in code + #check variables + slotnames = Base.sourceinfo_slotnames(src) + used_slotids = slots_used(src, slotnames) + + if isa(src.slottypes, Array) + for i = 1:length(slotnames) + if used_slotids[i] + name = Symbol(slotnames[i]) + typ = src.slottypes[i] + if (!isleaftype(typ) || typ == Core.Box) && !(typ <: get(unstable_vars, name, Int64)) + push!(unstable_var, (name, typ)) + end + + #else likely optmized out + end + end + else + warn("Can't access slot types of CodeInfo") + end + + if !unstable_return && (!isleaftype(rettyp) || rettyp == Core.Box) + unstable_ret = Nullable(rettyp) + end + + #TODO check body + end + + return StabilityReport(unstable_vars, unstable_ret) +end + +struct StabilityReport + unstable_variables::Array{Tuple{Symbol, Type}, 1} + unstable_return::Nullable{Type} +end + +StabilityReport() = StabilityReport(Array{Tuple{Symbol, Type}, 1}(0), Nullable{Type}()) + +is_stable(report::StabilityReport) = length(report.unstable_variables) == 0 && isnull(report.unstable_return) +is_stable(reports::Array{Tuple{Any, StabilityReport}}) = all(@. is_stable(getindex(reports, 2))) + + +function parameter_cartesian(typ::Type, params) + results = Type[] + for p in Base.product(params...) + push!(results, typ{p...}) + end + + results +end diff --git a/test/runMPITests.jl b/test/runMPITests.jl new file mode 100644 index 0000000..343de46 --- /dev/null +++ b/test/runMPITests.jl @@ -0,0 +1,68 @@ +#have debug enabled while running tests +globalDebug = true + +using JuliaPetra +using Base.Test + +include("TypeStability.jl") +include("TestUtil.jl") + +const GID = UInt64 +const PID = UInt16 +const LID = UInt32 + +#use distinct types +const comm = MPIComm(UInt64, UInt16, UInt32) + +const pid = myPid(comm) + +#only print errors from one process +if pid != 1 + #redirect_stdout() + #redirect_stderr() +end + + +#tries are to allow barriers to work correctly, even under erronious situtations +try + @testset "MPI Tests" begin + try + @testset "Comm MPI Tests" begin + include("MPICommTests.jl") + include("MPIBlockMapTests.jl") + include("MPIimport-export Tests.jl") + + include("LocalCommTests.jl") + runLocalCommTests(comm) + + include("BasicDirectoryTests.jl") + basicDirectoryTests(comm) + end + + @testset "Data MPI Tests" begin + include("MultiVectorTests.jl") + multiVectorTests(comm) + + include("CSRMatrixMPITests.jl") + end + + finally + #print results sequentially + for i in 1:pid + barrier(comm) + end + end + info("process $pid test results:") + end + +finally + #print results sequentially + for i in pid:4 + barrier(comm) + end +end + +#catch err +# sleep(10) +# throw(err) +#end diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100644 index 0000000..1dd9b5e --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,66 @@ +#have debug enabled while running tests +globalDebug = true + +using TypeStability +#check stability while running tests +enable_inline_stability_checks(true) + +using JuliaPetra +using Base.Test + +@. ARGS = lowercase(ARGS) +# check for command line arguments requesting parts to not be tested +const noMPI = in("--mpi", ARGS) #don't run multi-process tests +const noComm = in("--comm", ARGS) #don't run comm framework tests +const noDataStructs = in("--data", ARGS) #don't run tests on data structures +const noUtil = in("--util", ARGS) #don't run tests on Misc Utils + +include("TypeStability.jl") +include("TestUtil.jl") + + +@testset "Serial tests" begin + + #a generic serial comm for the tests that need to be called with a comm object + const serialComm = SerialComm{UInt64, UInt16, UInt32}() + + if !noUtil + @testset "Util Tests" begin + include("MacroTests.jl") + include("ComputeOffsetsTests.jl") + end + end + + if !noComm + @testset "Comm Tests" begin + include("SerialCommTests.jl") + include("Import-Export Tests.jl") + include("BlockMapTests.jl") + + include("LocalCommTests.jl") + runLocalCommTests(serialComm) + + include("BasicDirectoryTests.jl") + basicDirectoryTests(serialComm) + end + end + + if !noDataStructs + @testset "Data Structure Tests" begin + include("MultiVectorTests.jl") + multiVectorTests(serialComm) + + include("SparseRowViewTests.jl") + include("LocalCRSGraphTests.jl") + include("LocalCSRMatrixTests.jl") + + include("CRSGraphTests.jl") + include("CSRMatrixTests.jl") + end + end +end + +# do MPI tests at the end so that other errors are found faster since the MPI tests take the longest +if !noMPI + include("MPITestsStarter.jl") +end