Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
Fix issue 170 (#171)
Browse files Browse the repository at this point in the history
* fix issue 170

* fix chainrules patch
  • Loading branch information
GiggleLiu authored Nov 15, 2021
1 parent 37aea0e commit 9f48b62
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 9 deletions.
33 changes: 26 additions & 7 deletions src/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import ChainRulesCore: rrule, @non_differentiable, NoTangent
import ChainRulesCore: rrule, @non_differentiable, NoTangent, Tangent

function rrule(::typeof(apply), reg::ArrayReg, block::AbstractBlock)
out = apply(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block)
return (NoTangent(), inδ, paramsδ)
return (NoTangent(), inδ, dispatch(block, paramsδ))
end
end

function rrule(::typeof(apply), reg::ArrayReg, block::Add)
out = apply(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block; in = reg)
return (NoTangent(), inδ, paramsδ)
return (NoTangent(), inδ, dispatch(block, paramsδ))
end
end

function rrule(::typeof(dispatch), block::AbstractBlock, params)
out = dispatch(block, params)
out, function (outδ)
(NoTangent(), NoTangent(), outδ)
(NoTangent(), NoTangent(), parameters(outδ))
end
end

Expand All @@ -34,11 +34,30 @@ function rrule(::typeof(expect), op::AbstractBlock, reg::AbstractRegister{B}) wh
end
end

function rrule(::Type{Matrix}, block::AbstractBlock)
out = Matrix(block)
function rrule(::typeof(expect), op::AbstractBlock, reg_and_circuit::Pair{<:ArrayReg{B},<:AbstractBlock}) where {B}
out = expect(op, reg_and_circuit)
out, function (outδ)
greg, gcircuit = expect_g(op, reg_and_circuit)
for b in 1:B
viewbatch(greg, b).state .*= 2 * outδ[b]
end
return (NoTangent(), NoTangent(), Tangent{typeof(reg_and_circuit)}(; first=greg, second=dispatch(reg_and_circuit.second, gcircuit)))
end
end

function rrule(::Type{T}, block::AbstractBlock) where T<:Matrix
out = T(block)
out, function (outδ)
paramsδ = mat_back(block, outδ)
return (NoTangent(), dispatch(block, paramsδ))
end
end

function rrule(::typeof(mat), ::Type{T}, block::AbstractBlock) where T
out = mat(T, block)
out, function (outδ)
paramsδ = mat_back(block, outδ)
return (NoTangent(), paramsδ)
return (NoTangent(), NoTangent(), dispatch(block, paramsδ))
end
end

Expand Down
51 changes: 49 additions & 2 deletions test/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import Zygote, ForwardDiff
using Random, Test
using YaoBlocks, YaoArrayRegister

function Zygote.accum(a::AbstractBlock, b::AbstractBlock)
dispatch(a, parameters(a) + parameters(b))
end

@testset "rules" begin
h = put(5, 3 => Z) + put(5, 2 => X)
c = chain(put(5, 2 => chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
Expand Down Expand Up @@ -31,8 +35,7 @@ using YaoBlocks, YaoArrayRegister
@test Zygote.gradient(x -> real(sum(abs2, statevec(x'))), r)[1].state g1
# zygote does not work if `sin` is not here,
# because it gives an adjoint of different type as the output matrix type.
# do not modify the data type please! Zygote
@test Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1]
@test parameters(Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1])
ForwardDiff.gradient(x -> real(sum(sin, Matrix(dispatch(c, x)))), parameters(c))
end

Expand All @@ -46,6 +49,7 @@ end
sum(real(st .* st))
end

# apply
reg0 = zero_state(5)
params = rand!(parameters(c))
paramsδ = Zygote.gradient(params -> loss(reg0, dispatch(c, params)), params)[1]
Expand All @@ -64,6 +68,49 @@ end
)
@test fregδ reinterpret(Float64, regδ.state)
@test fparamsδ paramsδ

# expect and fidelity
c = chain(put(5, 5=>Rx(1.5)), put(5,1=>Rx(0.4)), put(5,4=>Rx(0.2)), put(5, 2 => chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
h = chain(repeat(5, X, 1:5))
reg = rand_state(5)
function loss2(reg::AbstractRegister, circuit::AbstractBlock{N}) where {N}
return 5*real(expect(h, copy(reg) => circuit) + fidelity(reg, apply(reg, circuit)))
end
params = rand!(parameters(c))
fδc = ForwardDiff.gradient(
params ->
loss2(ArrayReg(Matrix{Complex{eltype(params)}}(reg.state)), dispatch(c, params)),
params,
)
δr, δc = Zygote.gradient((reg, params)->loss2(reg, dispatch(c, params)),reg, params)
@test δc fδc

fregδ = ForwardDiff.gradient(
x -> loss2(
ArrayReg([Complex(x[2i-1], x[2i]) for i in 1:length(x)÷2]),
dispatch(c, Vector{real(eltype(x))}(params)),
),
reinterpret(Float64, reg.state),
)
@test fregδ reinterpret(Float64, δr.state)

# operator fidelity
c = chain(put(5, 5=>Rx(1.5)), put(5,1=>Rx(0.4)), put(5,4=>Rx(0.2)), put(5, 2 => chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
h = chain(repeat(5, X, 1:5))
function loss3(circuit::AbstractBlock{N}, h) where {N}
return operator_fidelity(circuit, h)
end
params = rand!(parameters(c))
fδc = ForwardDiff.gradient(
params ->
loss3(dispatch(c, params), h),
params,
)
δc, = Zygote.gradient(p->loss3(dispatch(c, p), h), params)
@test δc fδc

# NOTE: operator back propagation in expect is not implemented!
# to differentiate operators, we need to use the expensive `mat_back` function.
end

@testset "add block" begin
Expand Down

0 comments on commit 9f48b62

Please sign in to comment.