Skip to content

Commit

Permalink
CSRMatrix power
Browse files Browse the repository at this point in the history
functions needed to implement power method for CSRMatrices
  • Loading branch information
fsmith001 committed Aug 3, 2022
1 parent 0dd93af commit 975f7a8
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 18 deletions.
91 changes: 79 additions & 12 deletions src/CSRMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1127,44 +1127,114 @@ function localApply(Y::MultiVector{Data, GID, PID, LID},
rawY = getLocalArray(Y)
rawX = getLocalArray(X)


#TODO implement this better, can BLAS be used?
#TODO implement this better
if !isTransposed(mode)
numRows = getLocalNumRows(A)
for vect = LID(1):numVectors(Y)
for row = LID(1):numRows
sum::Data = Data(0)
@inbounds (indices, values, len) = getLocalRowViewPtr(A, row)
#@inbounds (indices, values, len) = getLocalRowViewPtr(A, row)
@boundscheck((indices, values, len) = getLocalRowViewPtr(A, row))
for i in LID(1):LID(len)
ind::LID = unsafe_load(indices, i)
val::Data = unsafe_load(values, i)
#@inbounds sum += val*rawX[ind, vect]
sum += val*rawX[ind, vect]
@boundscheck (sum += val*rawX[ind, vect])
end
sum = applyConjugation(mode, sum*alpha)
#@inbounds rawY[row, vect] *= beta
rawY[row, vect] *= beta
@boundscheck (rawY[row, vect] *= beta)
#@inbounds rawY[row, vect] += sum
rawY[row, vect] += sum
@boundscheck (rawY[row, vect] += sum)
end
end
else
rawY[:, :] *= beta
numRows = getLocalNumRows(A)
for vect = LID(1):numVectors(Y)
for mRow in LID(1):numRows
@inbounds (indices, values, len) = getLocalRowViewPtr(A, mRow)
#@inbounds (indices, values, len) = getLocalRowViewPtr(A, mRow)
@boundscheck ((indices, values, len) = getLocalRowViewPtr(A, mRow))
for i in LID(1):LID(len)
ind::LID = unsafe_load(indices, i)
val::Data = unsafe_load(values, i)
@inbounds rawY[ind, vect] += applyConjugation(mode, alpha*rawX[mRow, vect]*val)
#@inbounds rawY[ind, vect] += applyConjugation(mode, alpha*rawX[mRow, vect]*val)
@boundscheck (rawY[ind, vect] += applyConjugation(mode, alpha*rawX[mRow, vect]*val))
end
end
end
end
Y
end

"""
matVecMult(Y::MultiVector, A::CSRMatrix, X::MultiVector)
Finds the product of matrix-vector multiplication of A and X
"""
function matVecMult(Y::MultiVector{Data, GID, PID, LID}, A::CSRMatrix{Data, GID, PID, LID}, X::MultiVector{Data, GID, PID, LID})

rawY = getLocalArray(Y)
rawX = getLocalArray(X)

m = getLocalNumRows(A)
n = localLength(X)

for row = LID(1):m
thisSum = 0
(inds, vals) = getGlobalRowView(A, row)
for col = LID(1):n
thisSum = thisSum + (vals[col] * rawX[col,1])
end
rawY[row,1] = thisSum
end
return Y
end

"""
power1(A::CSRMatrix, niter::Integer, tol)
Power method on a CSRMatrix to solve for the dominant eigenvalue and eigenvector
"""
function power(A::CSRMatrix{}, niter::Integer, tol)
rows = getNumEntriesInLocalRow(A, 1)
comm = SerialComm{Int, Int, Int}()
Data = Float64
map = BlockMap(rows, rows, comm)
x = DenseMultiVector(map, Matrix{Float64}(ones(rows,1)))
nold = undef
n = undef

y = DenseMultiVector(map, Matrix{Float64}(undef, rows, rows))
alpha = Data(1)
beta = Data(1)

for k = 1:niter
y = matVecMult(y, A, x)
n = y[1,1]
for i = 2:numVectors(y)
for j = 2:localLength(y)
if abs(y[i,1]) > n
n = y[i,1]
end
end
end

if k > 1
if abs(nold - n) <= tol
break
end
end

for j = 1:rows
x[j, 1] = y[j, 1]/n
end
nold = n
end
return x, n
end


"""
invRowMax(A::CSRMatrix{})
Expand Down Expand Up @@ -1288,7 +1358,4 @@ function invColSum(A::CSRMatrix{})
inv[i] = 1/sum
end
return inv
end


#### TODO: Computational methods####
end
12 changes: 6 additions & 6 deletions src/RowMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -667,12 +667,12 @@ Base.axes(A::RowMatrix{GID}) where GID = if hasColMap(A)

function Base.getindex(A::RowMatrix, I::Vararg{Int, 2})
if isGloballyIndexed(A)
# @boundscheck begin
# (n, m) = size(A)
# if I[1] > n || I[1] < 1 || I[2] > m || I[2] < 1
# throw(BoundsError(A, I))
# end
#end
@boundscheck begin
(n, m) = size(A)
if I[1] > n || I[1] < 1 || I[2] > m || I[2] < 1
throw(BoundsError(A, I))
end
end
(rowInds, rowVals) = getGlobalRowView(A, I[0])
for i in 1:length(rowInds)
if rowInds[i] == I[1]
Expand Down

0 comments on commit 975f7a8

Please sign in to comment.