-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathLocalCSRMatrix.jl
78 lines (61 loc) · 2.45 KB
/
LocalCSRMatrix.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
export LocalCSRMatrix, numRows, numCols, getRowView
struct LocalCSRMatrix{Data, IndexType <: Integer}
graph::LocalCSRGraph{IndexType, IndexType}
values::Union{Vector{Data}, SubArray{Data, 1, Vector{Data}, Tuple{UnitRange{IndexType}}, true}}
numCols::IndexType
end
"""
LocalCSRMatrix{Data, IndexType}()
Creates an empty LocalCSRMatrix
"""
function LocalCSRMatrix{Data, IndexType}() where {Data, IndexType}
LocalCSRMatrix(LocalCSRGraph{IndexType, IndexType}(), Data[], IndexType(0))
end
"""
LocalCSRMatrix(nRows::Integer, nCols::Integer, vals::AbstractArray{Data, 1}, rows::AbstractArray{IndexType, 1}, cols::AbstractArray{IndexType, 1}}
Creates the specified LocalCSRMatrix
"""
function LocalCSRMatrix(nRows::Integer, nCols::Integer,
vals::AbstractArray{Data, 1}, rows::AbstractArray{IndexType, 1},
cols::AbstractArray{IndexType, 1}) where {Data, IndexType}
if length(rows) != nRows + 1
throw(InvalidArgumentError("length(rows) = $(length(rows)) != nRows+1 "
* "= $(nRows + 1)"))
end
LocalCSRMatrix(LocalCSRGraph(cols, rows), vals, IndexType(nCols))
end
"""
LocalCSRMatrix(numCols::IndexType, values::AbstractArray{Data, 1}, localGraph::LocalCSRGraph{IndexType, IndexType}) where {IndexType, Data <: Number}
Creates the specified LocalCSRMatrix
"""
function LocalCSRMatrix(numCols::IndexType, values::AbstractArray{Data, 1},
localGraph::LocalCSRGraph{IndexType, IndexType}) where {IndexType, Data <: Number}
if numCols < 0
throw(InvalidArgumentError("Cannot have a negative number of rows"))
end
LocalCSRMatrix(localGraph, values, numCols)
end
"""
numRows(::LocalCSRMatrix{Data, IndexType})::IndexType
Gets the number of rows in the matrix
"""
numRows(matrix::LocalCSRMatrix) = numRows(matrix.graph)
"""
numCols(::LocalCSRMatrix{Data, IndexType})::IndexType
Gets the number of columns in the matrix
"""
numCols(matrix::LocalCSRMatrix) = matrix.numCols
"""
getRowView((matrix::LocalCSRMatrix{Data, IndexType}, row::Integer)::SparseRowView{Data, IndexType}
Gets a view of the requested row
"""
function getRowView(matrix::LocalCSRMatrix{Data, IndexType},
row::Integer)::SparseRowView{Data, IndexType} where {Data, IndexType}
start = matrix.graph.rowMap[row]
count = matrix.graph.rowMap[row+1] - start
if count == 0
SparseRowView(Data[], IndexType[])
else
SparseRowView(matrix.values, matrix.graph.entries, count, start)
end
end