Skip to content

Commit

Permalink
Merge pull request #105 from vpuri3/methods
Browse files Browse the repository at this point in the history
some missing methods
  • Loading branch information
ChrisRackauckas authored Sep 24, 2022
2 parents 318cc97 + 3f36ae2 commit f703846
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 35 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLOperators"
uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
authors = ["xtalax <[email protected]>"]
version = "0.1.13"
version = "0.1.14"

[deps]
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
Expand Down
4 changes: 2 additions & 2 deletions src/SciMLOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Base: +, -, *, /, \, ∘, ==, conj, exp, kron
import Base: iszero, inv, adjoint, transpose, size, convert
import LinearAlgebra: mul!, ldiv!, lmul!, rmul!, factorize
import LinearAlgebra: Matrix, Diagonal
import SparseArrays: sparse
import SparseArrays: sparse, issparse

"""
$(TYPEDEF)
Expand All @@ -40,8 +40,8 @@ include("left.jl")
include("multidim.jl")

include("scalar.jl")
include("basic.jl")
include("matrix.jl")
include("basic.jl")
include("batch.jl")
include("func.jl")
include("tensor.jl")
Expand Down
19 changes: 14 additions & 5 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ end
Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L)
Base.:+(L::AbstractSciMLOperator) = L

Base.convert(::Type{AbstractMatrix}, L::ScaledOperator) = L.λ.val * convert(AbstractMatrix, L.L)
Base.convert(::Type{AbstractMatrix}, L::ScaledOperator) = convert(Number,L.λ) * convert(AbstractMatrix, L.L)
SparseArrays.sparse(L::ScaledOperator) = L.λ * sparse(L.L)

# traits
Expand Down Expand Up @@ -311,7 +311,14 @@ end
AddedOperator(L::AbstractSciMLOperator) = L

# constructors
Base.:+(ops::AbstractSciMLOperator...) = AddedOperator(ops...)
function Base.:+(ops::Union{AbstractSciMLOperator, AbstractMatrix}...)
ops_ = ()
for op in ops
op = op isa AbstractMatrix ? MatrixOperator(op) : op
ops_ = (ops_..., op)
end
AddedOperator(ops_...)
end
Base.:+(A::AbstractSciMLOperator, B::AddedOperator) = AddedOperator(A, B.ops...)
Base.:+(A::AddedOperator, B::AbstractSciMLOperator) = AddedOperator(A.ops..., B)
Base.:+(A::AddedOperator, B::AddedOperator) = AddedOperator(A.ops..., B.ops...)
Expand All @@ -327,6 +334,8 @@ function Base.:+(Z::NullOperator, A::AddedOperator)
end

Base.:-(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = AddedOperator(A, -B)
Base.:-(A::AbstractSciMLOperator, B::AbstractMatrix) = A - MatrixOperator(B)
Base.:-(A::AbstractMatrix, B::AbstractSciMLOperator) = MatrixOperator(A) - B

for op in (
:+, :-,
Expand All @@ -350,7 +359,7 @@ for op in (
end

Base.convert(::Type{AbstractMatrix}, L::AddedOperator) = sum(op -> convert(AbstractMatrix, op), L.ops)
SparseArrays.sparse(L::AddedOperator) = sum(_sparse, L.ops)
SparseArrays.sparse(L::AddedOperator) = sum(sparse, L.ops)

# traits
Base.size(L::AddedOperator) = size(first(L.ops))
Expand Down Expand Up @@ -482,7 +491,7 @@ for op in (
end

Base.convert(::Type{AbstractMatrix}, L::ComposedOperator) = prod(op -> convert(AbstractMatrix, op), L.ops)
SparseArrays.sparse(L::ComposedOperator) = prod(_sparse, L.ops)
SparseArrays.sparse(L::ComposedOperator) = prod(sparse, L.ops)

# traits
Base.size(L::ComposedOperator) = (size(first(L.ops), 1), size(last(L.ops),2))
Expand Down Expand Up @@ -686,7 +695,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertedOperator, u::Abstrac
axpy!(β, L.cache, v)
end

function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertedOperator, u)
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertedOperator, u::AbstractVecOrMat)
mul!(v, L.L, u)
end

Expand Down
33 changes: 12 additions & 21 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ issquare(::Union{
issquare(A...) = @. (&)(issquare(A)...)

Base.isreal(L::AbstractSciMLOperator{T}) where{T} = T <: Real
Base.Matrix(L::AbstractSciMLLinearOperator) = Matrix(convert(AbstractMatrix, L))
Base.Matrix(L::AbstractSciMLOperator) = Matrix(convert(AbstractMatrix, L))

LinearAlgebra.exp(L::AbstractSciMLLinearOperator,t) = exp(t*L)
LinearAlgebra.exp(L::AbstractSciMLOperator,t) = exp(t*L)
has_exp(L::AbstractSciMLLinearOperator) = true
expmv(L::AbstractSciMLLinearOperator,u,p,t) = exp(L,t)*u
expmv!(v,L::AbstractSciMLLinearOperator,u,p,t) = mul!(v,exp(L,t),u)
expmv(L::AbstractSciMLOperator,u,p,t) = exp(L,t)*u
expmv!(v,L::AbstractSciMLOperator,u,p,t) = mul!(v,exp(L,t),u)

###
# fallback implementations
Expand All @@ -188,46 +188,37 @@ function Base.:(==)(L1::AbstractSciMLOperator, L2::AbstractSciMLOperator)
convert(AbstractMatrix, L1) == convert(AbstractMatrix, L1)
end

Base.@propagate_inbounds function Base.getindex(L::AbstractSciMLLinearOperator, I::Vararg{Any,N}) where {N}
Base.@propagate_inbounds function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Any,N}) where {N}
convert(AbstractMatrix, L)[I...]
end
function Base.getindex(L::AbstractSciMLLinearOperator, I::Vararg{Int, N}) where {N}
function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Int, N}) where {N}
convert(AbstractMatrix,L)[I...]
end

LinearAlgebra.exp(L::AbstractSciMLLinearOperator) = exp(Matrix(L))
LinearAlgebra.opnorm(L::AbstractSciMLLinearOperator, p::Real=2) = opnorm(convert(AbstractMatrix,L), p)
LinearAlgebra.exp(L::AbstractSciMLOperator) = exp(Matrix(L))
LinearAlgebra.opnorm(L::AbstractSciMLOperator, p::Real=2) = opnorm(convert(AbstractMatrix,L), p)
for pred in (
:issymmetric,
:ishermitian,
:isposdef,
)
@eval function LinearAlgebra.$pred(L::AbstractSciMLLinearOperator)
@eval function LinearAlgebra.$pred(L::AbstractSciMLOperator)
$pred(convert(AbstractMatrix, L))
end
end
for op in (
:sum,:prod
)
@eval function LinearAlgebra.$op(L::AbstractSciMLLinearOperator; kwargs...)
@eval function LinearAlgebra.$op(L::AbstractSciMLOperator; kwargs...)
$op(convert(AbstractMatrix, L); kwargs...)
end
end

for op in (
:+, :-,
)

@eval function Base.$op(L::AbstractSciMLLinearOperator, u::AbstractVecOrMat)
$op(convert(AbstractMatrix,L), u)
end
end

function LinearAlgebra.mul!(v::AbstractVecOrMat, L::AbstractSciMLLinearOperator, u::AbstractVecOrMat)
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::AbstractSciMLOperator, u::AbstractVecOrMat)
mul!(v, convert(AbstractMatrix,L), u)
end

function LinearAlgebra.mul!(v::AbstractVecOrMat, L::AbstractSciMLLinearOperator, u::AbstractVecOrMat, α, β)
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::AbstractSciMLOperator, u::AbstractVecOrMat, α, β)
mul!(v, convert(AbstractMatrix,L), u, α, β)
end
#
9 changes: 7 additions & 2 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ isconstant(L::MatrixOperator) = L.update_func == DEFAULT_UPDATE_FUNC
Base.iszero(L::MatrixOperator) = iszero(L.A)

SparseArrays.sparse(L::MatrixOperator) = sparse(L.A)
SparseArrays.issparse(L::MatrixOperator) = issparse(L.A)

# TODO - add tests for MatrixOperator indexing
# propagate_inbounds here for the getindex fallback
Expand Down Expand Up @@ -142,12 +143,16 @@ for fact in (
:svd, :svd!,
)

@eval LinearAlgebra.$fact(L::AbstractSciMLLinearOperator, args...) =
@eval LinearAlgebra.$fact(L::AbstractSciMLOperator, args...) =
InvertibleOperator($fact(convert(AbstractMatrix, L), args...))
@eval LinearAlgebra.$fact(L::AbstractSciMLLinearOperator; kwargs...) =
@eval LinearAlgebra.$fact(L::AbstractSciMLOperator; kwargs...) =
InvertibleOperator($fact(convert(AbstractMatrix, L); kwargs...))
end

function Base.convert(::Type{<:Factorization}, L::InvertibleOperator{T,<:Factorization}) where{T}
L.F
end

function Base.convert(::Type{AbstractMatrix}, L::InvertibleOperator)
if L.F isa Adjoint
convert(AbstractMatrix,L.F')'
Expand Down
8 changes: 4 additions & 4 deletions src/multidim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ for op in (
end
end

function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLLinearOperator, u::AbstractArray)
function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray)
u isa AbstractVecOrMat && @error "LinearAlgebra.mul! not defined for $(typeof(L)), $(typeof(u))."

sizes = _mat_sizes(L, u)
Expand All @@ -34,7 +34,7 @@ function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLLinearOperator, u:
v
end

function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLLinearOperator, u::AbstractArray, α, β)
function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray, α, β)
u isa AbstractVecOrMat && @error "LinearAlgebra.mul! not defined for $(typeof(L)), $(typeof(u))."

sizes = _mat_sizes(L, u)
Expand All @@ -47,7 +47,7 @@ function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLLinearOperator, u:
v
end

function LinearAlgebra.ldiv!(v::AbstractArray, L::AbstractSciMLLinearOperator, u::AbstractArray)
function LinearAlgebra.ldiv!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray)
u isa AbstractVecOrMat && @error "LinearAlgebra.ldiv! not defined for $(typeof(L)), $(typeof(u))."

sizes = _mat_sizes(L, u)
Expand All @@ -60,7 +60,7 @@ function LinearAlgebra.ldiv!(v::AbstractArray, L::AbstractSciMLLinearOperator, u
v
end

function LinearAlgebra.ldiv!(L::AbstractSciMLLinearOperator, u::AbstractArray)
function LinearAlgebra.ldiv!(L::AbstractSciMLOperator, u::AbstractArray)
u isa AbstractVecOrMat && @error "LinearAlgebra.ldiv! not defined for $(typeof(L)), $(typeof(u))."

sizes = _mat_sizes(L, u)
Expand Down

0 comments on commit f703846

Please sign in to comment.