Skip to content

Commit

Permalink
add batch inverse BilinearLens
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Sep 9, 2020
1 parent 0e0362a commit 2aadf79
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions src/bilinearlens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,28 +150,39 @@ function *(Lϕ::Adjoint{<:Any,<:BilinearLens}, f::FlatS0{P}) where {N,D,P<:Flat{
end

function \(Lϕ::BilinearLens, f::FlatS0{P}) where {N,P<:Flat{N}}
FlatMap{P}(reshape(gmres(
.sparse_repr, view(f[:Ix],:),
Pl = get_anti_lensing_sparse_repr!(Lϕ), maxiter = 5
), N, N))
function \(Lϕ::BilinearLens, f̃::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}}
Łf̃ = Ł(f̃)
f = similar(Łf̃)
ds = (D == 1 ? ((),) : tuple.(1:D))
for d in ds
@views(f.Ix[:,:,d...][:]) .= gmres(
.sparse_repr, @views(Łf̃.Ix[:,:,d...][:]),
Pl = get_anti_lensing_sparse_repr!(Lϕ), maxiter = 5
)
end
f
end

function \(Lϕ::Adjoint{<:Any,<:BilinearLens}, f::FlatS0{P}) where {N,P<:Flat{N}}
FlatMap{P}(reshape(gmres(
parent(Lϕ).sparse_repr', view(f[:Ix],:),
Pl = get_anti_lensing_sparse_repr!(parent(Lϕ))', maxiter = 5
), N, N))
function \(Lϕ::Adjoint{<:Any,<:BilinearLens}, f̃::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}}
Łf̃ = Ł(f̃)
f = similar(Łf̃)
ds = (D == 1 ? ((),) : tuple.(1:D))
for d in ds
@views(f.Ix[:,:,d...][:]) .= gmres(
parent(Lϕ).sparse_repr', @views(Łf̃.Ix[:,:,d...][:]),
Pl = get_anti_lensing_sparse_repr!(parent(Lϕ))', maxiter = 5
)
end
f
end

# special cases for BilinearLens(0ϕ), which don't work with bicgstabl,
# see https://github.com/JuliaMath/IterativeSolvers.jl/issues/271
function \(Lϕ::BilinearLens{<:Any,<:UniformScaling}, f::FlatS0{P}) where {N,P<:Flat{N}}
.sparse_repr \ f
end
function \(Lϕ::Adjoint{<:Any,<:BilinearLens{<:Any,<:UniformScaling}}, f::FlatS0{P}) where {N,P<:Flat{N}}
parent(Lϕ).sparse_repr \ f
end

# optimizations for BilinearLens(0ϕ)
\(Lϕ::BilinearLens{<:Any,<:UniformScaling}, f::FlatS0{P}) where {N,P<:Flat{N}} = f
*(Lϕ::BilinearLens{<:Any,<:UniformScaling}, f::FlatS0{P}) where {N,P<:Flat{N}} = f
\(Lϕ::Adjoint{<:Any,<:BilinearLens{<:Any,<:UniformScaling}}, f::FlatS0{P}) where {N,P<:Flat{N}} = f
*(Lϕ::Adjoint{<:Any,<:BilinearLens{<:Any,<:UniformScaling}}, f::FlatS0{P}) where {N,P<:Flat{N}} = f


for op in (:*, :\)
@eval function ($op)(Lϕ::Union{BilinearLens, Adjoint{<:Any,<:BilinearLens}}, f::FieldTuple)
Expand Down

0 comments on commit 2aadf79

Please sign in to comment.