Skip to content

Commit

Permalink
Fix complex dense-perm matmul (#69)
Browse files Browse the repository at this point in the history
* fix complex dense-perm matmul

* bump version

* fix test

* fix nightly test
  • Loading branch information
GiggleLiu authored May 29, 2022
1 parent ffc1184 commit 961c04c
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 33 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "LuxurySparse"
uuid = "d05aeea4-b7d4-55ac-b691-9e7fabb07ba2"
authors = ["GiggleLiu <[email protected]>", "Roger-luo <[email protected]>"]
version = "0.6.11"
version = "0.6.12"

[deps]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
2 changes: 1 addition & 1 deletion src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
function *(X::AbstractMatrix, A::PermMatrix)
mX, nX = size(X)
nX == size(A, 1) || throw(DimensionMismatch())
return @views (A.vals'.*X)[:, fast_invperm(A.perm)]
return @views (transpose(A.vals) .* X)[:, fast_invperm(A.perm)]
end

# NOTE: this is just a temperory fix for v0.7. We should overload mul! in
Expand Down
6 changes: 6 additions & 0 deletions test/PermMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,10 @@ end
pm = PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3])
res = pm .* 3im
@test res == PermMatrix([3, 2, 4, 1], [0.2, 0.6, 0.1, 0.3] .* 3im) && res isa PermMatrix
end

@testset "fix dense-perm multiplication" begin
A = randn(ComplexF64, 4, 4)
pm = PermMatrix([3, 2, 4, 1], [0.2im, 0.6im, 0.1, 0.3])
@test A * pm A * Matrix(pm)
end
46 changes: 24 additions & 22 deletions test/kronecker.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
using Test, Random, SparseArrays, LinearAlgebra
import LuxurySparse: IMatrix, PermMatrix

Random.seed!(2)
@testset "kron" begin
Random.seed!(2)

p1 = IMatrix{4}()
sp = sprand(ComplexF64, 4, 4, 0.5)
ds = rand(ComplexF64, 4, 4)
pm = PermMatrix([2, 3, 4, 1], randn(4))
v = [0.5, 0.3im, 0.2, 1.0]
dv = Diagonal(v)
p1 = IMatrix{4}()
sp = sprand(ComplexF64, 4, 4, 0.5)
ds = rand(ComplexF64, 4, 4)
pm = PermMatrix([2, 3, 4, 1], randn(4))
pm = PermMatrix([2, 3, 4, 1], randn(4))
v = [0.5, 0.3im, 0.2, 1.0]
dv = Diagonal(v)


@testset "kron(::$(typeof(source)), ::$(typeof(target)))" for source in [p1, sp, ds, dv, pm],
target in [p1, sp, ds, dv, pm]
lres = kron(source, target)
rres = kron(target, source)
flres = kron(Matrix(source), Matrix(target))
frres = kron(Matrix(target), Matrix(source))
@test lres == flres
@test rres == frres
@test eltype(lres) == eltype(flres)
@test eltype(rres) == eltype(frres)
if !(target === ds && source === ds)
@test !(typeof(lres) <: StridedMatrix)
@test !(typeof(rres) <: StridedMatrix)
for source in Any[p1, sp, ds, dv, pm],
target in Any[p1, sp, ds, dv, pm]
lres = kron(source, target)
rres = kron(target, source)
flres = kron(Matrix(source), Matrix(target))
frres = kron(Matrix(target), Matrix(source))
@test lres == flres
@test rres == frres
@test eltype(lres) == eltype(flres)
@test eltype(rres) == eltype(frres)
if !(target === ds && source === ds)
@test !(typeof(lres) <: StridedMatrix)
@test !(typeof(rres) <: StridedMatrix)
end
end
end
end
18 changes: 9 additions & 9 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ dv = Diagonal(v)
@test logdet(p1) == 0
@test inv(pm) == inv(Matrix(pm))

for m in [pm, sp, p1, dv]
for m in Any[pm, sp, p1, dv]
@test !(m |> isdense)
@test !(m' |> isdense)
@test !(transpose(m) |> isdense)
end
for m in [ds, v]
for m in Any[ds, v]
@test m |> isdense
@test m' |> isdense
@test transpose(m) |> isdense
end
end

@testset "multiply" begin
for source_ in [p1, sp, ds, dv, pm]
for target in [p1, sp, ds, dv, pm]
for source in [source_, source_', transpose(source_)]
for source_ in Any[p1, sp, ds, dv, pm]
for target in Any[p1, sp, ds, dv, pm]
for source in Any[source_, source_', transpose(source_)]
lres = source * target
rres = target * source
flres = Matrix(source) * Matrix(target)
Expand Down Expand Up @@ -82,7 +82,7 @@ end
@testset "randn" begin
Random.seed!(2)
T = ComplexF64
for m in [sprand(T, 5, 5, 0.5)]
for m in Any[sprand(T, 5, 5, 0.5)]
zm = zero(m)
@test zm zeros(T, 5, 5)
if VERSION < v"1.4.0"
Expand All @@ -93,7 +93,7 @@ end
@test !(zm zeros(T, 5, 5))
end
end
for m in [pmrand(T, 5), Diagonal(randn(T, 5))]
for m in Any[pmrand(T, 5), Diagonal(randn(T, 5))]
zm = zero(m)
@test zm zeros(T, 5, 5)
rand!(zm)
Expand All @@ -112,8 +112,8 @@ end
end

@testset "findnz" begin
for m in [p1, sp, ds, dv, pm]
for _m in [m, staticize(m)]
for m in Any[p1, sp, ds, dv, pm]
for _m in Any[m, staticize(m)]
out = zeros(eltype(m), size(m)...)
for (i, j, v) in zip(LuxurySparse.findnz(_m)...)
out[i, j] = v
Expand Down

0 comments on commit 961c04c

Please sign in to comment.