Skip to content

Commit

Permalink
Merge pull request #537 from SciML/rules
Browse files Browse the repository at this point in the history
Move over the rest of pirating rules
  • Loading branch information
ChrisRackauckas authored Nov 3, 2023
2 parents 5a7771d + b960eae commit ecdcc33
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 37 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "2.6.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -31,16 +30,17 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
SciMLBasePartialFunctionsExt = "PartialFunctions"
SciMLBasePyCallExt = "PyCall"
SciMLBasePythonCallExt = "PythonCall"
Expand Down Expand Up @@ -78,11 +78,12 @@ Statistics = "1"
SymbolicIndexingInterface = "0.2"
Tables = "1"
TruncatedStacktraces = "1"
ZygoteRules = "0.2"
Zygote = "0.6"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Expand Down
62 changes: 62 additions & 0 deletions src/solutions/chainrules.jl → ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module SciMLBaseChainRulesCoreExt

import ChainRulesCore
import ChainRulesCore: NoTangent, @non_differentiable

function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{
>:ChainRulesCore.HasReverseMode,
},
Expand Down Expand Up @@ -70,3 +75,60 @@ function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
end
VA[sym], ODESolution_getindex_pullback
end

function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
function ODEProblemAdjoint(ȳ)
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
end

ODEProblem(args...; kwargs...), ODEProblemAdjoint
end

function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...)
function SDEProblemAdjoint(ȳ)
(NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
end

SDEProblem(args...; kwargs...), SDEProblemAdjoint
end

function ChainRulesCore.rrule(::Type{
<:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
T11, T12,
}}, u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11,
T12}
function ODESolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
ODESolutionAdjoint
end

function ChainRulesCore.rrule(::Type{
<:ODESolution{uType, tType, isinplace, P, NP, F, G, K,
ND,
}}, u,
args...) where {uType, tType, isinplace, P, NP, F, G, K, ND}
function SDESolutionAdjoint(ȳ)
(NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
end

SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
end

function ChainRulesCore.rrule(::DiffEqBase.EnsembleSolution, sim, time, converged)
out = EnsembleSolution(sim, time, converged)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
(NoTangent(), EnsembleSolution(arrarr, 0.0, true), NoTangent(), NoTangent())
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(NoTangent(), p̄, NoTangent(), NoTangent())
end
out, EnsembleSolution_adjoint
end

end
158 changes: 149 additions & 9 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module SciMLBaseZygoteExt

using Zygote: pullback
using Zygote
using Zygote: pullback, ZygoteRules
using ZygoteRules: @adjoint
import ZygoteRules
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved
using SciMLBase
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -56,25 +57,164 @@ end
VA[sym, j], ODESolution_getindex_pullback
end

ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats)
out = EnsembleSolution(sim, time, converged)
@adjoint function EnsembleSolution(sim, time, converged, stats)
out = EnsembleSolution(sim, time, converged, stats)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing)
(EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
(EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing)
(EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(p̄, nothing, nothing, nothing)
end
out, EnsembleSolution_adjoint
end

ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
@adjoint function getindex(VA::ODESolution, i::Int)
function ODESolution_getindex_pullback(Δ)
Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
VA[i], ODESolution_getindex_pullback
end

@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
::Val{:u})
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),)
end

@adjoint function getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
if i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...),
ODESolutionAdjoint
end

@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u,
args...) where
{uType, tType, isinplace, P, NP, F, G, K, ND}
function SDESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint
end

@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u,
args...) where {
T,
N,
uType,
R,
P,
A,
O,
uType2,
}
function NonlinearSolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end
NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint
end

@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, (zerou,), Δ)
(DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
end
sol.u, solu_adjoint
end

@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, zerou, Δ)
(DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),)
end
sol.u, solu_adjoint
end

@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.u)
= @. ifelse=== nothing, zerou, Δ)
(DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),)
end
sol.u, solu_adjoint
end

function ∇tmap(cx, f, args...)
ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
function ∇tmap_internal(Δ)
Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ)
Δf_and_args = Zygote.unzip(Δf_and_args_zipped)
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
ys, ∇tmap_internal
end
end

function ∇responsible_map(cx, f, args...)
ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...),
args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> (NoTangent(), NoTangent())
else
ys, backs = Zygote.unzip(ys_and_backs)
ys,
function ∇responsible_map_internal(Δ)
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ),
Zygote._tryreverse(SciMLBase.responsible_map,
backs, Δ)...)
Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map,
Δf_and_args_zipped))
Δf = reduce(Zygote.accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
end
end
end

@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...)
∇tmap(__context__, f, args...)
end

@adjoint function SciMLBase.responsible_map(f,
args::Union{AbstractArray, Tuple
}...)
∇responsible_map(__context__, f, args...)
end

end
3 changes: 0 additions & 3 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import RuntimeGeneratedFunctions
import EnumX
import TruncatedStacktraces
import ADTypes: AbstractADType
import ChainRulesCore
import ZygoteRules: @adjoint
import FillArrays

using Reexport
Expand Down Expand Up @@ -716,7 +714,6 @@ include("solutions/optimization_solutions.jl")
include("solutions/dae_solutions.jl")
include("solutions/pde_solutions.jl")
include("solutions/solution_interface.jl")
include("solutions/zygote.jl")

include("ensemble/ensemble_solutions.jl")
include("ensemble/ensemble_problems.jl")
Expand Down
22 changes: 0 additions & 22 deletions src/solutions/zygote.jl

This file was deleted.

0 comments on commit ecdcc33

Please sign in to comment.