Skip to content

Commit

Permalink
Type Variation for Comp and Math Methods
Browse files Browse the repository at this point in the history
Make the computational methods in CSRMatrix.jl and the mathematical methods in MultiVector.jl compatible for multiple data types
  • Loading branch information
fsmith001 committed Aug 4, 2022
1 parent 975f7a8 commit 9bf5656
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 44 deletions.
48 changes: 21 additions & 27 deletions src/CSRMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1133,33 +1133,27 @@ function localApply(Y::MultiVector{Data, GID, PID, LID},
for vect = LID(1):numVectors(Y)
for row = LID(1):numRows
sum::Data = Data(0)
#@inbounds (indices, values, len) = getLocalRowViewPtr(A, row)
@boundscheck((indices, values, len) = getLocalRowViewPtr(A, row))
@inbounds (indices, values, len) = getLocalRowViewPtr(A, row)
for i in LID(1):LID(len)
ind::LID = unsafe_load(indices, i)
val::Data = unsafe_load(values, i)
#@inbounds sum += val*rawX[ind, vect]
@boundscheck (sum += val*rawX[ind, vect])
@inbounds sum += val*rawX[ind, vect]
end
sum = applyConjugation(mode, sum*alpha)
#@inbounds rawY[row, vect] *= beta
@boundscheck (rawY[row, vect] *= beta)
#@inbounds rawY[row, vect] += sum
@boundscheck (rawY[row, vect] += sum)
@inbounds rawY[row, vect] *= beta
@inbounds rawY[row, vect] += sum
end
end
else
rawY[:, :] *= beta
numRows = getLocalNumRows(A)
for vect = LID(1):numVectors(Y)
for mRow in LID(1):numRows
#@inbounds (indices, values, len) = getLocalRowViewPtr(A, mRow)
@boundscheck ((indices, values, len) = getLocalRowViewPtr(A, mRow))
@inbounds (indices, values, len) = getLocalRowViewPtr(A, mRow)
for i in LID(1):LID(len)
ind::LID = unsafe_load(indices, i)
val::Data = unsafe_load(values, i)
#@inbounds rawY[ind, vect] += applyConjugation(mode, alpha*rawX[mRow, vect]*val)
@boundscheck (rawY[ind, vect] += applyConjugation(mode, alpha*rawX[mRow, vect]*val))
@inbounds rawY[ind, vect] += applyConjugation(mode, alpha*rawX[mRow, vect]*val)
end
end
end
Expand All @@ -1172,7 +1166,7 @@ end
Finds the product of matrix-vector multiplication of A and X
"""
function matVecMult(Y::MultiVector{Data, GID, PID, LID}, A::CSRMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID})
function matVecMult(Y::MultiVector{}, A::CSRMatrix{}, X::MultiVector{})

rawY = getLocalArray(Y)
rawX = getLocalArray(X)
Expand All @@ -1192,20 +1186,20 @@ function matVecMult(Y::MultiVector{Data, GID, PID, LID}, A::CSRMatrix{Data, GID,
end

"""
power1(A::CSRMatrix, niter::Integer, tol)
power(A::CSRMatrix, niter::Integer, tol)
Power method on a CSRMatrix to solve for the dominant eigenvalue and eigenvector
"""
function power(A::CSRMatrix{}, niter::Integer, tol)
function power(A::CSRMatrix{Data}, niter, tol)
rows = getNumEntriesInLocalRow(A, 1)
comm = SerialComm{Int, Int, Int}()
Data = Float64
map = BlockMap(rows, rows, comm)
x = DenseMultiVector(map, Matrix{Float64}(ones(rows,1)))
x = DenseMultiVector(map, Matrix{Data}(ones(rows,1)))
nold = undef
n = undef

y = DenseMultiVector(map, Matrix{Float64}(undef, rows, rows))
y = DenseMultiVector(map, Matrix{Data}(undef, rows, rows))
alpha = Data(1)
beta = Data(1)

Expand Down Expand Up @@ -1240,9 +1234,9 @@ end
Returns a vector of the inverse of the maximum absolute values of each row of A
"""
function invRowMax(A::CSRMatrix{})
function invRowMax(A::CSRMatrix{Data})
rows = getGlobalNumRows(A)
inv = Vector{Float64}(undef, rows)
inv = Vector{Data}(undef, rows)

for i = 1:rows
(inds, vals) = getGlobalRowView(A, i)
Expand All @@ -1268,9 +1262,9 @@ end
Returns a vector of the inverse of the sums of each row of A
"""
function invRowSum(A::CSRMatrix{})
function invRowSum(A::CSRMatrix{Data})
rows = getGlobalNumRows(A)
inv = Vector{Float64}(undef, rows)
inv = Vector{Data}(undef, rows)

for i = 1:rows
(inds, vals) = getGlobalRowView(A, i)
Expand All @@ -1293,10 +1287,10 @@ end
Returns a vector of the inverse of the maximum absolute values of each column of A
"""
function invColMax(A::CSRMatrix{})
function invColMax(A::CSRMatrix{Data})
rows = getGlobalNumRows(A)
inv = Vector{Float64}(undef, rows)
thisVal = Vector{Float64}(undef, rows)
inv = Vector{Data}(undef, rows)
thisVal = Vector{Data}(undef, rows)
cols = 0

for i = 1:rows
Expand Down Expand Up @@ -1330,10 +1324,10 @@ end
Returns a vector of the inverse of the sums of each column of A
"""
function invColSum(A::CSRMatrix{})
function invColSum(A::CSRMatrix{Data})
rows = getGlobalNumRows(A)
inv = Vector{Float64}(undef, rows)
thisVal = Vector{Float64}(undef, rows)
inv = Vector{Data}(undef, rows)
thisVal = Vector{Data}(undef, rows)
cols = 0

for i = 1:rows
Expand Down
24 changes: 12 additions & 12 deletions src/MultiVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ end
Returns a MultiVector where each entry is the absolute value of the corresponding entry in the given MultiVector
"""
function multiAbs(v::MultiVector)
function multiAbs(v::MultiVector{Data})
numVects = numVectors(v)
length = localLength(v)
result = Array{Float64}(undef, (numVects, length))
result = Array{Data}(undef, (numVects, length))

for i=1:numVects
for j=1:length
Expand All @@ -200,10 +200,10 @@ end
Returns a MultiVector where each entry is the reciprocal of the corresponding entry in the given MultiVector
"""
function reciprocal(v::MultiVector)
function reciprocal(v::MultiVector{Data})
numVects = numVectors(v)
length = localLength(v)
result = Array{Float64}(undef, (numVects, length))
result = Array{Data}(undef, (numVects, length))

for i=1:numVects
for j=1:length
Expand All @@ -219,10 +219,10 @@ end
Returns the minimum value of each vector in the given MultiVector
"""
function minValue(v::MultiVector{})
function minValue(v::MultiVector{Data})
numVects = numVectors(v)
length = localLength(v)
min = Vector{Float64}(undef, numVects)
min = Vector{Data}(undef, numVects)

for i=1:numVects #each iteration of outer for loop is for each vector in the multivector
min[i] = v[i,1]#set the min for this vector as the first element
Expand All @@ -241,10 +241,10 @@ end
Returns the maximum value of each vector in the given MultiVector
"""
function maxValue(v::MultiVector{})
function maxValue(v::MultiVector{Data})
numVects = numVectors(v)
length = localLength(v)
max = Vector{Float64}(undef, numVects)
max = Vector{Data}(undef, numVects)

for i=1:numVects #each iteration of outer for loop is for each vector in the multivector
max[i] = v[i,1] #set the max for this vector as the first element
Expand All @@ -263,10 +263,10 @@ end
Returns the mean value of each vector in the given MultiVector
"""
function meanValue(v::MultiVector{})
function meanValue(v::MultiVector{Data})
numVects = numVectors(v)
length = localLength(v)
mean = Vector{Float64}(undef, numVects)
mean = Vector{Data}(undef, numVects)

for i=1:numVects #each iteration of outer for loop is for each vector in the multivector
sum = 0
Expand All @@ -283,12 +283,12 @@ end
Returns a MultiVector of the element-by-element wise product of the given MultiVectors, then scaled by the given scalar
"""
function multiply(scalar::Float64, A::MultiVector, B::MultiVector)
function multiply(scalar, A::MultiVector{Data}, B::MultiVector{Data})
comm = SerialComm{Int, Int, Int}()
rows = localLength(A)
cols = numVectors(B)
myMap = BlockMap(rows, cols, comm)
result = DenseMultiVector(myMap, Matrix{Float64}(undef, rows, cols))
result = DenseMultiVector(myMap, Matrix{Data}(undef, rows, cols))

result = A.*B
result = scale!(result, scalar)
Expand Down
10 changes: 5 additions & 5 deletions test/CSRMatrixTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,22 @@ fillComplete(mat)
@test (LID[1, 2], Data[2, 3]) == getLocalRowView(mat, 1)
@test (LID[1, 2], Data[5, 7]) == getLocalRowView(mat, 2)



#the following is commented out because the purpose is not clear - we have below a test for apply!
#=
Y = DenseMultiVector(map, Data[1 0; 0 2])
X = DenseMultiVector(map, Data[2 2; 2 2])

@test Y === !(Y, mat, X, NO_TRANS, Data(3), Data(.5))
@test Y === (Y, mat, X, NO_TRANS, Data(3), Data(.5))
@test Data[2 2; 2 2] == X.data #ensure X isn't mutated
@test Data[30.5 30; 72 73] == Y.data
=#


Y = DenseMultiVector(map, Data[1 0; 0 2])
X = DenseMultiVector(map, Data[2 2; 2 2]) #ensure bugs in the previous test don't affect this test

@test Y === apply!(Y, mat, X, TRANS, Float32(3), Float32(.5))
@test Y == apply!(Y, mat, X, TRANS, Float32(3), Float32(.5))

@test Data[2 2; 2 2] == X.data #ensure X isn't mutated
@test [42.5 42; 60 61] == Y.data
Expand Down

0 comments on commit 9bf5656

Please sign in to comment.