Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Biquadlens #19

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMBLensing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ include("flat_batch.jl")
include("masking.jl")
include("taylens.jl")
include("bilinearlens.jl")
include("biquadlens.jl")

# plotting
isjuno = false
Expand Down
10 changes: 5 additions & 5 deletions src/bilinearlens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function BilinearLens(ϕ::FlatS0)
function compute_sparse_repr(is_gpu_backed::Val{false})
K = Vector{Int32}(getK(Nside))
M = similar(K)
V = similar(K,Float32)
V = similar(K,T)
for I in 1:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end
Expand All @@ -92,7 +92,7 @@ function BilinearLens(ϕ::FlatS0)
function compute_sparse_repr(is_gpu_backed::Val{true})
K = CuVector{Cint}(getK(Nside))
M = similar(K)
V = similar(K,Float32)
V = similar(K,T)
cuda(ĩs, j̃s, M, V; threads=256) do ĩs, j̃s, M, V
index = threadIdx().x
stride = blockDim().x
Expand All @@ -104,7 +104,7 @@ function BilinearLens(ϕ::FlatS0)
if !Base.isdefined(CUSPARSE,:CuSparseMatrixCOO)
error("To use BilinearLens on GPU, run `using Pkg; pkg\"add https://github.com/marius311/CUDA.jl#coo\"` and restart Julia.")
end
switch2csr(CUSPARSE.CuSparseMatrixCOO{Float32}(K,M,V,(Nside^2,Nside^2)))
switch2csr(CUSPARSE.CuSparseMatrixCOO{T}(K,M,V,(Nside^2,Nside^2)))
end


Expand All @@ -129,7 +129,7 @@ getϕ(Lϕ::BilinearLens) = Lϕ.ϕ
# applying various forms of the operator

function *(Lϕ::BilinearLens, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}}
Lϕ.sparse_repr==I && return f
Lϕ.sparse_repr===I && return f
Łf = Ł(f)
f̃ = similar(Łf)
ds = (D == 1 ? ((),) : tuple.(1:D))
Expand All @@ -140,7 +140,7 @@ function *(Lϕ::BilinearLens, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}}
end

function *(Lϕ::Adjoint{<:Any,<:BilinearLens}, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}}
parent(Lϕ).sparse_repr==I && return f
parent(Lϕ).sparse_repr===I && return f
Łf = Ł(f)
f̃ = similar(Łf)
ds = (D == 1 ? ((),) : tuple.(1:D))
Expand Down
215 changes: 215 additions & 0 deletions src/biquadlens.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@

export BiquadLens

@doc doc"""

BiquadLens(ϕ)

BiquadLens is a lensing operator that computes lensing with bilinear
interpolation. The action of the operator, as well as its adjoint, inverse,
inverse-adjoint, and gradient w.r.t. ϕ can all be computed. The log-determinant
of the operation is non-zero and can't be computed.

Internally, BiquadLens forms a sparse matrix with the interpolation weights,
which can be applied and adjoint-ed extremely fast (e.g. at least an order of
magnitude faster than LenseFlow). Inverse and inverse-adjoint lensing is
somewhat slower as it is implemented with several steps of the [preconditioned
generalized minimal residual](https://en.wikipedia.org/wiki/Generalized_minimal_residual_method)
algorithm, taking anti-lensing as the preconditioner.

!!! warning

Due to [this bug](https://github.com/JuliaLang/PackageCompiler.jl/issues/379)
in PackageCompiler, currently you have to run `using SparseArrays` by hand
in your Julia session before `BiquadLens` is available.

"""
mutable struct BiquadLens{Φ,S} <: ImplicitOp{Basis,Spin,Pix}
ϕ :: Φ
sparse_repr :: S
anti_lensing_sparse_repr :: Union{S, Nothing}
end

function BiquadLens(ϕ::FlatS0)

# if ϕ == 0 then just return identity operator
if norm(ϕ) == 0
return BiquadLens(ϕ,I,I)
end

@unpack Nside,Δx,T = fieldinfo(ϕ)

# the (i,j)-th pixel is deflected to (ĩs[i],j̃s[j])
j̃s,ĩs = getindex.((∇*ϕ)./Δx, :Ix)
ĩs .= ĩs .+ (1:Nside)
j̃s .= (j̃s' .+ (1:Nside))'

# sub2ind converts a 2D index to 1D index, including wrapping at edges
indexwrap(i) = mod(i - 1, Nside) + 1
sub2ind(i,j) = Base._sub2ind((Nside,Nside),indexwrap(i),indexwrap(j))

# compute the 9 non-zero entries in L[I,:] (ie the Ith row of the sparse
# lensing representation, L) and add these to the sparse constructor
# matrices, M, and V, accordingly. this function is split off so it can be
# called directly or used as a CUDA kernel
function compute_row!(I, ĩ, j̃, M, V)

# (i,j) indices of the 9 nearest neighbors
x₋,x₀,x₊ = floor(Int,ĩ) .+ (-1, 0, 1)
y₋,y₀,y₊ = floor(Int,j̃) .+ (-1, 0, 1)

# 1-D indices of the 9 nearest neighbors
M[9I-8:9I] .= @SVector[
sub2ind(x₋, y₋),
sub2ind(x₀, y₋),
sub2ind(x₊, y₋),
sub2ind(x₋, y₀),
sub2ind(x₀, y₀),
sub2ind(x₊, y₀),
sub2ind(x₋, y₊),
sub2ind(x₀, y₊),
sub2ind(x₊, y₊),
]

# weights of these neighbors in the bilinear interpolation
Δx₋, Δx₀, Δx₊ = ((x₋,x₀,x₊) .- ĩ)
Δy₋, Δy₀, Δy₊ = ((y₋,y₀,y₊) .- j̃)
A = @SMatrix[
1 Δx₋ Δx₋^2 Δy₋ Δy₋^2 Δx₋*Δy₋ Δx₋*Δy₋^2 Δx₋^2*Δy₋;
1 Δx₀ Δx₀^2 Δy₋ Δy₋^2 Δx₀*Δy₋ Δx₀*Δy₋^2 Δx₀^2*Δy₋;
1 Δx₊ Δx₊^2 Δy₋ Δy₋^2 Δx₊*Δy₋ Δx₊*Δy₋^2 Δx₊^2*Δy₋;
1 Δx₋ Δx₋^2 Δy₀ Δy₀^2 Δx₋*Δy₀ Δx₋*Δy₀^2 Δx₋^2*Δy₀;
1 Δx₀ Δx₀^2 Δy₀ Δy₀^2 Δx₀*Δy₀ Δx₀*Δy₀^2 Δx₀^2*Δy₀;
1 Δx₊ Δx₊^2 Δy₀ Δy₀^2 Δx₊*Δy₀ Δx₊*Δy₀^2 Δx₊^2*Δy₀;
1 Δx₋ Δx₋^2 Δy₊ Δy₊^2 Δx₋*Δy₊ Δx₋*Δy₊^2 Δx₋^2*Δy₊;
1 Δx₀ Δx₀^2 Δy₊ Δy₊^2 Δx₀*Δy₊ Δx₀*Δy₊^2 Δx₀^2*Δy₊;
1 Δx₊ Δx₊^2 Δy₊ Δy₊^2 Δx₊*Δy₊ Δx₊*Δy₊^2 Δx₊^2*Δy₊;
]
V[9I-8:9I] .= (inv(A'*A)*A')[1,:]

end

# a surprisingly large fraction of the computation for large Nside, so memoize it:
@memoize getK(Nside) = Int32.((9:9Nside^2+8) .÷ 9)

# CPU
function compute_sparse_repr(is_gpu_backed::Val{false})
K = Vector{Int32}(getK(Nside))
M = similar(K)
V = similar(K,T)
for I in 1:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end
sparse(K,M,V,Nside^2,Nside^2)
end

# GPU
function compute_sparse_repr(is_gpu_backed::Val{true})
K = CuVector{Cint}(getK(Nside))
M = similar(K)
V = similar(K,T)
cuda(ĩs, j̃s, M, V; threads=256) do ĩs, j̃s, M, V
index = threadIdx().x
stride = blockDim().x
for I in index:stride:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end
end
# remove once CuSparseMatrixCOO makes it into official CUDA.jl:
if !Base.isdefined(CUSPARSE,:CuSparseMatrixCOO)
error("To use BiquadLens on GPU, run `using Pkg; pkg\"add https://github.com/marius311/CUDA.jl#coo\"` and restart Julia.")
end
switch2csr(CUSPARSE.CuSparseMatrixCOO{T}(K,M,V,(Nside^2,Nside^2)))
end


BiquadLens(ϕ, compute_sparse_repr(Val(is_gpu_backed(ϕ))), nothing)

end


# lazily computing the sparse representation for anti-lensing

function get_anti_lensing_sparse_repr!(Lϕ::BiquadLens)
if Lϕ.anti_lensing_sparse_repr == nothing
Lϕ.anti_lensing_sparse_repr = BiquadLens(-Lϕ.ϕ).sparse_repr
end
Lϕ.anti_lensing_sparse_repr
end


getϕ(Lϕ::BiquadLens) = Lϕ.ϕ
(Lϕ::BiquadLens)(ϕ) = BiquadLens(ϕ)

# applying various forms of the operator

function *(Lϕ::BiquadLens, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}}
Lϕ.sparse_repr===I && return f
Łf = Ł(f)
f̃ = similar(Łf)
ds = (D == 1 ? ((),) : tuple.(1:D))
for d in ds
mul!(@views(f̃.Ix[:,:,d...][:]), Lϕ.sparse_repr, @views(Łf.Ix[:,:,d...][:]))
end
end

function *(Lϕ::Adjoint{<:Any,<:BiquadLens}, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}}
parent(Lϕ).sparse_repr===I && return f
Łf = Ł(f)
f̃ = similar(Łf)
ds = (D == 1 ? ((),) : tuple.(1:D))
for d in ds
mul!(@views(f̃.Ix[:,:,d...][:]), parent(Lϕ).sparse_repr', @views(Łf.Ix[:,:,d...][:]))
end
end

function \(Lϕ::BiquadLens, f::FlatS0{P}) where {N,P<:Flat{N}}
FlatMap{P}(reshape(gmres(
Lϕ.sparse_repr, view(f[:Ix],:),
Pl = get_anti_lensing_sparse_repr!(Lϕ), maxiter = 5
), N, N))
end

function \(Lϕ::Adjoint{<:Any,<:BiquadLens}, f::FlatS0{P}) where {N,P<:Flat{N}}
FlatMap{P}(reshape(gmres(
parent(Lϕ).sparse_repr', view(f[:Ix],:),
Pl = get_anti_lensing_sparse_repr!(parent(Lϕ))', maxiter = 5
), N, N))
end

# special cases for BiquadLens(0ϕ), which don't work with bicgstabl,
# see https://github.com/JuliaMath/IterativeSolvers.jl/issues/271
function \(Lϕ::BiquadLens{<:Any,<:UniformScaling}, f::FlatS0{P}) where {N,P<:Flat{N}}
Lϕ.sparse_repr \ f
end
function \(Lϕ::Adjoint{<:Any,<:BiquadLens{<:Any,<:UniformScaling}}, f::FlatS0{P}) where {N,P<:Flat{N}}
parent(Lϕ).sparse_repr \ f
end

for op in (:*, :\)
@eval function ($op)(Lϕ::Union{BiquadLens, Adjoint{<:Any,<:BiquadLens}}, f::FieldTuple)
Łf = Ł(f)
F = typeof(Łf)
F(map(f->($op)(Lϕ,f), Łf.fs))
end
end


# gradients

@adjoint BiquadLens(ϕ) = BiquadLens(ϕ), Δ -> (Δ,)

@adjoint function *(Lϕ::BiquadLens, f::Field{B}) where {B}
f̃ = Lϕ * f
function back(Δ)
(∇' * (Ref(tuple_adjoint(Ł(Δ))) .* Ł(∇*f̃))), B(Lϕ*Δ)
end
f̃, back
end


# gpu

adapt_structure(storage, Lϕ::BiquadLens) = BiquadLens(adapt(storage, fieldvalues(Lϕ))...)
12 changes: 8 additions & 4 deletions src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,19 @@ Zygote.grad_mut(ds::DataSet) = Ref{Any}((;(propertynames(ds) .=> nothing)...))
Cf # unlensed field covariance
Cf̃ = nothing # lensed field covariance (not always needed)
Cn # noise covariance
Cn̂ = Cn # approximate noise covariance, diagonal in same basis as Cf
M = 1 # user mask
M̂ = M # approximate user mask, diagonal in same basis as Cf
B = 1 # beam and instrumental transfer functions
B̂ = B # approximate beam and instrumental transfer functions, diagonal in same basis as Cf
D = 1 # mixing matrix for mixed parametrization
G = 1 # reparametrization for ϕ
P = 1 # pixelization operator (if estimating field on higher res than data)
L = LenseFlow # lensing operator, possibly cached for memory reuse
P = 1 # pixelization operator (if estimating field on higher res than data)

# preconditioners (denoted by \hat). these should all approximate the
# non-\hat version, but be fast to calculate:
Cn̂ = Cn
M̂ = M
B̂ = B
L̂ = 1
end

function subblock(ds::DS, block) where {DS<:DataSet}
Expand Down
2 changes: 1 addition & 1 deletion src/maximization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ function MAP_marg(
ϕstart = nothing,
Nϕ = :qe,
nsteps = 10,
nsteps_with_meanfield_update = 4,
nsteps_with_meanfield_update = nsteps,
conjgrad_kwargs = (nsteps=500, tol=1e-1),
α = 0.2,
weights = :unlensed,
Expand Down