diff --git a/src/kronecker.jl b/src/kronecker.jl index 17fa1d0..dcaaee3 100644 --- a/src/kronecker.jl +++ b/src/kronecker.jl @@ -36,14 +36,14 @@ for MT in [:AbstractMatrix, :PermMatrix, :SparseMatrixCSC, :Diagonal] end ####### diagonal kron ######## -kron(A::Diagonal, B::Diagonal) = Diagonal(kron(A.diag, B.diag)) -kron(A::StridedMatrix, B::Diagonal) = kron(A, PermMatrix(B)) -kron(A::Diagonal, B::StridedMatrix) = kron(PermMatrix(A), B) -kron(A::Diagonal, B::SparseMatrixCSC) = kron(PermMatrix(A), B) -kron(A::SparseMatrixCSC, B::Diagonal) = kron(A, PermMatrix(B)) +kron(A::Diagonal{<:Number}, B::Diagonal{<:Number}) = Diagonal(kron(A.diag, B.diag)) +kron(A::StridedMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) +kron(A::Diagonal{<:Number}, B::StridedMatrix{<:Number}) = kron(PermMatrix(A), B) +kron(A::Diagonal{<:Number}, B::SparseMatrixCSC{<:Number}) = kron(PermMatrix(A), B) +kron(A::SparseMatrixCSC{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) -function kron(A::AbstractMatrix{Tv}, B::IMatrix{Nb}) where {Nb, Tv} +function kron(A::AbstractMatrix{Tv}, B::IMatrix{Nb}) where {Nb, Tv<:Number} mA, nA = size(A) nzval = Vector{Tv}(undef, Nb*mA*nA) rowval = Vector{Int}(undef, Nb*mA*nA) @@ -63,7 +63,7 @@ function kron(A::AbstractMatrix{Tv}, B::IMatrix{Nb}) where {Nb, Tv} SparseMatrixCSC(mA*Nb, nA*Nb, colptr, rowval, nzval) end -function kron(A::IMatrix{Na}, B::AbstractMatrix{Tv}) where {Na, Tv} +function kron(A::IMatrix{Na}, B::AbstractMatrix{Tv}) where {Na, Tv<:Number} mB, nB = size(B) rowval = Vector{Int}(undef, nB*mB*Na) nzval = Vector{Tv}(undef, nB*mB*Na) @@ -81,7 +81,7 @@ function kron(A::IMatrix{Na}, B::AbstractMatrix{Tv}) where {Na, Tv} SparseMatrixCSC(mB*Na, Na*nB, colptr, rowval, nzval) end -function kron(A::IMatrix{Na}, B::SparseMatrixCSC{T}) where {Na, T} +function kron(A::IMatrix{Na}, B::SparseMatrixCSC{T}) where {Na, T<:Number} mB, nB = size(B) nV = nnz(B) nzval = Vector{T}(undef, Na*nV) @@ -104,7 +104,7 @@ function kron(A::IMatrix{Na}, B::SparseMatrixCSC{T}) where {Na, T} SparseMatrixCSC(mB*Na, nB*Na, colptr, rowval, nzval) end -function kron(A::SparseMatrixCSC{T}, B::IMatrix{Nb}) where {T, Nb} +function kron(A::SparseMatrixCSC{T}, B::IMatrix{Nb}) where {T<:Number, Nb} mA, nA = size(A) nV = nnz(A) rowval = Vector{Int}(undef, Nb*nV) @@ -129,7 +129,7 @@ function kron(A::SparseMatrixCSC{T}, B::IMatrix{Nb}) where {T, Nb} SparseMatrixCSC(mA*Nb, nA*Nb, colptr, rowval, nzval) end -function kron(A::PermMatrix{T}, B::IMatrix) where T +function kron(A::PermMatrix{T}, B::IMatrix) where T<:Number nA = size(A, 1) nB = size(B, 1) vals = Vector{T}(undef, nB*nA) @@ -146,7 +146,7 @@ function kron(A::PermMatrix{T}, B::IMatrix) where T PermMatrix(perm, vals) end -function kron(A::IMatrix, B::PermMatrix{Tv, Ti}) where {Tv, Ti <: Integer} +function kron(A::IMatrix, B::PermMatrix{Tv, Ti}) where {Tv<:Number, Ti <: Integer} nA = size(A, 1) nB = size(B, 1) perm = Vector{Int}(undef, nB*nA) @@ -162,7 +162,7 @@ function kron(A::IMatrix, B::PermMatrix{Tv, Ti}) where {Tv, Ti <: Integer} end -function kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv, Tb} +function kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv<:Number, Tb<:Number} mA, nA = size(A) nB = size(B, 1) perm = fast_invperm(B.perm) @@ -186,7 +186,7 @@ function kron(A::StridedMatrix{Tv}, B::PermMatrix{Tb}) where {Tv, Tb} SparseMatrixCSC(mA*nB, nA*nB, colptr, rowval, nzval) end -function kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb, Ta} +function kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb<:Number, Ta<:Number} mB, nB = size(B) nA = size(A, 1) perm = fast_invperm(A.perm) @@ -210,7 +210,7 @@ function kron(A::PermMatrix{Ta}, B::StridedMatrix{Tb}) where {Tb, Ta} SparseMatrixCSC(nA*mB, nA*nB, colptr, rowval, nzval) end -function kron(A::PermMatrix, B::PermMatrix) +function kron(A::PermMatrix{<:Number}, B::PermMatrix{<:Number}) nA = size(A, 1) nB = size(B, 1) vals = kron(A.vals, B.vals) @@ -225,10 +225,10 @@ function kron(A::PermMatrix, B::PermMatrix) PermMatrix(perm, vals) end -kron(A::PermMatrix, B::Diagonal) = kron(A, PermMatrix(B)) -kron(A::Diagonal, B::PermMatrix) = kron(PermMatrix(A), B) +kron(A::PermMatrix{<:Number}, B::Diagonal{<:Number}) = kron(A, PermMatrix(B)) +kron(A::Diagonal{<:Number}, B::PermMatrix{<:Number}) = kron(PermMatrix(A), B) -function kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta, Tb} +function kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta<:Number, Tb<:Number} nA = size(A, 1) mB, nB = size(B) nV = nnz(B) @@ -254,7 +254,7 @@ function kron(A::PermMatrix{Ta}, B::SparseMatrixCSC{Tb}) where {Ta, Tb} SparseMatrixCSC(mB*nA, nB*nA, colptr, rowval, nzval) end -function kron(A::SparseMatrixCSC{T}, B::PermMatrix{Tb}) where {T, Tb} +function kron(A::SparseMatrixCSC{T}, B::PermMatrix{Tb}) where {T<:Number, Tb<:Number} nB = size(B, 1) mA, nA = size(A) nV = nnz(A)