Skip to content

Commit

Permalink
Feature: add ForwardDiff and ChainRules extension (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
longemen3000 authored Jun 8, 2024
1 parent 542c6fe commit 281a6bf
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 1 deletion.
14 changes: 13 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,22 @@ version = "2.4.2"

[compat]
julia = "1.0"
ChainRulesCore = "1"
ForwardDiff = "0.10,0.11"

[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
PolyLogForwardDiffExt = "ForwardDiff"
PolyLogChainRulesExt = "ChainRulesCore"

[targets]
test = ["Test", "DelimitedFiles"]
test = ["Test", "DelimitedFiles", "ForwardDiff", "ChainRulesTestUtils"]
11 changes: 11 additions & 0 deletions ext/PolyLogChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module PolyLogChainRulesExt

using PolyLog
using ChainRulesCore
ChainRulesCore.@scalar_rule PolyLog.reli4(x) ifelse(iszero(x), one(x), PolyLog.reli3(x)/x)
ChainRulesCore.@scalar_rule PolyLog.reli3(x) ifelse(iszero(x), one(x), PolyLog.reli2(x)/x)
ChainRulesCore.@scalar_rule PolyLog.reli2(x) ifelse(iszero(x), one(x), PolyLog.reli1(x)/x)
ChainRulesCore.@scalar_rule PolyLog.reli1(x) one(x)/(one(x)-x)
ChainRulesCore.@scalar_rule(PolyLog.reli(n, x),(ChainRulesCore.NoTangent(),ifelse(iszero(x), one(x), PolyLog.reli(n-1,x)/x)))

end #module
53 changes: 53 additions & 0 deletions ext/PolyLogForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module PolyLogForwardDiffExt

using PolyLog
using ForwardDiff
using ForwardDiff: Dual, partials
function PolyLog.reli4(d::Dual{T}) where T
val = ForwardDiff.value(d)
if iszero(val)
return Dual{T}(one(val), partials(d))
end
x = reli4(val)
dx = reli3(val)/val
return Dual{T}(x, dx * partials(d))
end

function PolyLog.reli3(d::Dual{T}) where T
val = ForwardDiff.value(d)
if iszero(val)
return Dual{T}(one(val), partials(d))
end
x = reli3(val)
dx = reli2(val)/val
return Dual{T}(x, dx * partials(d))
end

function PolyLog.reli2(d::Dual{T}) where T
val = ForwardDiff.value(d)
if iszero(val)
return Dual{T}(one(val), partials(d))
end
x = reli2(val)
dx = reli1(val)/val
return Dual{T}(x, dx * partials(d))
end

function PolyLog.reli1(d::Dual{T}) where T
val = ForwardDiff.value(d)
x = reli1(val)
dx = one(val) / (one(val) - val)
return Dual{T}(x, dx*partials(d))
end

function PolyLog.reli(n::Integer,d::Dual{T}) where T
val = ForwardDiff.value(d)
if iszero(val)
return Dual{T}(one(val), partials(d))
end
x = reli(n,val)
dx = PolyLog.reli(n-1,val)/val
return Dual{T}(x, dx*partials(d))
end

end #module
8 changes: 8 additions & 0 deletions test/Li.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ end
@test PolyLog.li(n, 1//1 + 0//1im) zeta
@test PolyLog.li(n, 1 + 0im) zeta
@test PolyLog.li(n, BigFloat("1.0") + 0im) == PolyLog.zeta(n, BigFloat)

#ForwardDiff Test
if isdefined(Base,:get_extension)
@test ForwardDiff.derivative(Base.Fix1(PolyLog.reli,n),float(pi)) == PolyLog.reli(n-1,float(pi))/float(pi)
@test ForwardDiff.derivative(Base.Fix1(PolyLog.reli,n),0.0) == 1.0
ChainRulesTestUtils.test_frule(PolyLog.reli, n, 0.0)
ChainRulesTestUtils.test_rrule(PolyLog.reli, n, float(pi))
end
end

# value close to boundary between series 1 and 2 in arXiv:2010.09860
Expand Down
8 changes: 8 additions & 0 deletions test/Li1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,12 @@
# test value that causes overflow if squared
@test PolyLog.li1(1e300 + 1im) -690.77552789821371 + 3.14159265358979im rtol=eps(Float64)
@test PolyLog.li1(1.0 + 1e300im) -690.77552789821371 + 1.5707963267948966im rtol=eps(Float64)

#ForwardDiff Test
if isdefined(Base,:get_extension)
@test ForwardDiff.derivative(PolyLog.reli1,float(pi)) == 1/(1 - pi)
@test ForwardDiff.derivative(PolyLog.reli1,0.0) == 1.0
ChainRulesTestUtils.test_frule(PolyLog.reli1, 0.0)
ChainRulesTestUtils.test_rrule(PolyLog.reli1, float(pi))
end
end
8 changes: 8 additions & 0 deletions test/Li2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,12 @@
# test value that causes overflow if squared
@test PolyLog.li2(1e300 + 1im) -238582.12510339421 + 2170.13532372464im rtol=eps(Float64)
@test PolyLog.li2(1.0 + 1e300im) -238585.82620504462 + 1085.06766186232im rtol=eps(Float64)

#ForwardDiff Test
if isdefined(Base,:get_extension)
@test ForwardDiff.derivative(PolyLog.reli2,float(pi)) == PolyLog.reli1(float(pi))/float(pi)
@test ForwardDiff.derivative(PolyLog.reli2,0.0) == 1.0
ChainRulesTestUtils.test_frule(PolyLog.reli2, 0.0)
ChainRulesTestUtils.test_rrule(PolyLog.reli2, float(pi))
end
end
8 changes: 8 additions & 0 deletions test/Li3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,12 @@
# test value that causes overflow if squared
@test PolyLog.li3(1e300 + 1im) -5.4934049431527088e7 + 749538.186928224im rtol=eps(Float64)
@test PolyLog.li3(1.0 + 1e300im) -5.4936606061973454e7 + 374771.031356405im rtol=eps(Float64)

#ForwardDiff Test
if isdefined(Base, :get_extension)
@test ForwardDiff.derivative(PolyLog.reli3,float(pi)) == PolyLog.reli2(float(pi))/float(pi)
@test ForwardDiff.derivative(PolyLog.reli3,0.0) == 1.0
ChainRulesTestUtils.test_frule(PolyLog.reli3, 0.0)
ChainRulesTestUtils.test_rrule(PolyLog.reli3, float(pi))
end
end
8 changes: 8 additions & 0 deletions test/Li4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,12 @@
# test value that causes overflow if squared
@test PolyLog.li4(1e300 + 1im) -9.4863817894708364e9 + 1.725875455850714e8im rtol=eps(Float64)
@test PolyLog.li4(1.0 + 1e300im) -9.4872648206269765e9 + 8.62951114411071e7im rtol=eps(Float64)

#ForwardDiff Test
if isdefined(Base,:get_extension)
@test ForwardDiff.derivative(PolyLog.reli4,float(pi)) == PolyLog.reli3(float(pi))/float(pi)
@test ForwardDiff.derivative(PolyLog.reli4,0.0) == 1.0
ChainRulesTestUtils.test_frule(PolyLog.reli4, 0.0)
ChainRulesTestUtils.test_rrule(PolyLog.reli4, float(pi))
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using Test
import PolyLog
if isdefined(Base,:get_extension)
import ForwardDiff
import ChainRulesTestUtils
end

include("TestPrecision.jl")
include("DataReader.jl")
Expand Down

0 comments on commit 281a6bf

Please sign in to comment.