Skip to content

Commit

Permalink
Adding rules to my spline
Browse files Browse the repository at this point in the history
  • Loading branch information
marcobonici committed Sep 1, 2024
1 parent 456be33 commit 65bf79b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractCosmologicalEmulators = "0.3.3"
Expand Down
1 change: 1 addition & 0 deletions src/Effort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using Memoization
using OrdinaryDiffEq
using Integrals
using LinearAlgebra
using SparseArrays

const c_0 = 2.99792458e5

Expand Down
67 changes: 67 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,69 @@
@non_differentiable LinRange(a,b,n)
@non_differentiable _transformed_weights(quadrature_rule, order, a,b)

Zygote.@adjoint function _create_d(u, t, s, typed_zero)
y = _create_d(u, t, s, typed_zero)
function _create_d_pullback(ȳ)
∂u = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s),
map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) *
∂t = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s),
map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) *
return (∂u, ∂t, NoTangent(), NoTangent())
end
return y, _create_d_pullback
end

Zygote.@adjoint function _create_σ(z, x, i_list)
y = _create_σ(z, x, i_list)
function _create_σ_pullback(ȳ)
s = length(z)
s1 = length(i_list)

runner = 0.5 ./ (x[i_list] - x[i_list .- 1])
runner_bis = 2. .* (z[i_list] - z[i_list .- 1])

∂z = (sparse(i_list, 1:s1 ,runner, s, s1) -
sparse(i_list .- 1, 1:s1 ,runner, s, s1)) *
∂x = (-sparse(i_list, 1:s1 ,runner_bis .* runner .^2, s, s1) +
sparse(i_list .- 1, 1:s1 , runner_bis .* runner .^2, s, s1)) *
return (∂z, ∂x, NoTangent())
end
return y, _create_σ_pullback
end

Zygote.@adjoint function _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ)
y = _compose(z, t, new_t, Cᵢ_list, s_new, i_list, σ)
function _compose_pullback(ȳ)
s = length(z)
s1 = length(i_list)

∂z = sparse(i_list .-1, 1:s1, [new_t[j] - t[i_list[j] - 1] for j in 1:s_new], s, s1) *
∂t = sparse(i_list .-1, 1:s1, map(j -> -z[i_list[j] - 1] - 2σ[j] * (new_t[j] - t[i_list[j] - 1]), 1:s_new), s, s1) *
∂t1 = Diagonal([+z[i_list[j] - 1] + 2σ[j] * (new_t[j] - t[i_list[j] - 1]) for j in 1:s1]) *
∂σ = Diagonal(map(i -> (new_t[i] - t[i_list[i] - 1])^2, 1:s_new)) *
∂Cᵢ_list = Diagonal(ones(s1)) *
return (∂z, ∂t, ∂t1, ∂Cᵢ_list, NoTangent(), NoTangent(), ∂σ)
end
return y, _compose_pullback
end

Zygote.@adjoint function _create_Cᵢ_list(u, i_list)
y = _create_Cᵢ_list(u, i_list)
function _create_Cᵢ_list_pullback(ȳ)
s = length(z)
s1 = length(i_list)
∂Cᵢ_list = sparse(i_list .-1, 1:s1 ,ones(s1), s, s1) *
return (∂Cᵢ_list, NoTangent())
end
return y, _create_Cᵢ_list_pullback
end

Zygote.@adjoint function _create_i_list(t, new_t, s_new)
y = _create_i_list(t, new_t, s_new)
function _create_i_list_pullback(ȳ)
return (NoTangent(), NoTangent(), NoTangent())
end
return y, _create_i_list_pullback
end

0 comments on commit 65bf79b

Please sign in to comment.