diff --git a/Project.toml b/Project.toml index 8f4fd2c..3fd4251 100644 --- a/Project.toml +++ b/Project.toml @@ -5,10 +5,22 @@ version = "2.4.2" [compat] julia = "1.0" +ChainRulesCore = "1" +ForwardDiff = "0.10,0.11" [extras] DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" + +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +PolyLogForwardDiffExt = "ForwardDiff" +PolyLogChainRulesExt = "ChainRulesCore" [targets] -test = ["Test", "DelimitedFiles"] +test = ["Test", "DelimitedFiles", "ForwardDiff", "ChainRulesTestUtils"] diff --git a/ext/PolyLogChainRulesExt.jl b/ext/PolyLogChainRulesExt.jl new file mode 100644 index 0000000..bb58fec --- /dev/null +++ b/ext/PolyLogChainRulesExt.jl @@ -0,0 +1,11 @@ +module PolyLogChainRulesExt + +using PolyLog +using ChainRulesCore +ChainRulesCore.@scalar_rule PolyLog.reli4(x) ifelse(iszero(x), one(x), PolyLog.reli3(x)/x) +ChainRulesCore.@scalar_rule PolyLog.reli3(x) ifelse(iszero(x), one(x), PolyLog.reli2(x)/x) +ChainRulesCore.@scalar_rule PolyLog.reli2(x) ifelse(iszero(x), one(x), PolyLog.reli1(x)/x) +ChainRulesCore.@scalar_rule PolyLog.reli1(x) one(x)/(one(x)-x) +ChainRulesCore.@scalar_rule(PolyLog.reli(n, x),(ChainRulesCore.NoTangent(),ifelse(iszero(x), one(x), PolyLog.reli(n-1,x)/x))) + +end #module diff --git a/ext/PolyLogForwardDiffExt.jl b/ext/PolyLogForwardDiffExt.jl new file mode 100644 index 0000000..8831ad0 --- /dev/null +++ b/ext/PolyLogForwardDiffExt.jl @@ -0,0 +1,53 @@ +module PolyLogForwardDiffExt + +using PolyLog +using ForwardDiff +using ForwardDiff: Dual, partials +function PolyLog.reli4(d::Dual{T}) where T + val = ForwardDiff.value(d) + if iszero(val) + return Dual{T}(one(val), partials(d)) + end + x = reli4(val) + dx = reli3(val)/val + return Dual{T}(x, dx * partials(d)) +end + +function PolyLog.reli3(d::Dual{T}) where T + val = ForwardDiff.value(d) + if iszero(val) + return Dual{T}(one(val), partials(d)) + end + x = reli3(val) + dx = reli2(val)/val + return Dual{T}(x, dx * partials(d)) +end + +function PolyLog.reli2(d::Dual{T}) where T + val = ForwardDiff.value(d) + if iszero(val) + return Dual{T}(one(val), partials(d)) + end + x = reli2(val) + dx = reli1(val)/val + return Dual{T}(x, dx * partials(d)) +end + +function PolyLog.reli1(d::Dual{T}) where T + val = ForwardDiff.value(d) + x = reli1(val) + dx = one(val) / (one(val) - val) + return Dual{T}(x, dx*partials(d)) +end + +function PolyLog.reli(n::Integer,d::Dual{T}) where T + val = ForwardDiff.value(d) + if iszero(val) + return Dual{T}(one(val), partials(d)) + end + x = reli(n,val) + dx = PolyLog.reli(n-1,val)/val + return Dual{T}(x, dx*partials(d)) +end + +end #module diff --git a/test/Li.jl b/test/Li.jl index 4b93d5e..2419d41 100644 --- a/test/Li.jl +++ b/test/Li.jl @@ -93,6 +93,14 @@ end @test PolyLog.li(n, 1//1 + 0//1im) ≈ zeta @test PolyLog.li(n, 1 + 0im) ≈ zeta @test PolyLog.li(n, BigFloat("1.0") + 0im) == PolyLog.zeta(n, BigFloat) + + #ForwardDiff Test + if isdefined(Base,:get_extension) + @test ForwardDiff.derivative(Base.Fix1(PolyLog.reli,n),float(pi)) == PolyLog.reli(n-1,float(pi))/float(pi) + @test ForwardDiff.derivative(Base.Fix1(PolyLog.reli,n),0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli, n, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli, n, float(pi)) + end end # value close to boundary between series 1 and 2 in arXiv:2010.09860 diff --git a/test/Li1.jl b/test/Li1.jl index d3dc8fd..307d07c 100644 --- a/test/Li1.jl +++ b/test/Li1.jl @@ -59,4 +59,12 @@ # test value that causes overflow if squared @test PolyLog.li1(1e300 + 1im) ≈ -690.77552789821371 + 3.14159265358979im rtol=eps(Float64) @test PolyLog.li1(1.0 + 1e300im) ≈ -690.77552789821371 + 1.5707963267948966im rtol=eps(Float64) + + #ForwardDiff Test + if isdefined(Base,:get_extension) + @test ForwardDiff.derivative(PolyLog.reli1,float(pi)) == 1/(1 - pi) + @test ForwardDiff.derivative(PolyLog.reli1,0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli1, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli1, float(pi)) + end end diff --git a/test/Li2.jl b/test/Li2.jl index 580d937..2213df9 100644 --- a/test/Li2.jl +++ b/test/Li2.jl @@ -66,4 +66,12 @@ # test value that causes overflow if squared @test PolyLog.li2(1e300 + 1im) ≈ -238582.12510339421 + 2170.13532372464im rtol=eps(Float64) @test PolyLog.li2(1.0 + 1e300im) ≈ -238585.82620504462 + 1085.06766186232im rtol=eps(Float64) + + #ForwardDiff Test + if isdefined(Base,:get_extension) + @test ForwardDiff.derivative(PolyLog.reli2,float(pi)) == PolyLog.reli1(float(pi))/float(pi) + @test ForwardDiff.derivative(PolyLog.reli2,0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli2, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli2, float(pi)) + end end diff --git a/test/Li3.jl b/test/Li3.jl index dc0dddb..1f815c0 100644 --- a/test/Li3.jl +++ b/test/Li3.jl @@ -60,4 +60,12 @@ # test value that causes overflow if squared @test PolyLog.li3(1e300 + 1im) ≈ -5.4934049431527088e7 + 749538.186928224im rtol=eps(Float64) @test PolyLog.li3(1.0 + 1e300im) ≈ -5.4936606061973454e7 + 374771.031356405im rtol=eps(Float64) + + #ForwardDiff Test + if isdefined(Base, :get_extension) + @test ForwardDiff.derivative(PolyLog.reli3,float(pi)) == PolyLog.reli2(float(pi))/float(pi) + @test ForwardDiff.derivative(PolyLog.reli3,0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli3, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli3, float(pi)) + end end diff --git a/test/Li4.jl b/test/Li4.jl index 70e6aaf..fb04112 100644 --- a/test/Li4.jl +++ b/test/Li4.jl @@ -52,4 +52,12 @@ # test value that causes overflow if squared @test PolyLog.li4(1e300 + 1im) ≈ -9.4863817894708364e9 + 1.725875455850714e8im rtol=eps(Float64) @test PolyLog.li4(1.0 + 1e300im) ≈ -9.4872648206269765e9 + 8.62951114411071e7im rtol=eps(Float64) + + #ForwardDiff Test + if isdefined(Base,:get_extension) + @test ForwardDiff.derivative(PolyLog.reli4,float(pi)) == PolyLog.reli3(float(pi))/float(pi) + @test ForwardDiff.derivative(PolyLog.reli4,0.0) == 1.0 + ChainRulesTestUtils.test_frule(PolyLog.reli4, 0.0) + ChainRulesTestUtils.test_rrule(PolyLog.reli4, float(pi)) + end end diff --git a/test/runtests.jl b/test/runtests.jl index 5a8a7fc..98f6556 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,9 @@ using Test import PolyLog +if isdefined(Base,:get_extension) + import ForwardDiff + import ChainRulesTestUtils +end include("TestPrecision.jl") include("DataReader.jl")