From 0ea7f1de01c59e9589e635f04cc7c6939070c768 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:52:28 +0100 Subject: [PATCH] fix: use Enzyme's native Jacobian in forward mode with constant contexts (#710) * fix: use Enzyme's native Jacobian in forward mode with constant contexts * Add tests --- DifferentiationInterface/Project.toml | 2 +- .../docs/src/explanation/backends.md | 2 +- .../forward_onearg.jl | 76 +++++++++++++------ .../utils.jl | 10 +++ .../test/Back/Enzyme/test.jl | 7 ++ 5 files changed, 70 insertions(+), 27 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 75391593a..9446f551c 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.36" +version = "0.6.37" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/docs/src/explanation/backends.md b/DifferentiationInterface/docs/src/explanation/backends.md index 350d04716..b6e6d8aed 100644 --- a/DifferentiationInterface/docs/src/explanation/backends.md +++ b/DifferentiationInterface/docs/src/explanation/backends.md @@ -61,7 +61,7 @@ Moreover, each context type is supported by a specific subset of backends: | `AutoChainRules` | ✅ | ❌ | | `AutoDiffractor` | ❌ | ❌ | | `AutoEnzyme` (forward) | ✅ | ✅ | -| `AutoEnzyme` (reverse) | ✅ | ✅ | +| `AutoEnzyme` (reverse) | ✅ | ❌ (soon) | | `AutoFastDifferentiation` | ✅ | ✅ | | `AutoFiniteDiff` | ✅ | ✅ | | `AutoFiniteDifferences` | ✅ | ✅ | diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 00b66808a..86c292606 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -119,8 +119,11 @@ function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O} end function DI.prepare_gradient( - f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x -) where {F} + f::F, + backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, + x, + contexts::Vararg{DI.Constant,C}, +) where {F,C} valB = to_val(DI.pick_batchsize(backend, x)) shadows = create_shadows(valB, x) return EnzymeForwardGradientPrep(valB, shadows) @@ -131,11 +134,15 @@ function DI.gradient( prep::EnzymeForwardGradientPrep{B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, -) where {F,B} + contexts::Vararg{DI.Constant,C}, +) where {F,B,C} mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) - derivs = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) - return only(derivs) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + derivs = gradient( + mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows + ) + return first(derivs) end function DI.value_and_gradient( @@ -143,11 +150,15 @@ function DI.value_and_gradient( prep::EnzymeForwardGradientPrep{B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, -) where {F,B} + contexts::Vararg{DI.Constant,C}, +) where {F,B,C} mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) - (; derivs, val) = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) - return val, only(derivs) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + (; derivs, val) = gradient( + mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows + ) + return val, first(derivs) end function DI.gradient!( @@ -156,8 +167,9 @@ function DI.gradient!( prep::EnzymeForwardGradientPrep{B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, -) where {F,B} - return copyto!(grad, DI.gradient(f, prep, backend, x)) + contexts::Vararg{DI.Constant,C}, +) where {F,B,C} + return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end function DI.value_and_gradient!( @@ -166,8 +178,9 @@ function DI.value_and_gradient!( prep::EnzymeForwardGradientPrep{B}, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, -) where {F,B} - y, new_grad = DI.value_and_gradient(f, prep, backend, x) + contexts::Vararg{DI.Constant,C}, +) where {F,B,C} + y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) end @@ -185,9 +198,12 @@ function EnzymeForwardOneArgJacobianPrep( end function DI.prepare_jacobian( - f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x -) where {F} - y = f(x) + f::F, + backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, + x, + contexts::Vararg{DI.Constant,C}, +) where {F,C} + y = f(x, map(DI.unwrap, contexts)...) valB = to_val(DI.pick_batchsize(backend, x)) shadows = create_shadows(valB, x) return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y)) @@ -198,11 +214,15 @@ function DI.jacobian( prep::EnzymeForwardOneArgJacobianPrep{B}, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, -) where {F,B} + contexts::Vararg{DI.Constant,C}, +) where {F,B,C} mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) - derivs = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) - jac_tensor = only(derivs) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + derivs = jacobian( + mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows + ) + jac_tensor = first(derivs) return maybe_reshape(jac_tensor, prep.output_length, length(x)) end @@ -211,11 +231,15 @@ function DI.value_and_jacobian( prep::EnzymeForwardOneArgJacobianPrep{B}, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, -) where {F,B} + contexts::Vararg{DI.Constant,C}, +) where {F,B,C} mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) - (; derivs, val) = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) - jac_tensor = only(derivs) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + (; derivs, val) = jacobian( + mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows + ) + jac_tensor = first(derivs) return val, maybe_reshape(jac_tensor, prep.output_length, length(x)) end @@ -225,8 +249,9 @@ function DI.jacobian!( prep::EnzymeForwardOneArgJacobianPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, -) where {F} - return copyto!(jac, DI.jacobian(f, prep, backend, x)) + contexts::Vararg{DI.Constant,C}, +) where {F,C} + return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end function DI.value_and_jacobian!( @@ -235,7 +260,8 @@ function DI.value_and_jacobian!( prep::EnzymeForwardOneArgJacobianPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, -) where {F} - y, new_jac = DI.value_and_jacobian(f, prep, backend, x) + contexts::Vararg{DI.Constant,C}, +) where {F,C} + y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 9fb59d868..a6ad7cbc4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -53,6 +53,16 @@ force_annotation(f::F) where {F} = Const(f) return Const(DI.unwrap(c)) end +@inline function _translate( + backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache +) where {B} + if B == 1 + return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) + else + return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B))) + end +end + @inline function _translate( backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext ) where {B} diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index f9953fdf8..f4803ccb3 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -53,6 +53,13 @@ end; logging=LOGGING, ) + test_differentiation( + backends[2], + default_scenarios(; include_normal=false, include_cachified=true); + excluded=SECOND_ORDER, + logging=LOGGING, + ) + test_differentiation( duplicated_backends, default_scenarios(; include_normal=false, include_closurified=true);