diff --git a/Project.toml b/Project.toml index 3c0453fe2..f3f63e5ab 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "2.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -31,9 +30,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @@ -41,6 +40,7 @@ RCall = "6f49c342-dc21-5d91-9882-a32aef131414" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +SciMLBaseChainRulesCoreExt = "ChainRulesCore" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" @@ -78,11 +78,12 @@ Statistics = "1" SymbolicIndexingInterface = "0.2" Tables = "1" TruncatedStacktraces = "1" -ZygoteRules = "0.2" +Zygote = "0.6" julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" diff --git a/src/solutions/chainrules.jl b/ext/SciMLBaseChainRulesCoreExt.jl similarity index 57% rename from src/solutions/chainrules.jl rename to ext/SciMLBaseChainRulesCoreExt.jl index 899f58f19..e59bc55d6 100644 --- a/src/solutions/chainrules.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -1,3 +1,8 @@ +module SciMLBaseChainRulesCoreExt + +import ChainRulesCore +import ChainRulesCore: NoTangent, @non_differentiable + function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{ >:ChainRulesCore.HasReverseMode, }, @@ -70,3 +75,60 @@ function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) end VA[sym], ODESolution_getindex_pullback end + +function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...) + function ODEProblemAdjoint(ȳ) + (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type) + end + + ODEProblem(args...; kwargs...), ODEProblemAdjoint +end + +function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...) + function SDEProblemAdjoint(ȳ) + (NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type) + end + + SDEProblem(args...; kwargs...), SDEProblemAdjoint +end + +function ChainRulesCore.rrule(::Type{ + <:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, + T11, T12, + }}, u, + args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, + T12} + function ODESolutionAdjoint(ȳ) + (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) + end + + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), + ODESolutionAdjoint +end + +function ChainRulesCore.rrule(::Type{ + <:ODESolution{uType, tType, isinplace, P, NP, F, G, K, + ND, + }}, u, + args...) where {uType, tType, isinplace, P, NP, F, G, K, ND} + function SDESolutionAdjoint(ȳ) + (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) + end + + SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint +end + +function ChainRulesCore.rrule(::DiffEqBase.EnsembleSolution, sim, time, converged) + out = EnsembleSolution(sim, time, converged) + function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} + arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] + for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]] + (NoTangent(), EnsembleSolution(arrarr, 0.0, true), NoTangent(), NoTangent()) + end + function EnsembleSolution_adjoint(p̄::EnsembleSolution) + (NoTangent(), p̄, NoTangent(), NoTangent()) + end + out, EnsembleSolution_adjoint +end + +end \ No newline at end of file diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index a401e34f9..09ce9bdb5 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -1,9 +1,10 @@ module SciMLBaseZygoteExt -using Zygote: pullback +using Zygote +using Zygote: pullback, ZygoteRules using ZygoteRules: @adjoint -import ZygoteRules -using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved +using SciMLBase +using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt @@ -56,15 +57,15 @@ end VA[sym, j], ODESolution_getindex_pullback end -ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats) - out = EnsembleSolution(sim, time, converged) +@adjoint function EnsembleSolution(sim, time, converged, stats) + out = EnsembleSolution(sim, time, converged, stats) function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]] - (EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing) + (EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - (EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing) + (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::EnsembleSolution) (p̄, nothing, nothing, nothing) @@ -72,9 +73,148 @@ ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats) out, EnsembleSolution_adjoint end -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, +@adjoint function getindex(VA::ODESolution, i::Int) + function ODESolution_getindex_pullback(Δ) + Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) + for (x, j) in zip(VA.u, 1:length(VA))] + (Δ′, nothing) + end + VA[i], ODESolution_getindex_pullback +end + +@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) - sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),) + sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),) +end + +@adjoint function getindex(VA::ODESolution, sym) + function ODESolution_getindex_pullback(Δ) + i = issymbollike(sym) ? sym_to_index(sym, VA) : sym + if i === nothing + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) + else + Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] + for (x, j) in zip(VA.u, 1:length(VA))] + (Δ′, nothing) + end + end + VA[sym], ODESolution_getindex_pullback +end + +@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 + }(u, + args...) where {T1, T2, T3, T4, T5, T6, T7, T8, + T9, T10, T11, T12} + function ODESolutionAdjoint(ȳ) + (ȳ, ntuple(_ -> nothing, length(args))...) + end + + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), + ODESolutionAdjoint +end + +@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, + args...) where + {uType, tType, isinplace, P, NP, F, G, K, ND} + function SDESolutionAdjoint(ȳ) + (ȳ, ntuple(_ -> nothing, length(args))...) + end + + SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint +end + +@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, + args...) where { + T, + N, + uType, + R, + P, + A, + O, + uType2, +} + function NonlinearSolutionAdjoint(ȳ) + (ȳ, ntuple(_ -> nothing, length(args))...) + end + NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint +end + +@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.prob.u0) + _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) + (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) + end + sol.u, solu_adjoint +end + +@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.prob.u0) + _Δ = @. ifelse(Δ === nothing, zerou, Δ) + (DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),) + end + sol.u, solu_adjoint +end + +@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.u) + _Δ = @. ifelse(Δ === nothing, zerou, Δ) + (DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),) + end + sol.u, solu_adjoint +end + +function ∇tmap(cx, f, args...) + ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> (NoTangent(), NoTangent()) + else + ys, backs = Zygote.unzip(ys_and_backs) + function ∇tmap_internal(Δ) + Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ) + Δf_and_args = Zygote.unzip(Δf_and_args_zipped) + Δf = reduce(Zygote.accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + ys, ∇tmap_internal + end +end + +function ∇responsible_map(cx, f, args...) + ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), + args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> (NoTangent(), NoTangent()) + else + ys, backs = Zygote.unzip(ys_and_backs) + ys, + function ∇responsible_map_internal(Δ) + # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. + Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), + Zygote._tryreverse(SciMLBase.responsible_map, + backs, Δ)...) + Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, + Δf_and_args_zipped)) + Δf = reduce(Zygote.accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + end +end + +@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...) + ∇tmap(__context__, f, args...) +end + +@adjoint function SciMLBase.responsible_map(f, + args::Union{AbstractArray, Tuple + }...) + ∇responsible_map(__context__, f, args...) end end diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 72c887d77..b7401404b 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,8 +22,6 @@ import RuntimeGeneratedFunctions import EnumX import TruncatedStacktraces import ADTypes: AbstractADType -import ChainRulesCore -import ZygoteRules: @adjoint import FillArrays using Reexport @@ -716,7 +714,6 @@ include("solutions/optimization_solutions.jl") include("solutions/dae_solutions.jl") include("solutions/pde_solutions.jl") include("solutions/solution_interface.jl") -include("solutions/zygote.jl") include("ensemble/ensemble_solutions.jl") include("ensemble/ensemble_problems.jl") diff --git a/src/solutions/zygote.jl b/src/solutions/zygote.jl deleted file mode 100644 index d41d07e0f..000000000 --- a/src/solutions/zygote.jl +++ /dev/null @@ -1,22 +0,0 @@ -@adjoint function getindex(VA::ODESolution, i::Int) - function ODESolution_getindex_pullback(Δ) - Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) - for (x, j) in zip(VA.u, 1:length(VA))] - (Δ′, nothing) - end - VA[i], ODESolution_getindex_pullback -end - -@adjoint function getindex(VA::ODESolution, sym) - function ODESolution_getindex_pullback(Δ) - i = issymbollike(sym) ? sym_to_index(sym, VA) : sym - if i === nothing - throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) - else - Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] - for (x, j) in zip(VA.u, 1:length(VA))] - (Δ′, nothing) - end - end - VA[sym], ODESolution_getindex_pullback -end