Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 4, 2020
1 parent ba49d16 commit 2915481
Showing 1 changed file with 70 additions and 14 deletions.
84 changes: 70 additions & 14 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
const FillVector{F,A} = Fill{F,1,A}
const FillMatrix{F,A} = Fill{F,2,A}
const OnesVector{F,A} = Ones{F,1,A}
const OnesMatrix{F,A} = Ones{F,2,A}
const ZerosVector{F,A} = Zeros{F,1,A}
const ZerosMatrix{F,A} = Zeros{F,2,A}

## vec

vec(a::Ones{T}) where T = Ones{T}(length(a))
Expand Down Expand Up @@ -87,11 +94,22 @@ end
*(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b)
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
function *(a::Diagonal, b::AbstractFill{<:Any,2})

# Cannot unify following methods for Diagonal
# due to ambiguity with general array mult. with fill
function *(a::Diagonal, b::FillMatrix)
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a.diag .* b # use special broadcast
end
function *(a::FillMatrix, b::Diagonal)
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a .* permutedims(b.diag) # use special broadcast
end
function *(a::Diagonal, b::OnesMatrix)
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a.diag .* b # use special broadcast
end
function *(a::AbstractFill{<:Any,2}, b::Diagonal)
function *(a::OnesMatrix, b::Diagonal)
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
a .* permutedims(b.diag) # use special broadcast
end
Expand All @@ -100,23 +118,61 @@ end
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1))

function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
fB = similar(parent(a), size(b, 1), size(b, 2))
fill!(fB, b.value)
return a*fB
function *(x::AbstractMatrix, f::FillMatrix)
axes(x, 2) axes(f, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 2)
repeat(sum(x, dims=2) * f.value, 1, m)
end

function *(f::FillMatrix, x::AbstractMatrix)
axes(f, 2) axes(x, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 1)
repeat(sum(x, dims=1) * f.value, m, 1)
end

function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
fB = similar(parent(a), size(b, 1), size(b, 2))
fill!(fB, b.value)
return a*fB
function *(x::AbstractMatrix, f::OnesMatrix)
axes(x, 2) axes(f, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 2)
repeat(sum(x, dims=2) * one(eltype(f)), 1, m)
end

function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
fB = similar(a, size(b, 1), size(b, 2))
fill!(fB, b.value)
return a*fB
function *(f::OnesMatrix, x::AbstractMatrix)
axes(f, 2) axes(x, 1) &&
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
m = size(f, 1)
repeat(sum(x, dims=1) * one(eltype(f)), m, 1)
end

*(x::FillMatrix, y::FillMatrix) = mult_fill(x, y)
*(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y)
*(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y)
*(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y)
*(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y)
*(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y)
*(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y)
*(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y)

# function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
# fB = similar(parent(a), size(b, 1), size(b, 2))
# fill!(fB, b.value)
# return a*fB
# end

# function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
# fB = similar(parent(a), size(b, 1), size(b, 2))
# fill!(fB, b.value)
# return a*fB
# end

# function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
# fB = similar(a, size(b, 1), size(b, 2))
# fill!(fB, b.value)
# return a*fB
# end

function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
la, lb = length(a), length(b)
if la lb
Expand Down

0 comments on commit 2915481

Please sign in to comment.