Skip to content

Commit

Permalink
feat: use JVPCache for FiniteDiff pushforwards (#705)
Browse files Browse the repository at this point in the history
* feat: use `JVPCache` for FiniteDiff pushforwards

* test both fdtypes

* Up

* adapt to recent release

* Add benchmarks

* define logging
  • Loading branch information
gdalle authored Jan 31, 2025
1 parent 0ea7f1d commit 2896511
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 89 deletions.
4 changes: 2 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.37"
version = "0.6.38"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -56,7 +56,7 @@ Enzyme = "0.13.17"
EnzymeCore = "0.8.8"
ExplicitImports = "1.10.1"
FastDifferentiation = "0.4.3"
FiniteDiff = "2.23.1"
FiniteDiff = "2.27.0"
FiniteDifferences = "0.12.31"
ForwardDiff = "0.10.36"
GTPSA = "1.4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ using FiniteDiff:
GradientCache,
HessianCache,
JacobianCache,
JVPCache,
finite_difference_derivative,
finite_difference_gradient,
finite_difference_gradient!,
finite_difference_hessian,
finite_difference_hessian!,
finite_difference_jacobian,
finite_difference_jacobian!,
finite_difference_jvp,
finite_difference_jvp!,
default_relstep
using LinearAlgebra: dot, mul!

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
## Pushforward

struct FiniteDiffOneArgPushforwardPrep{R,A} <: DI.PushforwardPrep
struct FiniteDiffOneArgPushforwardPrep{C,R,A} <: DI.PushforwardPrep
cache::C
relstep::R
absstep::A
end

function DI.prepare_pushforward(
f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}
) where {C}
fc = DI.with_contexts(f, contexts...)
y = fc(x)
cache = if x isa Number || y isa Number
nothing
else
JVPCache(similar(x), y, fdtype(backend))
end
relstep = if isnothing(backend.relstep)
default_relstep(fdtype(backend), eltype(x))
else
Expand All @@ -18,12 +26,12 @@ function DI.prepare_pushforward(
else
backend.relstep
end
return FiniteDiffOneArgPushforwardPrep(relstep, absstep)
return FiniteDiffOneArgPushforwardPrep(cache, relstep, absstep)
end

function DI.pushforward(
f,
prep::FiniteDiffOneArgPushforwardPrep,
prep::FiniteDiffOneArgPushforwardPrep{Nothing},
backend::AutoFiniteDiff,
x,
tx::NTuple,
Expand All @@ -41,7 +49,7 @@ end

function DI.value_and_pushforward(
f,
prep::FiniteDiffOneArgPushforwardPrep,
prep::FiniteDiffOneArgPushforwardPrep{Nothing},
backend::AutoFiniteDiff,
x,
tx::NTuple,
Expand All @@ -64,6 +72,39 @@ function DI.value_and_pushforward(
return y, ty
end

function DI.pushforward(
f,
prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache},
::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
(; relstep, absstep) = prep
fc = DI.with_contexts(f, contexts...)
ty = map(tx) do dx
finite_difference_jvp(fc, x, dx, prep.cache; relstep, absstep)
end
return ty
end

function DI.value_and_pushforward(
f,
prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache},
::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
(; relstep, absstep) = prep
fc = DI.with_contexts(f, contexts...)
y = fc(x)
ty = map(tx) do dx
finite_difference_jvp(fc, x, dx, prep.cache, y; relstep, absstep)
end
return y, ty
end

## Derivative

struct FiniteDiffOneArgDerivativePrep{C,R,A} <: DI.DerivativePrep
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
## Pushforward

struct FiniteDiffTwoArgPushforwardPrep{R,A} <: DI.PushforwardPrep
struct FiniteDiffTwoArgPushforwardPrep{C,R,A} <: DI.PushforwardPrep
cache::C
relstep::R
absstep::A
end

function DI.prepare_pushforward(
f!, y, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}
) where {C}
cache = if x isa Number
nothing
else
JVPCache(similar(x), similar(y), fdtype(backend))
end
relstep = if isnothing(backend.relstep)
default_relstep(fdtype(backend), eltype(x))
else
Expand All @@ -18,14 +24,13 @@ function DI.prepare_pushforward(
else
backend.relstep
end
return FiniteDiffTwoArgPushforwardPrep(relstep, absstep)
return DI.NoPushforwardPrep()
return FiniteDiffTwoArgPushforwardPrep(cache, relstep, absstep)
end

function DI.value_and_pushforward(
f!,
y,
prep::FiniteDiffTwoArgPushforwardPrep,
prep::FiniteDiffTwoArgPushforwardPrep{Nothing},
backend::AutoFiniteDiff,
x,
tx::NTuple,
Expand All @@ -52,6 +57,84 @@ function DI.value_and_pushforward(
return y, ty
end

function DI.pushforward(
f!,
y,
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
(; relstep, absstep) = prep
fc! = DI.with_contexts(f!, contexts...)
ty = map(tx) do dx
dy = similar(y)
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep)
dy
end
return ty
end

function DI.value_and_pushforward(
f!,
y,
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
(; relstep, absstep) = prep
fc! = DI.with_contexts(f!, contexts...)
ty = map(tx) do dx
dy = similar(y)
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep)
dy
end
fc!(y, x)
return y, ty
end

function DI.pushforward!(
f!,
y,
ty::NTuple,
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
(; relstep, absstep) = prep
fc! = DI.with_contexts(f!, contexts...)
for b in eachindex(tx, ty)
dx, dy = tx[b], ty[b]
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep)
end
return ty
end

function DI.value_and_pushforward!(
f!,
y,
ty::NTuple,
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {C}
(; relstep, absstep) = prep
fc! = DI.with_contexts(f!, contexts...)
for b in eachindex(tx, ty)
dx, dy = tx[b], ty[b]
finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep)
end
fc!(y, x)
return y, ty
end

## Derivative

struct FiniteDiffTwoArgDerivativePrep{C,R,A} <: DI.DerivativePrep
Expand Down
33 changes: 33 additions & 0 deletions DifferentiationInterface/test/Back/FiniteDiff/benchmark.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using Pkg
Pkg.add("FiniteDiff")

using ADTypes: ADTypes
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterface as DI
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using Test

LOGGING = get(ENV, "CI", "false") == "false"

@testset "Benchmarking sparse" begin
filtered_sparse_scenarios = filter(sparse_scenarios(; band_sizes=[])) do scen
DIT.function_place(scen) == :in &&
DIT.operator_place(scen) == :in &&
scen.x isa AbstractVector &&
scen.y isa AbstractVector
end

data = benchmark_differentiation(
MyAutoSparse(AutoFiniteDiff()),
filtered_sparse_scenarios;
benchmark=:prepared,
excluded=SECOND_ORDER,
logging=LOGGING,
)
@testset "Analyzing benchmark results" begin
@testset "$(row[:scenario])" for row in eachrow(data)
@test row[:allocs] == 0
end
end
end
41 changes: 26 additions & 15 deletions DifferentiationInterface/test/Back/FiniteDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,32 @@ for backend in [AutoFiniteDiff()]
@test check_inplace(backend)
end

test_differentiation(
AutoFiniteDiff(),
default_scenarios(; include_constantified=true, include_cachified=true);
excluded=[:second_derivative, :hvp],
logging=LOGGING,
);

test_differentiation(
[
AutoFiniteDiff(; relstep=cbrt(eps(Float64))),
AutoFiniteDiff(; relstep=cbrt(eps(Float64)), absstep=cbrt(eps(Float64))),
];
excluded=[:second_derivative, :hvp],
logging=LOGGING,
);
@testset "Dense" begin
test_differentiation(
AutoFiniteDiff(),
default_scenarios(; include_constantified=true, include_cachified=true);
excluded=[:second_derivative, :hvp],
logging=LOGGING,
)

test_differentiation(
[
AutoFiniteDiff(; relstep=cbrt(eps(Float64))),
AutoFiniteDiff(; relstep=cbrt(eps(Float64)), absstep=cbrt(eps(Float64))),
];
excluded=[:second_derivative, :hvp],
logging=LOGGING,
)
end

@testset "Sparse" begin
test_differentiation(
MyAutoSparse(AutoFiniteDiff()),
sparse_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,
)
end

@testset "Complex" begin
test_differentiation(AutoFiniteDiff(), complex_scenarios(); logging=LOGGING)
Expand Down
52 changes: 52 additions & 0 deletions DifferentiationInterface/test/Back/ForwardDiff/benchmark.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using Pkg
Pkg.add("ForwardDiff")

using ADTypes: ADTypes
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterface as DI
import DifferentiationInterfaceTest as DIT
using ForwardDiff: ForwardDiff
using StaticArrays: StaticArrays, @SVector
using Test

LOGGING = get(ENV, "CI", "false") == "false"

@testset verbose = true "Benchmarking static" begin
filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen
DIT.function_place(scen) == :out && DIT.operator_place(scen) == :out
end
data = benchmark_differentiation(
AutoForwardDiff(),
filtered_static_scenarios;
benchmark=:prepared,
excluded=[:hessian, :pullback], # TODO: figure this out
logging=LOGGING,
)
@testset "Analyzing benchmark results" begin
@testset "$(row[:scenario])" for row in eachrow(data)
@test row[:allocs] == 0
end
end
end

@testset "Benchmarking sparse" begin
filtered_sparse_scenarios = filter(sparse_scenarios(; band_sizes=[])) do scen
DIT.function_place(scen) == :in &&
DIT.operator_place(scen) == :in &&
scen.x isa AbstractVector &&
scen.y isa AbstractVector
end

data = benchmark_differentiation(
MyAutoSparse(AutoForwardDiff()),
filtered_sparse_scenarios;
benchmark=:prepared,
excluded=SECOND_ORDER,
logging=LOGGING,
)
@testset "Analyzing benchmark results" begin
@testset "$(row[:scenario])" for row in eachrow(data)
@test row[:allocs] == 0
end
end
end
Loading

2 comments on commit 2896511

@gdalle
Copy link
Member Author

@gdalle gdalle commented on 2896511 Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=DifferentiationInterface

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/124077

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a DifferentiationInterface-v0.6.38 -m "<description of version>" 2896511bb6171c002c17684c815a540d1956ee83
git push origin DifferentiationInterface-v0.6.38

Please sign in to comment.