diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 3a2459a4..185f58c4 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -84,8 +84,10 @@ end *(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b) *(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b) +*(a::Zeros{<:Any,2}, b::AbstractTriangular) = mult_zeros(a, b) *(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b) *(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b) +*(a::AbstractTriangular, b::Zeros{<:Any,2}) = mult_zeros(a, b) *(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b) *(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b) *(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b) @@ -95,65 +97,51 @@ end *(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b) *(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b) -# Cannot unify following methods for Diagonal -# due to ambiguity with general array mult. with fill -function *(a::Diagonal, b::FillMatrix) +# # Cannot unify following methods for Diagonal +# # due to ambiguity with general array mult. with fill +function *(a::Diagonal, b::AbstractFill{T,2}) where T size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) a.diag .* b # use special broadcast end -function *(a::FillMatrix, b::Diagonal) +function *(a::AbstractFill{T,2}, b::Diagonal) where T size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) a .* permutedims(b.diag) # use special broadcast end -function *(a::Diagonal, b::OnesMatrix) - size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) - a.diag .* b # use special broadcast -end -function *(a::OnesMatrix, b::Diagonal) - size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) - a .* permutedims(b.diag) # use special broadcast -end - -*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2)) -*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2)) -*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1)) -function *(x::AbstractMatrix, f::FillMatrix) +function mult_sum2(x::AbstractMatrix, f::AbstractFill{T,2}) where T axes(x, 2) ≠ axes(f, 1) && throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) m = size(f, 2) - repeat(sum(x, dims=2) * f.value, 1, m) + repeat(sum(x, dims=2) * getindex_value(f), 1, m) end -function *(f::FillMatrix, x::AbstractMatrix) +function mult_sum1(f::AbstractFill{T,2}, x::AbstractMatrix) where T axes(f, 2) ≠ axes(x, 1) && throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) m = size(f, 1) - repeat(sum(x, dims=1) * f.value, m, 1) + repeat(sum(x, dims=1) * getindex_value(f), m, 1) end -function *(x::AbstractMatrix, f::OnesMatrix) - axes(x, 2) ≠ axes(f, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - m = size(f, 2) - repeat(sum(x, dims=2) * one(eltype(f)), 1, m) -end +*(x::AbstractMatrix, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) +*(x::AbstractTriangular, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) -function *(f::OnesMatrix, x::AbstractMatrix) - axes(f, 2) ≠ axes(x, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - m = size(f, 1) - repeat(sum(x, dims=1) * one(eltype(f)), m, 1) -end +# *(x::Diagonal, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) +# *(x::Transpose{T,AbstractMatrix{T}}, y::AbstractFill{<:Any,2}) where T = mult_sum2(x, y) +*(x::AbstractFill{<:Any,2}, y::AbstractMatrix) = mult_sum1(x, y) +*(x::AbstractFill{<:Any,2}, y::AbstractTriangular) = mult_sum1(x, y) -*(x::FillMatrix, y::FillMatrix) = mult_fill(x, y) -*(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y) -*(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y) -*(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y) -*(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y) -*(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y) -*(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y) -*(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y) +# *(x::AbstractFill{<:Any,2}, y::Diagonal) = mult_sum1(x, y) +# *(x::AbstractFill{<:Any,2}, y::Transpose{T,AbstractMatrix{T}}) where T = mult_sum1(x, y) + + +# *(x::FillMatrix, y::FillMatrix) = mult_fill(x, y) +# *(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y) +# *(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y) +# *(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y) +# *(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y) +# *(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y) +# *(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y) +# *(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y) # function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T # fB = similar(parent(a), size(b, 1), size(b, 2)) @@ -173,6 +161,16 @@ end # return a*fB # end +## Matrix-Vector multiplication + +*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = + reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2)) +*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = + reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2)) +*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = + reshape(sum(a; dims=2) .* b.value, size(a, 1)) + + function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S} la, lb = length(a), length(b) if la ≠ lb diff --git a/test/runtests.jl b/test/runtests.jl index 74f01d69..d8437c05 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1015,6 +1015,9 @@ end @test D*Zeros(1) ≡ Zeros(1) D = Diagonal(Fill(2,10)) + # @show D * Ones(10) + # @show D * Ones(10,5) + # @show Ones(5,10) * D @test D * Ones(10) ≡ Fill(2.0,10) @test D * Ones(10,5) ≡ Fill(2.0,10,5) @test Ones(5,10) * D ≡ Fill(2.0,5,10) @@ -1028,6 +1031,19 @@ end @test E*(1:5) ≡ 1.0:5.0 @test (1:5)'E == (1.0:5)' @test E*E ≡ E + + # Adjoint / Transpose / Triangular / Symmetric + for x in [transpose(rand(2, 2)), + adjoint(rand(2,2)), + UpperTriangular(rand(2,2)), + Symmetric(rand(2,2))] + @test x * Ones(2, 2) isa Matrix + @test Ones(2, 2) * x isa Matrix + @test x * Zeros(2, 2) isa Zeros + @test Zeros(2, 2) * x isa Zeros + @test x * Fill(1., 2, 2) isa Matrix + @test Fill(1., 2, 2) * x isa Matrix + end end @testset "count" begin