Automatic Differentiation

Automatic Differentiation

Classical back propagation

Back propagation has $O(M)$ complexity in obtaining gradients, with $M$ the number of circuit parameters. We can use autodiff(:BP) to mark differentiable units in a circuit. Let's see an example.

Example: Classical back propagation

julia> using Yao

julia> circuit = chain(4, repeat(4, H, 1:4), put(4, 3=>Rz(0.5)), control(2, 1=>X), put(4, 4=>Ry(0.2)))
Total: 4, DataType: Complex{Float64}
chain
├─ repeat on (1, 2, 3, 4)
│  └─ H gate
├─ put on (3)
│  └─ Rot Z gate: 0.5
├─ control(2)
│  └─ (1,)=>X gate
└─ put on (4)
   └─ Rot Y gate: 0.2

julia> circuit = circuit |> autodiff(:BP)
Total: 4, DataType: Complex{Float64}
chain
├─ repeat on (1, 2, 3, 4)
│  └─ H gate
├─ [∂] put on (3)
│  └─ Rot Z gate: 0.5
├─ control(2)
│  └─ (1,)=>X gate
└─ [∂] put on (4)
   └─ Rot Y gate: 0.2

From the output, we can see parameters of blocks marked by [∂] will be differentiated automatically.

julia> op = put(4, 3=>Y);  # loss is defined as its expectation.

julia> ψ = rand_state(4);

julia> ψ |> circuit;

julia> δ = ψ |> op;     # ∂f/∂ψ*

julia> backward!(δ, circuit);    # classical back propagation!

Here, the loss is L = <ψ|op|ψ>, δ = ∂f/∂ψ* is the error to be back propagated. The gradient is related to $δ$ as $\frac{\partial f}{\partial\theta} = 2\Re[\frac{\partial f}{\partial\psi^*}\frac{\partial \psi^*}{\partial\theta}]$

In face, backward!(δ, circuit) on wave function is equivalent to calculating δ |> circuit' (apply!(reg, Daggered{<:BPDiff})). This function is overloaded so that gradientis for parameters are also calculated and stored in BPDiff block at the same time.

Finally, we use gradient to collect gradients in the ciruits.

julia> g1 = gradient(circuit)  # collect gradient
2-element Array{Float64,1}:
 -0.01441540767478676
  2.7755575615628914e-17
Note

In real quantum devices, gradients can not be back propagated, this is why we need the following section.

Quantum circuit differentiation

Experimental applicable differentiation strategies are based on the following two papers

The former differentiation scheme is for observables, and the latter is for statistic functionals (U statistics). One may find the derivation of both schemes in this post.

Realizable quantum circuit gradient finding algorithms have complexity $O(M^2)$.

Example: Practical quantum differenciation

We use QDiff block to mark differentiable circuits

julia> using Yao, Yao.Blocks

julia> c = chain(put(4, 1=>Rx(0.5)), control(4, 1, 2=>Ry(0.5)), kron(4, 2=>Rz(0.3), 3=>Rx(0.7))) |> autodiff(:QC)  # automatically mark differentiable blocks
Total: 4, DataType: Complex{Float64}
chain
├─ put on (1)
│  └─ [̂∂] Rot X gate: 0.5
├─ control(1)
│  └─ (2,)=>Rot Y gate: 0.5
└─ kron
   ├─ 2=>[̂∂] Rot Z gate: 0.3
   └─ 3=>[̂∂] Rot X gate: 0.7

Blocks marked by [̂∂] will be differentiated.

julia> dbs = collect(c, QDiff)  # collect all QDiff blocks
Sequence
├─ [̂∂] Rot X gate: 0.5
├─ [̂∂] Rot Z gate: 0.3
└─ [̂∂] Rot X gate: 0.7

Here, we recommend collect QDiff blocks into a sequence using collect API for future calculations. Then, we can get the gradient one by one, using opdiff

julia> ed = opdiff(dbs[1], put(4, 1=>Z)) do   # the exact differentiation with respect to first QDiff block.
           zero_state(4) |> c
       end
-0.4794255386042028

Here, contents in the do-block returns the loss, it must be the expectation value of an observable.

For results checking, we get the numeric gradient use numdiff

julia> ed = numdiff(dbs[1]) do    # compare with numerical differentiation
          expect(put(4, 1=>Z), zero_state(4) |> c) |> real
       end
-0.4794175482185137

This numerical differentiation scheme is always applicable (even the loss is not an observable), but with numeric errors introduced by finite step size.

We can also get all gradients using broadcasting

julia> ed = opdiff.(()->zero_state(4) |> c, dbs, Ref(kron(4, 1=>Z, 2=>X)))   # using broadcast to get all gradients.
3-element Array{Float64,1}:
 -0.109791495292767
  0.008672047291031427
  0.0
Note

Since BP is not implemented for QDiff blocks, the memory consumption is much less since we don't cache intermediate results anymore.