Skip to content

Commit

Permalink
Fixed the way pqrfact was being called. Improved orthonormalize_gener…
Browse files Browse the repository at this point in the history
…ators! to use the rank-revealing QR of LowRankApprox.jl.
  • Loading branch information
bonevbs committed Mar 30, 2021
1 parent 5c1b551 commit ba9c64d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 61 deletions.
102 changes: 75 additions & 27 deletions src/generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,46 +59,94 @@ end

function orthonormalize_generators!(hssA::HssNode{T}) where T
if isleaf(hssA.A11)
U1 = qr(hssA.A11.U); hssA.A11.U = Matrix(U1.Q)
V1 = qr(hssA.A11.V); hssA.A11.V = Matrix(V1.Q)
U1 = pqrfact(hssA.A11.U, sketch=:none); hssA.A11.U = Matrix(U1.Q)
V1 = pqrfact(hssA.A11.V, sketch=:none); hssA.A11.V = Matrix(V1.Q)
else
orthonormalize_generators!(hssA.A11)
U1 = qr([hssA.A11.R1; hssA.A11.R2])
V1 = qr([hssA.A11.W1; hssA.A11.W2])
U1 = pqrfact([hssA.A11.R1; hssA.A11.R2], sketch=:none)
V1 = pqrfact([hssA.A11.W1; hssA.A11.W2], sketch=:none)
rm1 = size(hssA.A11.R1, 1)
R = Matrix(U1.Q)
hssA.A11.R1 = R[1:rm1,:]
hssA.A11.R2 = R[rm1+1:end,:]
#R = Matrix(U1.Q)
hssA.A11.R1 = U1.Q[1:rm1,:]
hssA.A11.R2 = U1.Q[rm1+1:end,:]
rn1 = size(hssA.A11.W1, 1)
W = Matrix(V1.Q)
hssA.A11.W1 = W[1:rn1,:]
hssA.A11.W2 = W[rn1+1:end,:]
#W = Matrix(V1.Q)
hssA.A11.W1 = V1.Q[1:rn1,:]
hssA.A11.W2 = V1.Q[rn1+1:end,:]
end

if isleaf(hssA.A22)
U2 = qr(hssA.A22.U); hssA.A22.U = Matrix(U2.Q)
V2 = qr(hssA.A22.V); hssA.A22.V = Matrix(V2.Q)
U2 = pqrfact(hssA.A22.U, sketch=:none); hssA.A22.U = Matrix(U2.Q)
V2 = pqrfact(hssA.A22.V, sketch=:none); hssA.A22.V = Matrix(V2.Q)
else
orthonormalize_generators!(hssA.A22)
U2 = qr([hssA.A22.R1; hssA.A22.R2])
V2 = qr([hssA.A22.W1; hssA.A22.W2])
U2 = pqrfact([hssA.A22.R1; hssA.A22.R2], sketch=:none)
V2 = pqrfact([hssA.A22.W1; hssA.A22.W2], sketch=:none)
rm1 = size(hssA.A22.R1, 1)
R = Matrix(U2.Q)
hssA.A22.R1 = R[1:rm1,:]
hssA.A22.R2 = R[rm1+1:end,:]
#R = Matrix(U2.Q)
hssA.A22.R1 = U2.Q[1:rm1,:]
hssA.A22.R2 = U2.Q[rm1+1:end,:]
rn1 = size(hssA.A22.W1, 1)
W = Matrix(V2.Q)
hssA.A22.W1 = W[1:rn1,:]
hssA.A22.W2 = W[rn1+1:end,:]
#W = Matrix(V2.Q)
hssA.A22.W1 = V2.Q[1:rn1,:]
hssA.A22.W2 = V2.Q[rn1+1:end,:]
end
ipU1 = invperm(U1.p); ipV1 = invperm(V1.p)
ipU2 = invperm(U2.p); ipV2 = invperm(V2.p)

hssA.B12 = U1.R*hssA.B12*V2.R'
hssA.B21 = U2.R*hssA.B21*V1.R'
hssA.B12 = U1.R[:, ipU1]*hssA.B12*V2.R[:, ipV2]'
hssA.B21 = U2.R[:, ipU2]*hssA.B21*V1.R[:, ipV1]'

hssA.R1 = U1.R*hssA.R1
hssA.R2 = U2.R*hssA.R2
hssA.W1 = V1.R*hssA.W1
hssA.W2 = V2.R*hssA.W2
hssA.R1 = U1.R[:, ipU1]*hssA.R1
hssA.R2 = U2.R[:, ipU2]*hssA.R2
hssA.W1 = V1.R[:, ipV1]*hssA.W1
hssA.W2 = V2.R[:, ipV2]*hssA.W2

return hssA
end
end

# function orthonormalize_generators!(hssA::HssNode{T}) where T
# if isleaf(hssA.A11)
# U1 = qr!(hssA.A11.U); hssA.A11.U = Matrix(U1.Q)
# V1 = qr!(hssA.A11.V); hssA.A11.V = Matrix(V1.Q)
# else
# orthonormalize_generators!(hssA.A11)
# U1 = qr!([hssA.A11.R1; hssA.A11.R2])
# V1 = qr!([hssA.A11.W1; hssA.A11.W2])
# rm1 = size(hssA.A11.R1, 1)
# R = Matrix(U1.Q)
# hssA.A11.R1 = R[1:rm1,:]
# hssA.A11.R2 = R[rm1+1:end,:]
# rn1 = size(hssA.A11.W1, 1)
# W = Matrix(V1.Q)
# hssA.A11.W1 = W[1:rn1,:]
# hssA.A11.W2 = W[rn1+1:end,:]
# end

# if isleaf(hssA.A22)
# U2 = qr!(hssA.A22.U); hssA.A22.U = Matrix(U2.Q)
# V2 = qr!(hssA.A22.V); hssA.A22.V = Matrix(V2.Q)
# else
# orthonormalize_generators!(hssA.A22)
# U2 = qr!([hssA.A22.R1; hssA.A22.R2])
# V2 = qr!([hssA.A22.W1; hssA.A22.W2])
# rm1 = size(hssA.A22.R1, 1)
# R = Matrix(U2.Q)
# hssA.A22.R1 = R[1:rm1,:]
# hssA.A22.R2 = R[rm1+1:end,:]
# rn1 = size(hssA.A22.W1, 1)
# W = Matrix(V2.Q)
# hssA.A22.W1 = W[1:rn1,:]
# hssA.A22.W2 = W[rn1+1:end,:]
# end

# hssA.B12 = U1.R*hssA.B12*V2.R'
# hssA.B21 = U2.R*hssA.B21*V1.R'

# hssA.R1 = U1.R*hssA.R1
# hssA.R2 = U2.R*hssA.R2
# hssA.W1 = V1.R*hssA.W1
# hssA.W2 = V2.R*hssA.W2

# return hssA
# end
5 changes: 3 additions & 2 deletions src/hssmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,9 @@ transpose(hssA::HssLeaf) = HssLeaf(copy(transpose(hssA.D)), copy(hssA.V), copy(h
transpose(hssA::HssNode) = HssNode(transpose(hssA.A11), transpose(hssA.A22), copy(transpose(hssA.B21)), copy(transpose(hssA.B12)), copy(hssA.W1), copy(hssA.R1), copy(hssA.W2), copy(hssA.R2))

# Define Matlab-like convenience functions, which are used throughout the library
blkdiagm(A::Matrix, B::Matrix) = [A zeros(size(A,1), size(B,2)); zeros(size(B,1), size(A,2)) B]
blkdiagm(A::Matrix... ) = blkdiagm(A[1], blkdiagm(A[2:end]...))
#blkdiagm(A::Matrix, B::Matrix) = [A zeros(size(A,1), size(B,2)); zeros(size(B,1), size(A,2)) B]
#blkdiagm(A::Matrix... ) = blkdiagm(A[1], blkdiagm(A[2:end]...))
blkdiagm(A::Matrix...) = cat(A[1:end]..., dims=(1,2))

## basic algebraic operations (taken and modified from LowRankApprox.jl)
for op in (:+,:-)
Expand Down
2 changes: 1 addition & 1 deletion src/prrqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function _compress_block!(A::AbstractMatrix{T}, atol::Float64, rtol::Float64) wh
#rk = min(size(R)...)
#return Q[:,1:rk], R[1:rk, invperm(p)]
# temporarily using prrqr of LowRankApprox.jl
F = pqrfact(A; atol = atol, rtol = rtol)
F = pqrfact(A; atol = atol, rtol = rtol, sketch=:none)
rk = min(size(F.R)...)
return F.Q[:,1:rk], F.R[1:rk, invperm(F.p)]
end
Expand Down
6 changes: 3 additions & 3 deletions src/visualization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function plotranks(hssA::HssMatrix; cutoff_level=3)
aspect = m/n
plot(yflip=true, showaxis=false, size = (400, 400*aspect))
xticks = [1]; yticks = [1]
yticks, xticks = _plotranks!(hssA,1,1,xticks,yticks,0,cutoff_level)
yticks, xticks = _plotranks!(hssA,1,1,xticks,yticks,1,cutoff_level)
append!(yticks, n+1)
append!(xticks, m+1)
plot!(aspect_ratio=:equal)
Expand Down Expand Up @@ -38,9 +38,9 @@ function _plotranks!(hssA::HssMatrix, co, ro, cticks, rticks, level, cl)
_plotranks!(hssA.A22, co+n1, ro+m1, cticks, rticks, level+1, cl)
# plot off-diagonal blocks
plot!(_rectangle(n2, m1, co+n1, ro), color=:aliceblue, label=false)
if 0 < level cl annotate!((co+n1+0.5*n2, ro+0.5*m1, text(max(size(hssA.B12)...), 8))) end
if level cl annotate!((co+n1+0.5*n2, ro+0.5*m1, text(max(size(hssA.B12)...), 8))) end
plot!(_rectangle(n1, m2, co, ro+m1), color=:aliceblue, label=false)
if 0 < level cl annotate!((co+0.5*n1, ro+m1+0.5*m2, text(max(size(hssA.B21)...), 8))) end
if level cl annotate!((co+0.5*n1, ro+m1+0.5*m2, text(max(size(hssA.B21)...), 8))) end
end
return rticks, cticks
end
Expand Down
73 changes: 45 additions & 28 deletions test/ulvtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,62 @@ using .HssMatrices
using LinearAlgebra
using AbstractTrees
using Plots
using BenchmarkTools

K(x,y) = (x-y) != 0 ? 1/(x-y) : 10000.
A = [ K(x,y) for x=-1:0.001:1, y=-1:0.001:1];
b = randn(size(A,2), 5);
n = 1000
K(x,y) = (x-y) != 0 ? 1/(x-y) : 1000.
A = [ K(x,y) for x=0:1/(n-1):1, y=0:1/(n-1):1];
b = randn(n, 1);

# test the simple implementation of cluster trees
m, n = size(A)
lsz = 64;
rcl = bisection_cluster(1:m, leafsize=lsz)
ccl = bisection_cluster(1:n, leafsize=lsz)

# test ULV normally
hssA = compress(A, rcl, ccl);
x = ulvfactsolve(hssA, b);
xcor = A\b;
println(norm(x-xcor)/norm(xcor))

# test on schewed cluster trees
lsz = 701;
rcl = bisection_cluster(1:m, leafsize=lsz)
ccl = bisection_cluster(1:n, leafsize=lsz)
rcl.left.left.data = 1:700
rcl.left.right.data = 701:1001
#print_tree(rcl)
hssA = compress(A, rcl, ccl);
x = ulvfactsolve(hssA, b);
xcor = A\b;
println(norm(x-xcor)/norm(xcor))
#hssA\hssA

n = 20000
K(x,y) = (x-y) != 0 ? 1/(x-y) : 1000.
A = [ K(x,y) for x=0:1/(n-1):1, y=0:1/(n-1):1];
b = randn(n, 1);

# test on schewed cluster trees
lsz = 701;
# test the simple implementation of cluster trees
m, n = size(A)
lsz = 64;
rcl = bisection_cluster(1:m, leafsize=lsz)
ccl = bisection_cluster(1:n, leafsize=lsz)
ccl.left.left.data = 1:700
ccl.left.right.data = 701:1001
#print_tree(ccl)
hssA = compress(A, rcl, ccl);
x = ulvfactsolve(hssA, b);
xcor = A\b;
println(norm(x-xcor)/norm(xcor))

@btime hssA\hssA

# # test ULV normally
# x = ulvfactsolve(hssA, b);
# xcor = A\b;
# println(norm(x-xcor)/norm(xcor))

# # test on schewed cluster trees
# lsz = 701;
# rcl = bisection_cluster(1:m, leafsize=lsz)
# ccl = bisection_cluster(1:n, leafsize=lsz)
# rcl.left.left.data = 1:700
# rcl.left.right.data = 701:1001
# #print_tree(rcl)
# hssA = compress(A, rcl, ccl);
# x = ulvfactsolve(hssA, b);
# xcor = A\b;
# println(norm(x-xcor)/norm(xcor))

# # test on schewed cluster trees
# lsz = 701;
# rcl = bisection_cluster(1:m, leafsize=lsz)
# ccl = bisection_cluster(1:n, leafsize=lsz)
# ccl.left.left.data = 1:700
# ccl.left.right.data = 701:1001
# #print_tree(ccl)
# hssA = compress(A, rcl, ccl);
# x = ulvfactsolve(hssA, b);
# xcor = A\b;
# println(norm(x-xcor)/norm(xcor))

plot = plotranks(hssA)

2 comments on commit ba9c64d

@bonevbs
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Performance improvements due to the correct use of pqrfact
  • Bugfixes in cluster
  • Bugfixes in plotranks

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33161

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.1 -m "<description of version>" ba9c64d11fdf6fb8c91c5e55550a30d3ed949267
git push origin v0.1.1

Please sign in to comment.