diff --git a/src/bilinearlens.jl b/src/bilinearlens.jl index b512799b..32109efe 100644 --- a/src/bilinearlens.jl +++ b/src/bilinearlens.jl @@ -150,28 +150,39 @@ function *(Lϕ::Adjoint{<:Any,<:BilinearLens}, f::FlatS0{P}) where {N,D,P<:Flat{ f̃ end -function \(Lϕ::BilinearLens, f::FlatS0{P}) where {N,P<:Flat{N}} - FlatMap{P}(reshape(gmres( - Lϕ.sparse_repr, view(f[:Ix],:), - Pl = get_anti_lensing_sparse_repr!(Lϕ), maxiter = 5 - ), N, N)) +function \(Lϕ::BilinearLens, f̃::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}} + Łf̃ = Ł(f̃) + f = similar(Łf̃) + ds = (D == 1 ? ((),) : tuple.(1:D)) + for d in ds + @views(f.Ix[:,:,d...][:]) .= gmres( + Lϕ.sparse_repr, @views(Łf̃.Ix[:,:,d...][:]), + Pl = get_anti_lensing_sparse_repr!(Lϕ), maxiter = 5 + ) + end + f end -function \(Lϕ::Adjoint{<:Any,<:BilinearLens}, f::FlatS0{P}) where {N,P<:Flat{N}} - FlatMap{P}(reshape(gmres( - parent(Lϕ).sparse_repr', view(f[:Ix],:), - Pl = get_anti_lensing_sparse_repr!(parent(Lϕ))', maxiter = 5 - ), N, N)) +function \(Lϕ::Adjoint{<:Any,<:BilinearLens}, f̃::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}} + Łf̃ = Ł(f̃) + f = similar(Łf̃) + ds = (D == 1 ? ((),) : tuple.(1:D)) + for d in ds + @views(f.Ix[:,:,d...][:]) .= gmres( + parent(Lϕ).sparse_repr', @views(Łf̃.Ix[:,:,d...][:]), + Pl = get_anti_lensing_sparse_repr!(parent(Lϕ))', maxiter = 5 + ) + end + f end -# special cases for BilinearLens(0ϕ), which don't work with bicgstabl, -# see https://github.com/JuliaMath/IterativeSolvers.jl/issues/271 -function \(Lϕ::BilinearLens{<:Any,<:UniformScaling}, f::FlatS0{P}) where {N,P<:Flat{N}} - Lϕ.sparse_repr \ f -end -function \(Lϕ::Adjoint{<:Any,<:BilinearLens{<:Any,<:UniformScaling}}, f::FlatS0{P}) where {N,P<:Flat{N}} - parent(Lϕ).sparse_repr \ f -end + +# optimizations for BilinearLens(0ϕ) +\(Lϕ::BilinearLens{<:Any,<:UniformScaling}, f::FlatS0{P}) where {N,P<:Flat{N}} = f +*(Lϕ::BilinearLens{<:Any,<:UniformScaling}, f::FlatS0{P}) where {N,P<:Flat{N}} = f +\(Lϕ::Adjoint{<:Any,<:BilinearLens{<:Any,<:UniformScaling}}, f::FlatS0{P}) where {N,P<:Flat{N}} = f +*(Lϕ::Adjoint{<:Any,<:BilinearLens{<:Any,<:UniformScaling}}, f::FlatS0{P}) where {N,P<:Flat{N}} = f + for op in (:*, :\) @eval function ($op)(Lϕ::Union{BilinearLens, Adjoint{<:Any,<:BilinearLens}}, f::FieldTuple)