From 65bf79b071efad1d76d39b23031016663da53706 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Sun, 1 Sep 2024 15:29:36 -0400 Subject: [PATCH] Adding rules to my spline --- Project.toml | 2 ++ src/Effort.jl | 1 + src/chainrules.jl | 67 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+) diff --git a/Project.toml b/Project.toml index c17ed2c..a75cdd4 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,8 @@ LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractCosmologicalEmulators = "0.3.3" diff --git a/src/Effort.jl b/src/Effort.jl index a46902b..6c7fa67 100644 --- a/src/Effort.jl +++ b/src/Effort.jl @@ -13,6 +13,7 @@ using Memoization using OrdinaryDiffEq using Integrals using LinearAlgebra +using SparseArrays const c_0 = 2.99792458e5 diff --git a/src/chainrules.jl b/src/chainrules.jl index f827d6c..069ef52 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,2 +1,69 @@ @non_differentiable LinRange(a,b,n) @non_differentiable _transformed_weights(quadrature_rule, order, a,b) + +Zygote.@adjoint function _create_d(u, t, s, typed_zero) + y = _create_d(u, t, s, typed_zero) + function _create_d_pullback(ȳ) + ∂u = Tridiagonal(zeros(eltype(typed_zero), s-1), + map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s), + map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) * ȳ + ∂t = Tridiagonal(zeros(eltype(typed_zero), s-1), + map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s), + map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) * ȳ + return (∂u, ∂t, NoTangent(), NoTangent()) + end + return y, _create_d_pullback +end + +Zygote.@adjoint function _create_σ(z, x, i_list) + y = _create_σ(z, x, i_list) + function _create_σ_pullback(ȳ) + s = length(z) + s1 = length(i_list) + + runner = 0.5 ./ (x[i_list] - x[i_list .- 1]) + runner_bis = 2. .* (z[i_list] - z[i_list .- 1]) + + ∂z = (sparse(i_list, 1:s1 ,runner, s, s1) - + sparse(i_list .- 1, 1:s1 ,runner, s, s1)) * ȳ + ∂x = (-sparse(i_list, 1:s1 ,runner_bis .* runner .^2, s, s1) + + sparse(i_list .- 1, 1:s1 , runner_bis .* runner .^2, s, s1)) * ȳ + return (∂z, ∂x, NoTangent()) + end + return y, _create_σ_pullback +end + +Zygote.@adjoint function _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ) + y = _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ) + function _compose_pullback(ȳ) + s = length(z) + s1 = length(i_list) + + ∂z = sparse(i_list .-1, 1:s1, [new_t[j] - t[i_list[j] - 1] for j in 1:s_new], s, s1) * ȳ + ∂t = sparse(i_list .-1, 1:s1, map(j -> -z[i_list[j] - 1] - 2σ[j] * (new_t[j] - t[i_list[j] - 1]), 1:s_new), s, s1) * ȳ + ∂t1 = Diagonal([+z[i_list[j] - 1] + 2σ[j] * (new_t[j] - t[i_list[j] - 1]) for j in 1:s1]) * ȳ + ∂σ = Diagonal(map(i -> (new_t[i] - t[i_list[i] - 1])^2, 1:s_new)) * ȳ + ∂Cᵢ_list = Diagonal(ones(s1)) * ȳ + return (∂z, ∂t, ∂t1, ∂Cᵢ_list, NoTangent(), NoTangent(), ∂σ) + end + return y, _compose_pullback +end + +Zygote.@adjoint function _create_Cᵢ_list(u, i_list) + y = _create_Cᵢ_list(u, i_list) + function _create_Cᵢ_list_pullback(ȳ) + s = length(z) + s1 = length(i_list) + ∂Cᵢ_list = sparse(i_list .-1, 1:s1 ,ones(s1), s, s1) * ȳ + return (∂Cᵢ_list, NoTangent()) + end + return y, _create_Cᵢ_list_pullback +end + +Zygote.@adjoint function _create_i_list(t, new_t, s_new) + y = _create_i_list(t, new_t, s_new) + function _create_i_list_pullback(ȳ) + return (NoTangent(), NoTangent(), NoTangent()) + end + return y, _create_i_list_pullback +end