From d79184d2943a7db11a112044bda3e658f041442f Mon Sep 17 00:00:00 2001 From: marius Date: Wed, 4 Jan 2023 01:34:44 -0800 Subject: [PATCH] muse implicit diff working --- src/base_fields.jl | 23 +++++++++++++++-------- src/field_tuples.jl | 21 ++++++++++++++++++++- src/generic.jl | 13 ++++++++++--- src/muse.jl | 40 +++++++++++++++++++++++----------------- src/proj_lambert.jl | 10 +++++----- 5 files changed, 73 insertions(+), 34 deletions(-) diff --git a/src/base_fields.jl b/src/base_fields.jl index e1c9b7cc..4a2dc56c 100644 --- a/src/base_fields.jl +++ b/src/base_fields.jl @@ -36,7 +36,9 @@ lastindex(f::BaseField, i::Int) = lastindex(f.arr, i) @propagate_inbounds getindex(f::BaseField, I::Union{Int,Colon,AbstractArray}...) = getindex(f.arr, I...) @propagate_inbounds setindex!(f::BaseField, X, I::Union{Int,Colon,AbstractArray}...) = (setindex!(f.arr, X, I...); f) similar(f::BaseField{B}, ::Type{T}) where {B,T} = BaseField{B}(similar(f.arr, T), f.metadata) +similar(f::BaseField{B}, ::Type{T}, dims::Base.DimOrInd...) where {B,T} = similar(f.arr, T, dims...) copy(f::BaseField{B}) where {B} = BaseField{B}(copy(f.arr), f.metadata) +copyto!(dst::AbstractArray, src::BaseField) = copyto!(dst, src.arr) (==)(f₁::BaseField, f₂::BaseField) = strict_compatible_metadata(f₁,f₂) && (f₁.arr == f₂.arr) @@ -46,7 +48,9 @@ function promote(f₁::BaseField{B₁}, f₂::BaseField{B₂}) where {B₁,B₂} B = typeof(promote_basis_generic(B₁(), B₂())) B(f₁), B(f₂) end - +# allow very basic arithmetic with BaseField & AbstractArray +promote(f::BaseField{B}, x::AbstractArray) where {B} = (f, BaseField{B}(reshape(x, size(f.arr)), f.proj)) +promote(x::AbstractArray, f::BaseField{B}) where {B} = reverse(promote(f, x)) ## broadcasting @@ -61,6 +65,7 @@ BroadcastStyle(::Type{F}) where {B,M,T,A,F<:BaseField{B,M,T,A}} = BroadcastStyle(::BaseFieldStyle{S₁,B₁}, ::BaseFieldStyle{S₂,B₂}) where {S₁,B₁,S₂,B₂} = BaseFieldStyle{typeof(result_style(S₁(), S₂())), typeof(promote_basis_strict(B₁(),B₂()))}() BroadcastStyle(S::BaseFieldStyle, ::DefaultArrayStyle{0}) = S +BaseFieldStyle{S,B}(::Val{2}) where {S,B} = DefaultArrayStyle{2}() # with the Broadcasted object created, we now compute the answer function materialize(bc::Broadcasted{BaseFieldStyle{S,B}}) where {S,B} @@ -101,10 +106,13 @@ function materialize!(dst::BaseField{B}, bc::Broadcasted{BaseFieldStyle{S,B′}} end -# the default preprocessing, which just unwraps the underlying array. -# this doesn't dispatch on the first argument, but custom BaseFields -# are free to override this and dispatch on it if they need -preprocess(::Any, f::BaseField) = f.arr +# if broadcasting into a BaseField, the first method here is hit with +# dest::Tuple{BaseFieldStyle,M}, in which case just unwrap the array, +# since it will be fed into a downstream regular broadcast +preprocess(::Tuple{BaseFieldStyle{S,B},M}, f::BaseField) where {S,B,M} = f.arr +# if broadcasting into an Array (ie dropping the BaseField wrapper) we +# need to return the vector representation +preprocess(::AbstractArray, f::BaseField) = view(f.arr, :) # we re-wrap each Broadcasted object as we go through preprocessing # because some array types do special things here (e.g. CUDA wraps @@ -135,8 +143,7 @@ function strict_compatible_metadata(f₁::BaseField, f₂::BaseField) end ## mapping - -# this comes up in Zygote.broadcast_forward, and the generic falls back to a regular Array +# map over entries in the array like a true AbstractArray map(func, f::BaseField{B}) where {B} = BaseField{B}(map(func, f.arr), f.metadata) @@ -169,4 +176,4 @@ getproperty(f::BaseField{B}, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B)))...}) where BaseField{B₀}(_reshape_batch(view(getfield(f,:arr), pol_slice(f, pol_index(B(), k))...)), getfield(f,:metadata)) getproperty(f::BaseS02{Basis3Prod{𝐈,B₂,B₀}}, ::Val{:P}) where {B₂,B₀} = BaseField{Basis2Prod{B₂,B₀}}(view(getfield(f,:arr), pol_slice(f, 2:3)...), getfield(f,:metadata)) -getproperty(f::BaseS2, ::Val{:P}) = f \ No newline at end of file +getproperty(f::BaseS2, ::Val{:P}) = f diff --git a/src/field_tuples.jl b/src/field_tuples.jl index 105bf2ab..d46eb607 100644 --- a/src/field_tuples.jl +++ b/src/field_tuples.jl @@ -28,6 +28,7 @@ typealias_def(::Type{<:FieldTuple{FS,T}}) where {FS<:Tuple,T} = ### array interface size(f::FieldTuple) = (mapreduce(length, +, f.fs, init=0),) copy(f::FieldTuple) = FieldTuple(map(copy,f.fs)) +copyto!(dst::AbstractArray, src::FieldTuple) = copyto!(dst, src[:]) # todo: memory optimization possible iterate(ft::FieldTuple, args...) = iterate(ft.fs, args...) getindex(f::FieldTuple, i::Union{Int,UnitRange}) = getindex(f.fs, i) fill!(ft::FieldTuple, x) = (map(f->fill!(f,x), ft.fs); ft) @@ -35,6 +36,7 @@ get_storage(f::FieldTuple) = only(unique(map(get_storage, f.fs))) adapt_structure(to, f::FieldTuple) = FieldTuple(map(f->adapt(to,f),f.fs)) similar(ft::FieldTuple) = FieldTuple(map(similar,ft.fs)) similar(ft::FieldTuple, ::Type{T}) where {T<:Number} = FieldTuple(map(f->similar(f,T),ft.fs)) +similar(ft::FieldTuple, ::Type{T}, dims::Base.DimOrInd...) where {B,T} = similar(ft.fs[1].arr, T, dims...) # todo: make work for heterogenous arrays? similar(ft::FieldTuple, Nbatch::Int) = FieldTuple(map(f->similar(f,Nbatch),ft.fs)) sum(f::FieldTuple; dims=:) = dims == (:) ? sum(sum, f.fs) : error("sum(::FieldTuple, dims=$dims not supported") @@ -54,6 +56,7 @@ function BroadcastStyle(::FieldTupleStyle{S₁,Names}, ::FieldTupleStyle{S₂,Na FieldTupleStyle{Tuple{map_tupleargs((s₁,s₂)->typeof(result_style(s₁(),s₂())), S₁, S₂)...}, Names}() end BroadcastStyle(S::FieldTupleStyle, ::DefaultArrayStyle{0}) = S +FieldTupleStyle{S,Names}(::Val{2}) where {S,Names} = DefaultArrayStyle{2}() @generated function materialize(bc::Broadcasted{FieldTupleStyle{S,Names}}) where {S,Names} @@ -73,13 +76,29 @@ end struct FieldTupleComponent{i} end preprocess(::Tuple{<:Any,FieldTupleComponent{i}}, ft::FieldTuple) where {i} = ft.fs[i] +preprocess(::AbstractArray, ft::FieldTuple) = vcat((view(f.arr, :) for f in ft.fs)...) +### mapping +# map over entries in the component fields like a true AbstractArray +map(func, ft::FieldTuple) = FieldTuple(map(f -> map(func, f), ft.fs)) + ### promotion function promote(ft1::FieldTuple, ft2::FieldTuple) fts = map(promote, ft1.fs, ft2.fs) FieldTuple(map(first,fts)), FieldTuple(map(last,fts)) end +# allow very basic arithmetic with FieldTuple & AbstractArray +function promote(ft::FieldTuple, x::AbstractVector) + lens = map(length, ft.fs) + offsets = typeof(lens)((cumsum([1; lens...])[1:end-1]...,)) + x_ft = FieldTuple(map(ft.fs, offsets, lens) do f, offset, len + promote(f, view(x, offset:offset+len-1))[2] + end) + (ft, x_ft) +end +promote(x::AbstractVector, ft::FieldTuple) = reverse(promote(ft, x)) + ### conversion Basis(ft::FieldTuple) = ft @@ -120,4 +139,4 @@ tr(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(tr∘Diagon batch_length(ft::FieldTuple) = only(unique(map(batch_length, ft.fs))) batch_index(ft::FieldTuple, I) = FieldTuple(map(f -> batch_index(f, I), ft.fs)) getindex(ft::FieldTuple, k::Symbol) = ft.fs[k] -haskey(ft::FieldTuple, k::Symbol) = haskey(ft.fs, k) \ No newline at end of file +haskey(ft::FieldTuple, k::Symbol) = haskey(ft.fs, k) diff --git a/src/generic.jl b/src/generic.jl index 020fab32..891e99b9 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -330,9 +330,16 @@ show_vector(io::IO, f::Field) = !isempty(f) && show_vector(io, f[:]) Base.has_offset_axes(::Field) = false # needed for Diagonal(::Field) if the Field is implicitly-sized -# addition/subtraction works between any fields and scalars, promotion is done -# automatically if fields are in different bases -for op in (:+,:-), (T1,T2,promote) in ((:Field,:Scalar,false),(:Scalar,:Field,false),(:Field,:Field,true)) +# addition/subtraction works between fields, scalars, and +# abstractarrays. promotion is done automatically for fields in +# different bases are wrapped assuming they're the same field type +for op in (:+,:-), (T1,T2,promote) in [ + (:Field, :Scalar, false), + (:Scalar, :Field, false), + (:Field, :Field, true), + (:Field, :AbstractArray, true), + (:AbstractArray, :Field, true) +] @eval ($op)(a::$T1, b::$T2) = broadcast($op, ($promote ? promote(a,b) : (a,b))...) end diff --git a/src/muse.jl b/src/muse.jl index 92ef301b..382583ec 100644 --- a/src/muse.jl +++ b/src/muse.jl @@ -2,7 +2,8 @@ # interface with MuseInference.jl using .MuseInference: AbstractMuseProblem, MuseResult -import .MuseInference: ∇θ_logLike, sample_x_z, ẑ_at_θ, muse!, standardizeθ +using .MuseInference.AbstractDifferentiation +import .MuseInference: logLike, ∇θ_logLike, sample_x_z, ẑ_at_θ, muse!, standardizeθ export CMBLensingMuseProblem @@ -14,10 +15,20 @@ struct CMBLensingMuseProblem{DS<:DataSet,DS_SIM<:DataSet} <: AbstractMuseProblem θ_fixed x latent_vars + autodiff end -function CMBLensingMuseProblem(ds, ds_for_sims=ds; parameterization=0, MAP_joint_kwargs=(;), θ_fixed=(;), latent_vars=nothing) - CMBLensingMuseProblem(ds, ds_for_sims, parameterization, MAP_joint_kwargs, θ_fixed, ds.d, latent_vars) +function CMBLensingMuseProblem( + ds, + ds_for_sims = ds; + parameterization = 0, + MAP_joint_kwargs = (;), + θ_fixed = (;), + latent_vars = nothing, + autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(tag=false), AD.ZygoteBackend())), +) + parameterization == 0 || error("only parameterization=0 (unlensed parameterization) currently implemented") + CMBLensingMuseProblem(ds, ds_for_sims, parameterization, MAP_joint_kwargs, θ_fixed, ds.d, latent_vars, autodiff) end mergeθ(prob::CMBLensingMuseProblem, θ) = isempty(prob.θ_fixed) ? θ : (;prob.θ_fixed..., θ...) @@ -27,26 +38,21 @@ function standardizeθ(prob::CMBLensingMuseProblem, θ) 1f0 * ComponentVector(θ) # ensure component vector and float end +function MuseInference.logLike(prob::CMBLensingMuseProblem, d, z, θ) + logpdf(prob.ds; z..., θ = mergeθ(prob, θ), d) +end + function ∇θ_logLike(prob::CMBLensingMuseProblem, d, z, θ) - @unpack ds, parameterization = prob - @set! ds.d = d - if parameterization == 0 - gradient(θ -> logpdf(ds; z..., θ = mergeθ(prob, θ)), θ)[1] - elseif parameterization == :mix - z° = mix(ds; z..., θ = mergeθ(prob, θ)) - gradient(θ -> logpdf(Mixed(ds); z°..., θ = mergeθ(prob, θ)), θ)[1] - else - error("parameterization should be 0 or :mix") - end + AD.gradient(prob.autodiff, θ -> logLike(prob, d, z, θ), θ)[1] end function sample_x_z(prob::CMBLensingMuseProblem, rng::AbstractRNG, θ) sim = simulate(rng, prob.ds_for_sims, θ = mergeθ(prob, θ)) if prob.latent_vars == nothing # this is a guess which might not work for everything necessarily - z = FieldTuple(delete(sim, (:f̃, :d, :μ))) + z = LenseBasis(FieldTuple(delete(sim, (:f̃, :d, :μ))) ) else - z = FieldTuple(select(sim, prob.latent_vars)) + z = LenseBasis(FieldTuple(select(sim, prob.latent_vars))) end x = sim.d (;x, z) @@ -56,12 +62,12 @@ function ẑ_at_θ(prob::CMBLensingMuseProblem, d, zguess, θ; ∇z_logLike_atol @unpack ds = prob Ωstart = delete(NamedTuple(zguess), :f) MAP = MAP_joint(mergeθ(prob, θ), @set(ds.d=d), Ωstart; fstart=zguess.f, prob.MAP_joint_kwargs...) - FieldTuple(;delete(MAP, :history)...), MAP.history + LenseBasis(FieldTuple(;delete(MAP, :history)...)), MAP.history end function ẑ_at_θ(prob::CMBLensingMuseProblem{<:NoLensingDataSet}, d, (f₀,), θ; ∇z_logLike_atol=nothing) @unpack ds = prob - FieldTuple(f=argmaxf_logpdf(I, mergeθ(prob, θ), @set(ds.d=d); fstart=f₀, prob.MAP_joint_kwargs...)), nothing + LenseBasis(FieldTuple(f=argmaxf_logpdf(I, mergeθ(prob, θ), @set(ds.d=d); fstart=f₀, prob.MAP_joint_kwargs...))), nothing end function muse!(result::MuseResult, ds::DataSet, θ₀=nothing; parameterization=0, MAP_joint_kwargs=(;), kwargs...) diff --git a/src/proj_lambert.jl b/src/proj_lambert.jl index c86e0ee6..8cddbf96 100644 --- a/src/proj_lambert.jl +++ b/src/proj_lambert.jl @@ -131,17 +131,17 @@ promote_metadata_generic(metadata₁::ProjLambert, metadata₂::ProjLambert) = # return `Broadcasted` objects which are spliced into the final # broadcast, thus avoiding allocating any temporary arrays. -function preprocess((_,proj)::Tuple{<:Any,<:ProjLambert{T,V}}, r::Real) where {T,V} +function preprocess((_,proj)::Tuple{<:BaseFieldStyle,<:ProjLambert{T,V}}, r::Real) where {T,V} r isa BatchedReal ? adapt(V, reshape(r.vals, 1, 1, 1, :)) : r end # need custom adjoint here bc Δ can come back batched from the # backward pass even though r was not batched on the forward pass -@adjoint function preprocess(m::Tuple{<:Any,<:ProjLambert{T,V}}, r::Real) where {T,V} +@adjoint function preprocess(m::Tuple{<:BaseFieldStyle,<:ProjLambert{T,V}}, r::Real) where {T,V} preprocess(m, r), Δ -> (nothing, Δ isa AbstractArray ? batch(real.(Δ[:])) : Δ) end -function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ∇d::∇diag) where {S,B} +function preprocess((_,proj)::Tuple{<:BaseFieldStyle{S,B},<:ProjLambert}, ∇d::∇diag) where {S,B} (B <: Union{Fourier,QUFourier,IQUFourier}) || error("Can't broadcast ∇[$(∇d.coord)] as a $(typealias(B)), its not diagonal in this basis.") @@ -156,7 +156,7 @@ function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ∇d:: end end -function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ::∇²diag) where {S,B} +function preprocess((_,proj)::Tuple{<:BaseFieldStyle{S,B},<:ProjLambert}, ::∇²diag) where {S,B} (B <: Union{Fourier,<:Basis2Prod{<:Any,Fourier},<:Basis3Prod{<:Any,<:Any,Fourier}}) || error("Can't broadcast a BandPass as a $(typealias(B)), its not diagonal in this basis.") @@ -164,7 +164,7 @@ function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ::∇²d broadcasted(+, broadcasted(^, proj.ℓx', 2), broadcasted(^, proj.ℓy, 2)) end -function preprocess((_,proj)::Tuple{<:Any,<:ProjLambert}, bp::BandPass) +function preprocess((_,proj)::Tuple{<:BaseFieldStyle,<:ProjLambert}, bp::BandPass) Cℓ_to_2D(bp.Wℓ, proj) end