Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spline AD #26

Merged
merged 8 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ on:
branches:
- main
- develop
- abstractcosmoemu
push:
branches:
- '**'
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
LegendrePolynomials = "3db4a2ba-fc88-11e8-3e01-49c72059a882"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractCosmologicalEmulators = "0.3.3"
Expand Down
3 changes: 3 additions & 0 deletions src/Effort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ import AbstractCosmologicalEmulators.get_emulator_description
using ChainRulesCore
using DataInterpolations
using FastGaussQuadrature
using FindFirstFunctions
using LegendrePolynomials
using LoopVectorization
using Memoization
using OrdinaryDiffEq
using Integrals
using LinearAlgebra
using SparseArrays
using Zygote

const c_0 = 2.99792458e5

Expand Down
67 changes: 67 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,69 @@
@non_differentiable LinRange(a,b,n)
@non_differentiable _transformed_weights(quadrature_rule, order, a,b)

Zygote.@adjoint function _create_d(u, t, s, typed_zero)
y = _create_d(u, t, s, typed_zero)
function _create_d_pullback(ȳ)
∂u = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s),
map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) * ȳ
∂t = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s),
map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) * ȳ
return (∂u, ∂t, NoTangent(), NoTangent())
end
return y, _create_d_pullback
end

Zygote.@adjoint function _create_σ(z, x, i_list)
y = _create_σ(z, x, i_list)
function _create_σ_pullback(ȳ)
s = length(z)
s1 = length(i_list)

runner = 0.5 ./ (x[i_list] - x[i_list .- 1])
runner_bis = 2. .* (z[i_list] - z[i_list .- 1])

∂z = (sparse(i_list, 1:s1 ,runner, s, s1) -
sparse(i_list .- 1, 1:s1 ,runner, s, s1)) * ȳ
∂x = (-sparse(i_list, 1:s1 ,runner_bis .* runner .^2, s, s1) +
sparse(i_list .- 1, 1:s1 , runner_bis .* runner .^2, s, s1)) * ȳ
return (∂z, ∂x, NoTangent())
end
return y, _create_σ_pullback
end

Zygote.@adjoint function _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ)
y = _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ)
function _compose_pullback(ȳ)
s = length(z)
s1 = length(i_list)

∂z = sparse(i_list .-1, 1:s1, [new_t[j] - t[i_list[j] - 1] for j in 1:s_new], s, s1) * ȳ
∂t = sparse(i_list .-1, 1:s1, map(j -> -z[i_list[j] - 1] - 2σ[j] * (new_t[j] - t[i_list[j] - 1]), 1:s_new), s, s1) * ȳ
∂t1 = Diagonal([+z[i_list[j] - 1] + 2σ[j] * (new_t[j] - t[i_list[j] - 1]) for j in 1:s1]) * ȳ
∂σ = Diagonal(map(i -> (new_t[i] - t[i_list[i] - 1])^2, 1:s_new)) * ȳ
∂Cᵢ_list = Diagonal(ones(s1)) * ȳ
return (∂z, ∂t, ∂t1, ∂Cᵢ_list, NoTangent(), NoTangent(), ∂σ)
end
return y, _compose_pullback
end

Zygote.@adjoint function _create_Cᵢ_list(u, i_list)
y = _create_Cᵢ_list(u, i_list)
function _create_Cᵢ_list_pullback(ȳ)
s = length(u)
s1 = length(i_list)
∂Cᵢ_list = sparse(i_list .-1, 1:s1 ,ones(s1), s, s1) * ȳ
return (∂Cᵢ_list, NoTangent())
end
return y, _create_Cᵢ_list_pullback
end

Zygote.@adjoint function _create_i_list(t, new_t, s_new)
y = _create_i_list(t, new_t, s_new)
function _create_i_list_pullback(ȳ)
return (NoTangent(), NoTangent(), NoTangent())
end
return y, _create_i_list_pullback
end
59 changes: 59 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,62 @@ function _transformed_weights(quadrature_rule, order, a,b)
w = (b-a)/2. .* w
return x, w
end

function _quadratic_spline(u, t, new_t::Number)
s = length(t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))

d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
z = tA \ d
i = min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t, firstindex(t) - 1)), length(t))
Cᵢ = u[i - 1]
σ = 1 // 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1])
return z[i - 1] * (new_t - t[i - 1]) + σ * (new_t - t[i - 1])^2 + Cᵢ
end

function _quadratic_spline(u, t, new_t::AbstractArray)
s = length(t)
s_new = length(new_t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))

d = _create_d(u, t, s, typed_zero)
z = tA \ d
i_list = _create_i_list(t, new_t, s_new)
Cᵢ_list = _create_Cᵢ_list(u, i_list)
σ = _create_σ(z, t, i_list)
return _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ)
end

function _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ)
return map(i -> z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) +
σ[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cᵢ_list[i], 1:s_new)
end

function _create_σ(z, t, i_list)
return map(i -> 1 / 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]), i_list)
end

function _create_Cᵢ_list(u, i_list)
return map(i-> u[i - 1], i_list)
end

function _create_i_list(t, new_t, s_new)
return map(i-> min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i],
firstindex(t) - 1)), length(t)), 1:s_new)
end

function _create_d(u, t, s, typed_zero)
return map(i -> i == 1 ? typed_zero : 2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
AbstractCosmologicalEmulators = "c83c1981-e5c4-4837-9eb8-c9b1572acfc6"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand Down
23 changes: 19 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using Static
using Effort
using ForwardDiff
using Zygote
using FiniteDiff
using FiniteDifferences
using SciMLSensitivity
using DataInterpolations

mlpd = SimpleChain(
static(6),
Expand All @@ -32,6 +33,16 @@ effort_emu = Effort.P11Emulator(TrainedEmulator = emu, kgrid=k_test, InMinMax =

x = [Ωm0, h, mν, w0, wa]

n = 64
x1 = vcat([0.], sort(rand(n-2)), [1.])
x2 = 2 .* vcat([0.], sort(rand(n-2)), [1.])
y = rand(n)

function di_spline(y,x,xn)
spline = QuadraticSpline(y,x, extrapolate = true)
return spline.(xn)
end

function D_z_x(z, x)
Ωm0, h, mν, w0, wa = x
sum(Effort._D_z(z, Ωm0, h; mν =mν, w0=w0, wa=wa))
Expand All @@ -58,11 +69,15 @@ end
@test isapprox(Effort._D_z_old(z, Ωm0, h), Effort._D_z(z, Ωm0, h), rtol=1e-9)
@test isapprox(Effort._f_z_old(0.4, Ωm0, h), Effort._f_z(0.4, Ωm0, h)[1], rtol=1e-9)
@test isapprox(Zygote.gradient(x->D_z_x(z, x), x)[1], ForwardDiff.gradient(x->D_z_x(z, x), x), rtol=1e-5)
@test isapprox(FiniteDiff.finite_difference_gradient(x->D_z_x(z, x), x), ForwardDiff.gradient(x->D_z_x(z, x), x), rtol=1e-5)
@test isapprox(grad(central_fdm(5,1), x->D_z_x(z, x), x)[1], ForwardDiff.gradient(x->D_z_x(z, x), x), rtol=1e-5)
@test isapprox(Zygote.gradient(x->f_z_x(z, x), x)[1], ForwardDiff.gradient(x->f_z_x(z, x), x), rtol=1e-5)
@test isapprox(FiniteDiff.finite_difference_gradient(x->f_z_x(z, x), x), ForwardDiff.gradient(x->f_z_x(z, x), x), rtol=1e-4)
@test isapprox(FiniteDiff.finite_difference_gradient(x->r_z_x(3., x), x), ForwardDiff.gradient(x->r_z_x(3., x), x), rtol=1e-7)
@test isapprox(grad(central_fdm(5,1), x->f_z_x(z, x), x)[1], ForwardDiff.gradient(x->f_z_x(z, x), x), rtol=1e-4)
@test isapprox(grad(central_fdm(5,1), x->r_z_x(3., x), x)[1], ForwardDiff.gradient(x->r_z_x(3., x), x), rtol=1e-7)
@test isapprox(Zygote.gradient(x->r_z_x(3., x), x)[1], ForwardDiff.gradient(x->r_z_x(3., x), x), rtol=1e-6)
@test isapprox(Zygote.gradient(x->r_z_x(3., x), x)[1], Zygote.gradient(x->r_z_check_x(3., x), x)[1], rtol=1e-7)
@test isapprox(Effort._r_z(3., Ωm0, h; mν =mν, w0=w0, wa=wa), Effort._r_z_check(3., Ωm0, h; mν =mν, w0=w0, wa=wa), rtol=1e-6)
@test isapprox(Effort._quadratic_spline(y, x1, x2), di_spline(y, x1, x2), rtol=1e-9)
@test isapprox(grad(central_fdm(5,1), y->sum(Effort._quadratic_spline(y,x1,x2)), y)[1], Zygote.gradient(y->sum(Effort._quadratic_spline(y,x1,x2)), y)[1], rtol=1e-6)
#@test isapprox(grad(central_fdm(6,1), x1->sum(Effort._quadratic_spline(y,x1,x2)), x1)[1], Zygote.gradient(x1->sum(Effort._quadratic_spline(y,x1,x2)), x1)[1], rtol=1e-6)
@test isapprox(grad(central_fdm(5,1), x2->sum(Effort._quadratic_spline(y,x1,x2)), x2)[1], Zygote.gradient(x2->sum(Effort._quadratic_spline(y,x1,x2)), x2)[1], rtol=1e-6)
end