Automatic Differentiation
Classical back propagation
Back propagation has $O(M)$ complexity in obtaining gradients, with $M$ the number of circuit parameters. We can use autodiff(:BP)
to mark differentiable units in a circuit. Let's see an example.
Example: Classical back propagation
julia> using Yao
julia> circuit = chain(4, repeat(4, H, 1:4), put(4, 3=>Rz(0.5)), control(2, 1=>X), put(4, 4=>Ry(0.2)))
Total: 4, DataType: Complex{Float64}
chain
├─ repeat on (1, 2, 3, 4)
│ └─ H gate
├─ put on (3)
│ └─ Rot Z gate: 0.5
├─ control(2)
│ └─ (1,)=>X gate
└─ put on (4)
└─ Rot Y gate: 0.2
julia> circuit = circuit |> autodiff(:BP)
Total: 4, DataType: Complex{Float64}
chain
├─ repeat on (1, 2, 3, 4)
│ └─ H gate
├─ [∂] put on (3)
│ └─ Rot Z gate: 0.5
├─ control(2)
│ └─ (1,)=>X gate
└─ [∂] put on (4)
└─ Rot Y gate: 0.2
From the output, we can see parameters of blocks marked by [∂]
will be differentiated automatically.
julia> op = put(4, 3=>Y); # loss is defined as its expectation.
julia> ψ = rand_state(4);
julia> ψ |> circuit;
julia> δ = ψ |> op; # ∂f/∂ψ*
julia> backward!(δ, circuit); # classical back propagation!
Here, the loss is L = <ψ|op|ψ>
, δ = ∂f/∂ψ*
is the error to be back propagated. The gradient is related to $δ$ as $\frac{\partial f}{\partial\theta} = 2\Re[\frac{\partial f}{\partial\psi^*}\frac{\partial \psi^*}{\partial\theta}]$
In face, backward!(δ, circuit)
on wave function is equivalent to calculating δ |> circuit'
(apply!(reg, Daggered{<:BPDiff})
). This function is overloaded so that gradientis for parameters are also calculated and stored in BPDiff
block at the same time.
Finally, we use gradient
to collect gradients in the ciruits.
julia> g1 = gradient(circuit) # collect gradient
2-element Array{Float64,1}:
-0.01441540767478676
2.7755575615628914e-17
In real quantum devices, gradients can not be back propagated, this is why we need the following section.
Quantum circuit differentiation
Experimental applicable differentiation strategies are based on the following two papers
- Quantum Circuit Learning, Kosuke Mitarai, Makoto Negoro, Masahiro Kitagawa, Keisuke Fujii
- Differentiable Learning of Quantum Circuit Born Machine, Jin-Guo Liu, Lei Wang
The former differentiation scheme is for observables, and the latter is for statistic functionals (U statistics). One may find the derivation of both schemes in this post.
Realizable quantum circuit gradient finding algorithms have complexity $O(M^2)$.
Example: Practical quantum differenciation
We use QDiff
block to mark differentiable circuits
julia> using Yao, Yao.Blocks
julia> c = chain(put(4, 1=>Rx(0.5)), control(4, 1, 2=>Ry(0.5)), kron(4, 2=>Rz(0.3), 3=>Rx(0.7))) |> autodiff(:QC) # automatically mark differentiable blocks
Total: 4, DataType: Complex{Float64}
chain
├─ put on (1)
│ └─ [̂∂] Rot X gate: 0.5
├─ control(1)
│ └─ (2,)=>Rot Y gate: 0.5
└─ kron
├─ 2=>[̂∂] Rot Z gate: 0.3
└─ 3=>[̂∂] Rot X gate: 0.7
Blocks marked by [̂∂]
will be differentiated.
julia> dbs = collect(c, QDiff) # collect all QDiff blocks
Sequence
├─ [̂∂] Rot X gate: 0.5
├─ [̂∂] Rot Z gate: 0.3
└─ [̂∂] Rot X gate: 0.7
Here, we recommend collect QDiff
blocks into a sequence using collect
API for future calculations. Then, we can get the gradient one by one, using opdiff
julia> ed = opdiff(dbs[1], put(4, 1=>Z)) do # the exact differentiation with respect to first QDiff block.
zero_state(4) |> c
end
-0.4794255386042028
Here, contents in the do-block returns the loss, it must be the expectation value of an observable.
For results checking, we get the numeric gradient use numdiff
julia> ed = numdiff(dbs[1]) do # compare with numerical differentiation
expect(put(4, 1=>Z), zero_state(4) |> c) |> real
end
-0.4794175482185137
This numerical differentiation scheme is always applicable (even the loss is not an observable), but with numeric errors introduced by finite step size.
We can also get all gradients using broadcasting
julia> ed = opdiff.(()->zero_state(4) |> c, dbs, Ref(kron(4, 1=>Z, 2=>X))) # using broadcast to get all gradients.
3-element Array{Float64,1}:
-0.109791495292767
0.008672047291031427
0.0
Since BP is not implemented for QDiff
blocks, the memory consumption is much less since we don't cache intermediate results anymore.