diff --git a/Project.toml b/Project.toml index 41a2b900..9c66bb8a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CMBLensing" uuid = "b60c06c0-7e54-11e8-3788-4bd722d65317" -authors = ["marius "] -version = "0.5.1" +authors = ["Marius Millea "] +version = "0.6.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -56,7 +56,7 @@ Adapt = "1.0.1, 2, 3" CUDA = "3" Combinatorics = "1" DataStructures = "0.17.9, 0.18" -FFTW = "1.2" +FFTW = "1.2 - 1.3" FileIO = "1.2.2" Formatting = "0.4" ImageFiltering = "0.6.14" @@ -70,7 +70,7 @@ Loess = "0.5" MacroTools = "0.5" Match = "1.1" Measurements = "2" -Memoization = "0.1.4" +Memoization = "0.1.8" NamedTupleTools = "0.13" OptimKit = "0.3.1" ProgressMeter = "1.2" diff --git a/src/autodiff.jl b/src/autodiff.jl index e6617a12..f7e52021 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -1,6 +1,7 @@ # this does basis promotion, unlike Zygote's default for AbstractArrays Zygote.accum(a::Field, b::Field) = a+b +Zygote.accum(a::FieldTuple, b::FieldTuple) = Zygote.accum.(a,b) # this may create a LazyBinaryOp, unlike Zygote's Zygote.accum(a::FieldOp, b::FieldOp) = a+b diff --git a/src/dataset.jl b/src/dataset.jl index dcd281e8..28c46c66 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -70,7 +70,7 @@ function subblock(ds::DS, block) where {DS<:DataSet} end...) end -function (ds::DataSet)(θ::NamedTuple) +function (ds::DataSet)(θ) DS = typeof(ds) DS(map(fieldvalues(ds)) do v (v isa Union{ParamDependentOp,DataSet}) ? v(θ) : v @@ -230,7 +230,7 @@ function load_sim(; @warn "`rfid` will be removed in a future version. Use `fiducial_θ=(r=...,)` instead." fiducial_θ = merge(fiducial_θ,(r=rfid,)) end - Aϕ₀ = get(fiducial_θ, :Aϕ, 1) + Aϕ₀ = T(get(fiducial_θ, :Aϕ, 1)) fiducial_θ = Base.structdiff(fiducial_θ, NamedTuple{(:Aϕ,)}) # remove Aϕ key if present if (Cℓ == nothing) Cℓ = camb(;fiducial_θ..., ℓmax=ℓmax) @@ -241,7 +241,7 @@ function load_sim(; error("ℓmax of `Cℓ` argument should be higher than $ℓmax for this configuration.") end end - r₀ = Cℓ.params.r + r₀ = T(Cℓ.params.r) # noise Cℓs (these are non-debeamed, hence beamFWHM=0 below; the beam comes in via the B operator) if (Cℓn == nothing) @@ -264,7 +264,7 @@ function load_sim(; Cf̃ = Cℓ_to_Cov(pol, proj, (Cℓ.total[k] for k in ks)...) Cn̂ = Cℓ_to_Cov(pol, proj, (Cℓn[k] for k in ks)...) if (Cn == nothing); Cn = Cn̂; end - Cf = ParamDependentOp((;r=r₀, _...)->(Cfs + T(r/r₀)*Cft)) + Cf = ParamDependentOp((;r=r₀, _...)->(Cfs + (T(r)/r₀)*Cft)) Cϕ = ParamDependentOp((;Aϕ=Aϕ₀, _...)->(T(Aϕ) * Cϕ₀)) # data mask diff --git a/src/generic.jl b/src/generic.jl index 0f391b21..f027d013 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -135,7 +135,7 @@ Basis(f::Field) = f basis(f::F) where {F<:Field} = basis(F) basis(::Type{<:Field{B}}) where {B<:Basis} = B basis(::Type{<:Field}) = Basis - +basis(::AbstractVector) = Basis ### printing typealias(::Type{B}) where {B<:Basis} = string(B.name.name) diff --git a/src/maximization.jl b/src/maximization.jl index c5afcd34..908d98ea 100644 --- a/src/maximization.jl +++ b/src/maximization.jl @@ -3,7 +3,7 @@ @doc doc""" argmaxf_lnP(ϕ, ds::DataSet; kwargs...) - argmaxf_lnP(ϕ, θ::NamedTuple, ds::DataSet; kwargs...) + argmaxf_lnP(ϕ, θ, ds::DataSet; kwargs...) argmaxf_lnP(Lϕ, ds::DataSet; kwargs...) Computes either the Wiener filter at fixed $\phi$, or a sample from this slice @@ -20,11 +20,11 @@ Keyword arguments: """ argmaxf_lnP(ϕ::Field, ds::DataSet; kwargs...) = argmaxf_lnP(cache(ds.L(ϕ),ds.d), NamedTuple(), ds; kwargs...) -argmaxf_lnP(ϕ::Field, θ::NamedTuple, ds::DataSet; kwargs...) = argmaxf_lnP(cache(ds.L(ϕ),ds.d), θ, ds; kwargs...) +argmaxf_lnP(ϕ::Field, θ, ds::DataSet; kwargs...) = argmaxf_lnP(cache(ds.L(ϕ),ds.d), θ, ds; kwargs...) function argmaxf_lnP( Lϕ, - θ::NamedTuple, + θ, ds::DataSet; which = :wf, fstart = nothing, diff --git a/src/posterior.jl b/src/posterior.jl index a4b1c94e..e899acfa 100644 --- a/src/posterior.jl +++ b/src/posterior.jl @@ -1,29 +1,29 @@ """ - mix(f, ϕ, ds::DataSet) - mix(f, ϕ, θ::NamedTuple, ds::DataSet) + mix(f, ϕ, ds::DataSet) + mix(f, ϕ, θ, ds::DataSet) Compute the mixed `(f°, ϕ°)` from the unlensed field `f` and lensing potential `ϕ`, given the definition of the mixing matrices in `ds` evaluated at parameters `θ` (or at fiducial values if no `θ` provided). """ mix(f, ϕ, ds::DataSet) = mix(f,ϕ,NamedTuple(),ds) -function mix(f, ϕ, θ::NamedTuple, ds::DataSet) +function mix(f, ϕ, θ, ds::DataSet) @unpack D,G,L = ds(θ) L(ϕ)*D*f, G*ϕ end """ - unmix(f°, ϕ°, ds::DataSet) - unmix(f°, ϕ°, θ::NamedTuple, ds::DataSet) + unmix(f°, ϕ°, ds::DataSet) + unmix(f°, ϕ°, θ, ds::DataSet) Compute the unmixed/unlensed `(f, ϕ)` from the mixed field `f°` and mixed lensing potential `ϕ°`, given the definition of the mixing matrices in `ds` evaluated at parameters `θ` (or at fiducial values if no `θ` provided). """ unmix(f°, ϕ°, ds::DataSet) = unmix(f°,ϕ°,NamedTuple(),ds) -function unmix(f°, ϕ°, θ::NamedTuple, ds::DataSet) +function unmix(f°, ϕ°, θ, ds::DataSet) @unpack D,G,L = ds(θ) ϕ = G\ϕ° D\(L(ϕ)\f°), ϕ @@ -31,8 +31,8 @@ end @doc doc""" - lnP(t, fₜ, ϕₜ, ds::DataSet) - lnP(t, fₜ, ϕₜ, θ::NamedTuple, ds::DataSet) + lnP(t, fₜ, ϕₜ, ds::DataSet) + lnP(t, fₜ, ϕₜ, θ, ds::DataSet) Compute the log posterior probability in the joint parameterization as a function of the field, $f_t$, the lensing potential, $\phi_t$, and possibly some @@ -49,10 +49,10 @@ also include any Jacobian determinant terms that depend on $\theta$. The argument `ds` should be a `DataSet` and stores the masks, data, etc... needed to construct the posterior. """ -lnP(t, fₜ, ϕₜ, ds::DataSet) = lnP(Val(t), fₜ, ϕₜ, NamedTuple(), ds) -lnP(t, fₜ, ϕₜ, θ::NamedTuple, ds::DataSet) = lnP(Val(t), fₜ, ϕₜ, θ, ds) +lnP(t, fₜ, ϕₜ, ds::DataSet) = lnP(Val(t), fₜ, ϕₜ, NamedTuple(), ds) +lnP(t, fₜ, ϕₜ, θ, ds::DataSet) = lnP(Val(t), fₜ, ϕₜ, θ, ds) -function lnP(::Val{t}, fₜ, ϕ, θ::NamedTuple, ds::DataSet) where {t} +function lnP(::Val{t}, fₜ, ϕ, θ, ds::DataSet) where {t} @unpack Cn,Cf,Cϕ,L,M,B,d = ds @@ -69,7 +69,7 @@ function lnP(::Val{t}, fₜ, ϕ, θ::NamedTuple, ds::DataSet) where {t} end -function lnP(::Val{:mix}, f°, ϕ°, θ::NamedTuple, ds::DataSet) +function lnP(::Val{:mix}, f°, ϕ°, θ, ds::DataSet) lnP(Val(0), unmix(f°,ϕ°,θ,ds)..., θ, ds) - logdet(ds.D,θ) - logdet(ds.G,θ) end diff --git a/src/sampling.jl b/src/sampling.jl index 9d8775d9..22a1f1ba 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -386,12 +386,12 @@ end @pack! state = ϕ°, ΔH, accept end -function hmc_step(U::Function, x, Λ; symp_kwargs, progress, always_accept) +function hmc_step(U::Function, x, Λ, δUδx=x->gradient(U, x)[1]; symp_kwargs, progress, always_accept) local ΔH, accept for kwargs in symp_kwargs p = simulate(Λ) (ΔH, xtest) = symplectic_integrate( - x, p, Λ, U; + x, p, Λ, U, δUδx; progress = (progress==:verbose), kwargs... ) diff --git a/src/specialops.jl b/src/specialops.jl index 92837320..93b5c4d8 100644 --- a/src/specialops.jl +++ b/src/specialops.jl @@ -227,6 +227,12 @@ function (L::ParamDependentOp)(θ::NamedTuple) end end (L::ParamDependentOp)(;θ...) = L((;θ...)) +@init @require ComponentArrays="b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" begin + using ComponentArrays + (L::ParamDependentOp)(θ::ComponentArray) = L(convert(NamedTuple, θ)) + (L::Union{FieldOp,UniformScaling})(::ComponentArray) = L +end + @auto_adjoint *(L::ParamDependentOp, f::Field) = L.op * f @auto_adjoint \(L::ParamDependentOp, f::Field) = L.op \ f for F in (:inv, :pinv, :sqrt, :adjoint, :Diagonal, :diag, :simulate, :zero, :one, :logdet, :global_rng_for) diff --git a/src/util.jl b/src/util.jl index db03ccba..b0dc0ffb 100644 --- a/src/util.jl +++ b/src/util.jl @@ -427,3 +427,11 @@ string_trunc(x) = Base._truncate_at_width_or_chars(string(x), displaysize(stdout import NamedTupleTools NamedTupleTools.select(d::Dict, keys) = (;(k=>d[k] for k in keys)...) + +@init @require ComponentArrays="b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" begin + using ComponentArrays + # a Zygote-compatible conversion of ComponentVector to a NamedTuple + Base.convert(::Type{NamedTuple}, x::ComponentVector) = NamedTuple{keys(x)}([x[k] for k in keys(x)]) + @adjoint Base.convert(::Type{NamedTuple}, x::ComponentVector) = convert(NamedTuple, x), Δ -> (nothing, ComponentArray(Δ)) +end +