diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e95add1..38767f6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,7 +6,6 @@ on: branches: - main - develop - - abstractcosmoemu push: branches: - '**' diff --git a/Project.toml b/Project.toml index 7577e4c..a75cdd4 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" +FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" LegendrePolynomials = "3db4a2ba-fc88-11e8-3e01-49c72059a882" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,6 +17,8 @@ LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractCosmologicalEmulators = "0.3.3" diff --git a/src/Effort.jl b/src/Effort.jl index 1e60994..5de6c5f 100644 --- a/src/Effort.jl +++ b/src/Effort.jl @@ -6,12 +6,15 @@ import AbstractCosmologicalEmulators.get_emulator_description using ChainRulesCore using DataInterpolations using FastGaussQuadrature +using FindFirstFunctions using LegendrePolynomials using LoopVectorization using Memoization using OrdinaryDiffEq using Integrals using LinearAlgebra +using SparseArrays +using Zygote const c_0 = 2.99792458e5 diff --git a/src/chainrules.jl b/src/chainrules.jl index f827d6c..0862e9e 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,2 +1,69 @@ @non_differentiable LinRange(a,b,n) @non_differentiable _transformed_weights(quadrature_rule, order, a,b) + +Zygote.@adjoint function _create_d(u, t, s, typed_zero) + y = _create_d(u, t, s, typed_zero) + function _create_d_pullback(ȳ) + ∂u = Tridiagonal(zeros(eltype(typed_zero), s-1), + map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s), + map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) * ȳ + ∂t = Tridiagonal(zeros(eltype(typed_zero), s-1), + map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s), + map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) * ȳ + return (∂u, ∂t, NoTangent(), NoTangent()) + end + return y, _create_d_pullback +end + +Zygote.@adjoint function _create_σ(z, x, i_list) + y = _create_σ(z, x, i_list) + function _create_σ_pullback(ȳ) + s = length(z) + s1 = length(i_list) + + runner = 0.5 ./ (x[i_list] - x[i_list .- 1]) + runner_bis = 2. .* (z[i_list] - z[i_list .- 1]) + + ∂z = (sparse(i_list, 1:s1 ,runner, s, s1) - + sparse(i_list .- 1, 1:s1 ,runner, s, s1)) * ȳ + ∂x = (-sparse(i_list, 1:s1 ,runner_bis .* runner .^2, s, s1) + + sparse(i_list .- 1, 1:s1 , runner_bis .* runner .^2, s, s1)) * ȳ + return (∂z, ∂x, NoTangent()) + end + return y, _create_σ_pullback +end + +Zygote.@adjoint function _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ) + y = _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ) + function _compose_pullback(ȳ) + s = length(z) + s1 = length(i_list) + + ∂z = sparse(i_list .-1, 1:s1, [new_t[j] - t[i_list[j] - 1] for j in 1:s_new], s, s1) * ȳ + ∂t = sparse(i_list .-1, 1:s1, map(j -> -z[i_list[j] - 1] - 2σ[j] * (new_t[j] - t[i_list[j] - 1]), 1:s_new), s, s1) * ȳ + ∂t1 = Diagonal([+z[i_list[j] - 1] + 2σ[j] * (new_t[j] - t[i_list[j] - 1]) for j in 1:s1]) * ȳ + ∂σ = Diagonal(map(i -> (new_t[i] - t[i_list[i] - 1])^2, 1:s_new)) * ȳ + ∂Cᵢ_list = Diagonal(ones(s1)) * ȳ + return (∂z, ∂t, ∂t1, ∂Cᵢ_list, NoTangent(), NoTangent(), ∂σ) + end + return y, _compose_pullback +end + +Zygote.@adjoint function _create_Cᵢ_list(u, i_list) + y = _create_Cᵢ_list(u, i_list) + function _create_Cᵢ_list_pullback(ȳ) + s = length(u) + s1 = length(i_list) + ∂Cᵢ_list = sparse(i_list .-1, 1:s1 ,ones(s1), s, s1) * ȳ + return (∂Cᵢ_list, NoTangent()) + end + return y, _create_Cᵢ_list_pullback +end + +Zygote.@adjoint function _create_i_list(t, new_t, s_new) + y = _create_i_list(t, new_t, s_new) + function _create_i_list_pullback(ȳ) + return (NoTangent(), NoTangent(), NoTangent()) + end + return y, _create_i_list_pullback +end diff --git a/src/utils.jl b/src/utils.jl index b492c2f..4c8637e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,3 +4,62 @@ function _transformed_weights(quadrature_rule, order, a,b) w = (b-a)/2. .* w return x, w end + +function _quadratic_spline(u, t, new_t::Number) + s = length(t) + dl = ones(eltype(t), s - 1) + d_tmp = ones(eltype(t), s) + du = zeros(eltype(t), s - 1) + tA = Tridiagonal(dl, d_tmp, du) + + # zero for element type of d, which we don't know yet + typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin])) + + d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) + z = tA \ d + i = min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t, firstindex(t) - 1)), length(t)) + Cᵢ = u[i - 1] + σ = 1 // 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]) + return z[i - 1] * (new_t - t[i - 1]) + σ * (new_t - t[i - 1])^2 + Cᵢ +end + +function _quadratic_spline(u, t, new_t::AbstractArray) + s = length(t) + s_new = length(new_t) + dl = ones(eltype(t), s - 1) + d_tmp = ones(eltype(t), s) + du = zeros(eltype(t), s - 1) + tA = Tridiagonal(dl, d_tmp, du) + + # zero for element type of d, which we don't know yet + typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin])) + + d = _create_d(u, t, s, typed_zero) + z = tA \ d + i_list = _create_i_list(t, new_t, s_new) + Cᵢ_list = _create_Cᵢ_list(u, i_list) + σ = _create_σ(z, t, i_list) + return _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ) +end + +function _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ) + return map(i -> z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + + σ[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cᵢ_list[i], 1:s_new) +end + +function _create_σ(z, t, i_list) + return map(i -> 1 / 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]), i_list) +end + +function _create_Cᵢ_list(u, i_list) + return map(i-> u[i - 1], i_list) +end + +function _create_i_list(t, new_t, s_new) + return map(i-> min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], + firstindex(t) - 1)), length(t)), 1:s_new) +end + +function _create_d(u, t, s, typed_zero) + return map(i -> i == 1 ? typed_zero : 2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) +end diff --git a/test/Project.toml b/test/Project.toml index ea13341..14ede01 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,8 @@ [deps] AbstractCosmologicalEmulators = "c83c1981-e5c4-4837-9eb8-c9b1572acfc6" +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" diff --git a/test/runtests.jl b/test/runtests.jl index 1db17ee..0878737 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,8 +5,9 @@ using Static using Effort using ForwardDiff using Zygote -using FiniteDiff +using FiniteDifferences using SciMLSensitivity +using DataInterpolations mlpd = SimpleChain( static(6), @@ -32,6 +33,16 @@ effort_emu = Effort.P11Emulator(TrainedEmulator = emu, kgrid=k_test, InMinMax = x = [Ωm0, h, mν, w0, wa] +n = 64 +x1 = vcat([0.], sort(rand(n-2)), [1.]) +x2 = 2 .* vcat([0.], sort(rand(n-2)), [1.]) +y = rand(n) + +function di_spline(y,x,xn) + spline = QuadraticSpline(y,x, extrapolate = true) + return spline.(xn) +end + function D_z_x(z, x) Ωm0, h, mν, w0, wa = x sum(Effort._D_z(z, Ωm0, h; mν =mν, w0=w0, wa=wa)) @@ -58,11 +69,15 @@ end @test isapprox(Effort._D_z_old(z, Ωm0, h), Effort._D_z(z, Ωm0, h), rtol=1e-9) @test isapprox(Effort._f_z_old(0.4, Ωm0, h), Effort._f_z(0.4, Ωm0, h)[1], rtol=1e-9) @test isapprox(Zygote.gradient(x->D_z_x(z, x), x)[1], ForwardDiff.gradient(x->D_z_x(z, x), x), rtol=1e-5) - @test isapprox(FiniteDiff.finite_difference_gradient(x->D_z_x(z, x), x), ForwardDiff.gradient(x->D_z_x(z, x), x), rtol=1e-5) + @test isapprox(grad(central_fdm(5,1), x->D_z_x(z, x), x)[1], ForwardDiff.gradient(x->D_z_x(z, x), x), rtol=1e-5) @test isapprox(Zygote.gradient(x->f_z_x(z, x), x)[1], ForwardDiff.gradient(x->f_z_x(z, x), x), rtol=1e-5) - @test isapprox(FiniteDiff.finite_difference_gradient(x->f_z_x(z, x), x), ForwardDiff.gradient(x->f_z_x(z, x), x), rtol=1e-4) - @test isapprox(FiniteDiff.finite_difference_gradient(x->r_z_x(3., x), x), ForwardDiff.gradient(x->r_z_x(3., x), x), rtol=1e-7) + @test isapprox(grad(central_fdm(5,1), x->f_z_x(z, x), x)[1], ForwardDiff.gradient(x->f_z_x(z, x), x), rtol=1e-4) + @test isapprox(grad(central_fdm(5,1), x->r_z_x(3., x), x)[1], ForwardDiff.gradient(x->r_z_x(3., x), x), rtol=1e-7) @test isapprox(Zygote.gradient(x->r_z_x(3., x), x)[1], ForwardDiff.gradient(x->r_z_x(3., x), x), rtol=1e-6) @test isapprox(Zygote.gradient(x->r_z_x(3., x), x)[1], Zygote.gradient(x->r_z_check_x(3., x), x)[1], rtol=1e-7) @test isapprox(Effort._r_z(3., Ωm0, h; mν =mν, w0=w0, wa=wa), Effort._r_z_check(3., Ωm0, h; mν =mν, w0=w0, wa=wa), rtol=1e-6) + @test isapprox(Effort._quadratic_spline(y, x1, x2), di_spline(y, x1, x2), rtol=1e-9) + @test isapprox(grad(central_fdm(5,1), y->sum(Effort._quadratic_spline(y,x1,x2)), y)[1], Zygote.gradient(y->sum(Effort._quadratic_spline(y,x1,x2)), y)[1], rtol=1e-6) + #@test isapprox(grad(central_fdm(6,1), x1->sum(Effort._quadratic_spline(y,x1,x2)), x1)[1], Zygote.gradient(x1->sum(Effort._quadratic_spline(y,x1,x2)), x1)[1], rtol=1e-6) + @test isapprox(grad(central_fdm(5,1), x2->sum(Effort._quadratic_spline(y,x1,x2)), x2)[1], Zygote.gradient(x2->sum(Effort._quadratic_spline(y,x1,x2)), x2)[1], rtol=1e-6) end