Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
Format .jl files (#169)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] authored Nov 9, 2021
1 parent 46351be commit 37aea0e
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 47 deletions.
34 changes: 19 additions & 15 deletions src/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end
function rrule(::typeof(apply), reg::ArrayReg, block::Add)
out = apply(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block; in=reg)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block; in = reg)
return (NoTangent(), inδ, paramsδ)
end
end
Expand All @@ -27,8 +27,8 @@ function rrule(::typeof(expect), op::AbstractBlock, reg::AbstractRegister{B}) wh
out = expect(op, reg)
out, function (outδ)
greg = expect_g(op, reg)
for b=1:B
viewbatch(greg, b).state .*= 2*outδ[b]
for b in 1:B
viewbatch(greg, b).state .*= 2 * outδ[b]
end
return (NoTangent(), NoTangent(), greg)
end
Expand All @@ -42,26 +42,30 @@ function rrule(::Type{Matrix}, block::AbstractBlock)
end
end

function rrule(::Type{ArrayReg{B}}, raw::AbstractArray) where B
ArrayReg{B}(raw), adjy->(NoTangent(), reshape(adjy.state, size(raw)))
function rrule(::Type{ArrayReg{B}}, raw::AbstractArray) where {B}
ArrayReg{B}(raw), adjy -> (NoTangent(), reshape(adjy.state, size(raw)))
end

function rrule(::Type{ArrayReg}, raw::AbstractArray)
ArrayReg(raw), adjy->(NoTangent(), reshape(adjy.state, size(raw)))
ArrayReg(raw), adjy -> (NoTangent(), reshape(adjy.state, size(raw)))
end

function rrule(::typeof(copy), reg::ArrayReg) where B
copy(reg), adjy->(NoTangent(), adjy)
function rrule(::typeof(copy), reg::ArrayReg) where {B}
copy(reg), adjy -> (NoTangent(), adjy)
end

_totype(::Type{T}, x::AbstractArray{T}) where T = x
_totype(::Type{T}, x::AbstractArray{T}) where {T} = x
_totype(::Type{T}, x::AbstractArray{T2}) where {T,T2} = convert.(T, x)
rrule(::typeof(state), reg::ArrayReg{B,T}) where {B,T} = state(reg), adjy->(NoTangent(), ArrayReg(_totype(T, adjy)))
rrule(::typeof(statevec), reg::ArrayReg{B,T}) where {B,T} = statevec(reg), adjy->(NoTangent(), ArrayReg(_totype(T, adjy)))
rrule(::typeof(state), reg::AdjointArrayReg{B,T}) where {B,T} = state(reg), adjy->(NoTangent(), ArrayReg(_totype(T, adjy)')')
rrule(::typeof(statevec), reg::AdjointArrayReg{B,T}) where {B,T} = statevec(reg), adjy->(NoTangent(), ArrayReg(_totype(T, adjy)')')
rrule(::typeof(parent), reg::AdjointArrayReg) = parent(reg), adjy->(NoTangent(), adjy')
rrule(::typeof(Base.adjoint), reg::ArrayReg) = Base.adjoint(reg), adjy->(NoTangent(), parent(adjy))
rrule(::typeof(state), reg::ArrayReg{B,T}) where {B,T} =
state(reg), adjy -> (NoTangent(), ArrayReg(_totype(T, adjy)))
rrule(::typeof(statevec), reg::ArrayReg{B,T}) where {B,T} =
statevec(reg), adjy -> (NoTangent(), ArrayReg(_totype(T, adjy)))
rrule(::typeof(state), reg::AdjointArrayReg{B,T}) where {B,T} =
state(reg), adjy -> (NoTangent(), ArrayReg(_totype(T, adjy)')')
rrule(::typeof(statevec), reg::AdjointArrayReg{B,T}) where {B,T} =
statevec(reg), adjy -> (NoTangent(), ArrayReg(_totype(T, adjy)')')
rrule(::typeof(parent), reg::AdjointArrayReg) = parent(reg), adjy -> (NoTangent(), adjy')
rrule(::typeof(Base.adjoint), reg::ArrayReg) = Base.adjoint(reg), adjy -> (NoTangent(), parent(adjy))
@non_differentiable nparameters(::Any)
@non_differentiable zero_state(args...)
@non_differentiable rand_state(args...)
Expand Down
98 changes: 66 additions & 32 deletions test/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,64 +3,98 @@ using Random, Test
using YaoBlocks, YaoArrayRegister

@testset "rules" begin
h = put(5, 3=>Z) + put(5, 2=>X)
c = chain(put(5, 2=>chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3=>Rx(-0.5)))
h = put(5, 3 => Z) + put(5, 2 => X)
c = chain(put(5, 2 => chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
r = rand_state(5)
g0 = reinterpret(ComplexF64, ForwardDiff.gradient(x->real(expect(h, ArrayReg([Complex(x[2i-1],x[2i]) for i=1:length(x)÷2]))), reinterpret(Float64,r.state)))
@test Zygote.gradient(x->real(expect(h, ArrayReg(x))), r.state)[1] g0
@test Zygote.gradient(x->real(expect(h, ArrayReg{1}(reshape(statevec(x),:,1)))), r)[1].state g0
@test Zygote.gradient(x->real(expect(h, ArrayReg(reshape(state(x),:,1)))), r)[1].state g0
@test Zygote.gradient(x->real(expect(h, copy(x))), r)[1].state g0
@test Zygote.gradient(x->real(expect(h, parent(x'))), r)[1].state g0
g0 = reinterpret(
ComplexF64,
ForwardDiff.gradient(
x -> real(expect(h, ArrayReg([Complex(x[2i-1], x[2i]) for i in 1:length(x)÷2]))),
reinterpret(Float64, r.state),
),
)
@test Zygote.gradient(x -> real(expect(h, ArrayReg(x))), r.state)[1] g0
@test Zygote.gradient(x -> real(expect(h, ArrayReg{1}(reshape(statevec(x), :, 1)))), r)[1].state
g0
@test Zygote.gradient(x -> real(expect(h, ArrayReg(reshape(state(x), :, 1)))), r)[1].state g0
@test Zygote.gradient(x -> real(expect(h, copy(x))), r)[1].state g0
@test Zygote.gradient(x -> real(expect(h, parent(x'))), r)[1].state g0

g1 = reinterpret(ComplexF64, ForwardDiff.gradient(x->real(sum(abs2, [Complex(x[2i-1],x[2i]) for i=1:length(x)÷2])), reinterpret(Float64,r.state)))
@test Zygote.gradient(x->real(sum(abs2, state(x'))), r)[1].state g1
@test Zygote.gradient(x->real(sum(abs2, statevec(x'))), r)[1].state g1
g1 = reinterpret(
ComplexF64,
ForwardDiff.gradient(
x -> real(sum(abs2, [Complex(x[2i-1], x[2i]) for i in 1:length(x)÷2])),
reinterpret(Float64, r.state),
),
)
@test Zygote.gradient(x -> real(sum(abs2, state(x'))), r)[1].state g1
@test Zygote.gradient(x -> real(sum(abs2, statevec(x'))), r)[1].state g1
# zygote does not work if `sin` is not here,
# because it gives an adjoint of different type as the output matrix type.
# do not modify the data type please! Zygote
@test Zygote.gradient(x->real(sum(sin, Matrix(x))), c)[1] ForwardDiff.gradient(x->real(sum(sin, Matrix(dispatch(c, x)))), parameters(c))
@test Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1]
ForwardDiff.gradient(x -> real(sum(sin, Matrix(dispatch(c, x)))), parameters(c))
end

@testset "adwith zygote" begin
c = chain(put(5, 2=>chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3=>Rx(-0.5)))
c = chain(put(5, 2 => chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
dispatch!(c, :random)

function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where N
function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where {N}
reg = apply(copy(reg), circuit)
st = state(reg)
sum(real(st.*st))
sum(real(st .* st))
end

reg0 = zero_state(5)
params = rand!(parameters(c))
paramsδ = Zygote.gradient(params->loss(reg0, dispatch(c, params)), params)[1]
regδ = Zygote.gradient(reg->loss(reg, c), reg0)[1]
fparamsδ = ForwardDiff.gradient(params->loss(ArrayReg(Matrix{Complex{eltype(params)}}(reg0.state)), dispatch(c, params)), params)
fregδ = ForwardDiff.gradient(x->loss(ArrayReg([Complex(x[2i-1],x[2i]) for i=1:length(x)÷2]), dispatch(c, Vector{real(eltype(x))}(parameters(c)))), reinterpret(Float64,reg0.state))
paramsδ = Zygote.gradient(params -> loss(reg0, dispatch(c, params)), params)[1]
regδ = Zygote.gradient(reg -> loss(reg, c), reg0)[1]
fparamsδ = ForwardDiff.gradient(
params ->
loss(ArrayReg(Matrix{Complex{eltype(params)}}(reg0.state)), dispatch(c, params)),
params,
)
fregδ = ForwardDiff.gradient(
x -> loss(
ArrayReg([Complex(x[2i-1], x[2i]) for i in 1:length(x)÷2]),
dispatch(c, Vector{real(eltype(x))}(parameters(c))),
),
reinterpret(Float64, reg0.state),
)
@test fregδ reinterpret(Float64, regδ.state)
@test fparamsδ paramsδ
end

@testset "add block" begin
H = sum([chain(5, put(k=>Z)) for k=1:5])
c = chain(put(5, 2=>chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3=>Rx(-0.5)))
H = sum([chain(5, put(k => Z)) for k in 1:5])
c = chain(put(5, 2 => chain(Rx(0.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
dispatch!(c, :random)
function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where N
reg = apply(copy(reg), circuit)
st = state(reg)
reg2 = apply(copy(reg), H)
st2 = state(reg2)
sum(real(st.*st2))
function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where {N}
reg = apply(copy(reg), circuit)
st = state(reg)
reg2 = apply(copy(reg), H)
st2 = state(reg2)
sum(real(st .* st2))
end

reg0 = zero_state(5)
params = rand!(parameters(c))
paramsδ = Zygote.gradient(params->loss(reg0, dispatch(c, params)), params)[1]
fparamsδ = ForwardDiff.gradient(params->loss(ArrayReg(Matrix{Complex{eltype(params)}}(reg0.state)), dispatch(c, params)), params)
paramsδ = Zygote.gradient(params -> loss(reg0, dispatch(c, params)), params)[1]
fparamsδ = ForwardDiff.gradient(
params ->
loss(ArrayReg(Matrix{Complex{eltype(params)}}(reg0.state)), dispatch(c, params)),
params,
)
@test fparamsδ paramsδ

regδ = Zygote.gradient(reg->loss(reg, c), reg0)[1]
fregδ = ForwardDiff.gradient(x->loss(ArrayReg([Complex(x[2i-1],x[2i]) for i=1:length(x)÷2]), dispatch(c, Vector{real(eltype(x))}(parameters(c)))), reinterpret(Float64,reg0.state))
regδ = Zygote.gradient(reg -> loss(reg, c), reg0)[1]
fregδ = ForwardDiff.gradient(
x -> loss(
ArrayReg([Complex(x[2i-1], x[2i]) for i in 1:length(x)÷2]),
dispatch(c, Vector{real(eltype(x))}(parameters(c))),
),
reinterpret(Float64, reg0.state),
)
@test fregδ reinterpret(Float64, regδ.state)
end
end

0 comments on commit 37aea0e

Please sign in to comment.