From 92d786ed29eeb4e0231a5453f8e9f850e2485087 Mon Sep 17 00:00:00 2001 From: Zack Li Date: Thu, 31 Oct 2024 08:27:53 -0700 Subject: [PATCH 1/5] remove tag=false for ForwardDiff in MUSE ext --- ext/CMBLensingMuseInferenceExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/CMBLensingMuseInferenceExt.jl b/ext/CMBLensingMuseInferenceExt.jl index d30e22b7..ef195da5 100644 --- a/ext/CMBLensingMuseInferenceExt.jl +++ b/ext/CMBLensingMuseInferenceExt.jl @@ -26,7 +26,7 @@ using Setfield θ_fixed = (;) x = ds.d latent_vars = nothing - autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(tag=false), AD.ZygoteBackend())) + autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(), AD.ZygoteBackend())) transform_θ = identity inv_transform_θ = identity end @@ -91,4 +91,4 @@ function MuseInference.muse!(result::MuseResult, ds::DataSet, θ₀=nothing; par muse!(result, CMBLensingMuseProblem(ds; parameterization, MAP_joint_kwargs), θ₀; kwargs...) end -end \ No newline at end of file +end From 254efcb9a9db51a4d6970d4ae3fc13bb8e204e27 Mon Sep 17 00:00:00 2001 From: xzackli Date: Thu, 31 Oct 2024 08:42:56 -0700 Subject: [PATCH 2/5] paste in AD fix for tag=false --- ext/CMBLensingMuseInferenceExt.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ext/CMBLensingMuseInferenceExt.jl b/ext/CMBLensingMuseInferenceExt.jl index ef195da5..1c3f8dbe 100644 --- a/ext/CMBLensingMuseInferenceExt.jl +++ b/ext/CMBLensingMuseInferenceExt.jl @@ -18,6 +18,13 @@ using Random using Requires using Setfield + +_ADgetchunksize(::Nothing) = Nothing # can't access extension unexported methods +_ADgetchunksize(::Val{N}) where {N} = N +function _AD_ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing, tag=true) + return AD.ForwardDiffBackend{_ADgetchunksize(chunksize), tag}() +end + @kwdef struct CMBLensingMuseProblem{DS<:DataSet,DS_SIM<:DataSet} <: AbstractMuseProblem ds :: DS ds_for_sims :: DS_SIM = ds @@ -26,7 +33,7 @@ using Setfield θ_fixed = (;) x = ds.d latent_vars = nothing - autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(), AD.ZygoteBackend())) + autodiff = AD.HigherOrderBackend((_AD_ForwardDiffBackend(tag=false), AD.ZygoteBackend())) transform_θ = identity inv_transform_θ = identity end From ca1882022b7a4dfd82522ada98a7e9bca3069857 Mon Sep 17 00:00:00 2001 From: xzackli Date: Thu, 31 Oct 2024 10:04:16 -0700 Subject: [PATCH 3/5] do it ourselves --- ext/CMBLensingMuseInferenceExt.jl | 50 +++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/ext/CMBLensingMuseInferenceExt.jl b/ext/CMBLensingMuseInferenceExt.jl index 1c3f8dbe..55672787 100644 --- a/ext/CMBLensingMuseInferenceExt.jl +++ b/ext/CMBLensingMuseInferenceExt.jl @@ -6,9 +6,11 @@ using CMBLensing if isdefined(Base, :get_extension) using MuseInference using MuseInference: AD, AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ + import AD: pushforward_function, gradient, jacobian, hessian, value_and_gradient, value_and_gradient else using ..MuseInference using ..MuseInference: AD, AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ + import AD: pushforward_function, gradient, jacobian, hessian, value_and_gradient, value_and_gradient end using Base: @kwdef @@ -17,14 +19,52 @@ using NamedTupleTools using Random using Requires using Setfield +using ForwardDiff +# we're going to make our own backend +struct ForwardDiffNoTagBackend{CS} <: AD.AbstractForwardMode end +chunk(::ForwardDiffNoTagBackend{Nothing}, x) = ForwardDiff.Chunk(x) +chunk(::ForwardDiffNoTagBackend{N}, _) where {N} = ForwardDiff.Chunk{N}() -_ADgetchunksize(::Nothing) = Nothing # can't access extension unexported methods -_ADgetchunksize(::Val{N}) where {N} = N -function _AD_ForwardDiffBackend(; chunksize::Union{Val,Nothing}=nothing, tag=true) - return AD.ForwardDiffBackend{_ADgetchunksize(chunksize), tag}() +function pushforward_function(ba::ForwardDiffNoTagBackend{CS}, f, xs...) + pushforward_function(AD.ForwardDiffBackend{CS}(), f, xs...) end +function AD.gradient(ba::ForwardDiffNoTagBackend, f, x::AbstractArray) + cfg = ForwardDiff.GradientConfig(nothing, x, chunk(ba, x)) + return (ForwardDiff.gradient(f, x, cfg),) +end + +function AD.jacobian(ba::ForwardDiffNoTagBackend, f, x::AbstractArray) + cfg = ForwardDiff.JacobianConfig(nothing, x, chunk(ba, x)) + return (ForwardDiff.jacobian(AD.asarray ∘ f, x, cfg),) +end + +function AD.jacobian(ba::ForwardDiffNoTagBackend, f, x::R) where {R <: Number} + T = typeof(ForwardDiff.Tag(nothing, R)) + return (ForwardDiff.extract_derivative(T, f(ForwardDiff.Dual{T}(x, one(x)))),) +end + +function AD.hessian(ba::ForwardDiffNoTagBackend, f, x::AbstractArray) + cfg = ForwardDiff.HessianConfig(nothing, x, chunk(ba, x)) + return (ForwardDiff.hessian(f, x, cfg),) +end + +function AD.value_and_gradient(ba::ForwardDiffNoTagBackend, f, x::AbstractArray) + result = DiffResults.GradientResult(x) + cfg = ForwardDiff.GradientConfig(nothing, x, chunk(ba, x)) + ForwardDiff.gradient!(result, f, x, cfg) + return DiffResults.value(result), (DiffResults.derivative(result),) +end + +function AD.value_and_hessian(ba::ForwardDiffNoTagBackend, f, x) + result = DiffResults.HessianResult(x) + cfg = ForwardDiff.HessianConfig(nothing, result, x, chunk(ba, x)) + ForwardDiff.hessian!(result, f, x, cfg) + return DiffResults.value(result), (DiffResults.hessian(result),) +end + + @kwdef struct CMBLensingMuseProblem{DS<:DataSet,DS_SIM<:DataSet} <: AbstractMuseProblem ds :: DS ds_for_sims :: DS_SIM = ds @@ -33,7 +73,7 @@ end θ_fixed = (;) x = ds.d latent_vars = nothing - autodiff = AD.HigherOrderBackend((_AD_ForwardDiffBackend(tag=false), AD.ZygoteBackend())) + autodiff = AD.HigherOrderBackend((ForwardDiffNoTagBackend(), AD.ZygoteBackend())) transform_θ = identity inv_transform_θ = identity end From d45cf360191a7a050155944b3b438b7069a6b108 Mon Sep 17 00:00:00 2001 From: xzackli Date: Thu, 31 Oct 2024 10:27:48 -0700 Subject: [PATCH 4/5] add AD as a weak dep for the MUSE extension --- Project.toml | 3 ++- ext/CMBLensingMuseInferenceExt.jl | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 8b833bba..c1dfaafb 100644 --- a/Project.toml +++ b/Project.toml @@ -69,6 +69,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -78,7 +79,7 @@ PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9" [extensions] CMBLensingCUDAExt = "CUDA" -CMBLensingMuseInferenceExt = "MuseInference" +CMBLensingMuseInferenceExt = ["MuseInference", "AbstractDifferentiation"] CMBLensingPythonCallExt = "PythonCall" CMBLensingPythonPlotExt = "PythonPlot" diff --git a/ext/CMBLensingMuseInferenceExt.jl b/ext/CMBLensingMuseInferenceExt.jl index 55672787..3be807f3 100644 --- a/ext/CMBLensingMuseInferenceExt.jl +++ b/ext/CMBLensingMuseInferenceExt.jl @@ -5,14 +5,18 @@ using CMBLensing if isdefined(Base, :get_extension) using MuseInference - using MuseInference: AD, AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ - import AD: pushforward_function, gradient, jacobian, hessian, value_and_gradient, value_and_gradient + using MuseInference: AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ + import AbstractDifferentiation + import AbstractDifferentiation: pushforward_function, gradient, jacobian, hessian, value_and_gradient, value_and_gradient else using ..MuseInference - using ..MuseInference: AD, AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ - import AD: pushforward_function, gradient, jacobian, hessian, value_and_gradient, value_and_gradient + using ..MuseInference: AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ + import ..AbstractDifferentiation + import ..AbstractDifferentiation: pushforward_function, gradient, jacobian, hessian, value_and_gradient, value_and_gradient end +const AD = AbstractDifferentiation + using Base: @kwdef using ComponentArrays using NamedTupleTools @@ -26,7 +30,7 @@ struct ForwardDiffNoTagBackend{CS} <: AD.AbstractForwardMode end chunk(::ForwardDiffNoTagBackend{Nothing}, x) = ForwardDiff.Chunk(x) chunk(::ForwardDiffNoTagBackend{N}, _) where {N} = ForwardDiff.Chunk{N}() -function pushforward_function(ba::ForwardDiffNoTagBackend{CS}, f, xs...) +function pushforward_function(ba::ForwardDiffNoTagBackend{CS}, f, xs...) where CS pushforward_function(AD.ForwardDiffBackend{CS}(), f, xs...) end From dd25a5890b012125204803f4825780122297935e Mon Sep 17 00:00:00 2001 From: xzackli Date: Thu, 31 Oct 2024 15:00:15 -0700 Subject: [PATCH 5/5] inject structure into Main --- ext/CMBLensingMuseInferenceExt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/CMBLensingMuseInferenceExt.jl b/ext/CMBLensingMuseInferenceExt.jl index 3be807f3..955ac759 100644 --- a/ext/CMBLensingMuseInferenceExt.jl +++ b/ext/CMBLensingMuseInferenceExt.jl @@ -27,6 +27,8 @@ using ForwardDiff # we're going to make our own backend struct ForwardDiffNoTagBackend{CS} <: AD.AbstractForwardMode end +const CMBLensing.ForwardDiffNoTagBackend = ForwardDiffNoTagBackend + chunk(::ForwardDiffNoTagBackend{Nothing}, x) = ForwardDiff.Chunk(x) chunk(::ForwardDiffNoTagBackend{N}, _) where {N} = ForwardDiff.Chunk{N}()