Quantum Amplitude Amplification#

amplitude_amplification(args: QuantumVariable | QuantumArray | Sequence[QuantumVariable | QuantumArray], state_function: Callable, oracle_function: Callable, kwargs_oracle: dict[str, Any] | None = None, iter: int = 1, reflection_indices: list[int] | None = None) None[source]#

This method performs quantum amplitude amplification.

The problem of quantum amplitude amplification is described as follows:

  • Given a unitary operator \(\mathcal{A}\), let \(\ket{\Psi}=\mathcal{A}\ket{0}\).

  • Write \(\ket{\Psi}=\ket{\Psi_1}+\ket{\Psi_0}\) as a superposition of the orthogonal good and bad components of \(\ket{\Psi}\).

  • Enhance the probability \(a=\langle\Psi_1|\Psi_1\rangle\) that a measurement of \(\ket{\Psi}\) yields a good state.

Let \(\theta_a\in [0,\pi/2]\) such that \(\sin^2(\theta_a)=a\). Then the amplitude amplification operator \(\mathcal Q\) acts as

\[\mathcal Q^j\ket{\Psi}=\frac{1}{\sqrt{a}}\sin((2j+1)\theta_a)\ket{\Psi_1}+\frac{1}{\sqrt{1-a}}\cos((2j+1)\theta_a)\ket{\Psi_0}.\]

Therefore, after \(m\) iterations the probability of measuring a good state is \(\sin^2((2m+1)\theta_a)\).

Parameters:
argsQuantumVariable | QuantumArray | Sequence[QuantumVariable | QuantumArray]

The quantum variable, array, or collection thereof on which amplitude amplification is performed. These variables must already be prepared in the initial state \(\ket{\Psi}\) before calling this method (i.e., the user is responsible for applying the state_function to the zero state prior to execution).

state_functionCallable

A Python function preparing the state \(\ket{\Psi}\) from the zero state. The required signature of this function depends on the input args:

  • if args is a single variable or array, it receives that single object.

  • if args is a list, the elements are unpacked and passed as separate positional arguments (e.g., for args=[qv1, qv2], the signature must be state_function(qv1, qv2)).

Although args must already be in the state \(\ket{\Psi}\) upon input, this function is strictly required internally to construct the amplitude amplification operator \(\mathcal{Q}\) (specifically to perform the reflection about the initial state).

oracle_functionCallable

A Python function tagging the good state \(\ket{\Psi_1}\). Like state_function, its required signature matches the structure of args: it takes a single argument if args is a single object, or unpacked positional arguments if args is a list.

kwargs_oracledict, optional

A dictionary containing keyword arguments for the oracle. The default is None.

iterint, optional

The exact amount of amplitude amplification iterations to perform. The default is 1.

reflection_indiceslist[int], optional

A list of indices indicating with respect to which variables the reflection is performed, i.e., oblivious amplitude amplification is performed. Indices correspond to the flattened args (e.g., if args = QuantumArray(QuantumFloat(3), (6,)), reflection_indices=[0,1,2,3] corresponds to the first four variables in the array). By default, the reflection is performed with respect to all variables in args (standard amplitude amplification).

Examples

We define a function that prepares the state \(\ket{\Psi}=\cos(\frac{\pi}{16})\ket{0}+\sin(\frac{\pi}{16})\ket{1}\) and an oracle that tags the good state \(\ket{1}\). In this case, we have \(a=\sin^2(\frac{\pi}{16})\approx 0.19509\).

from qrisp import z, ry, QuantumBool, amplitude_amplification
import numpy as np

def state_function(qb):
    ry(np.pi/8,qb)

def oracle_function(qb):
    z(qb)

qb = QuantumBool()

state_function(qb)
>>> qb.qs.statevector(decimals=5)
0.98079∣False⟩+0.19509∣True⟩

We can enhance the probability of measuring the good state with amplitude amplification:

>>> amplitude_amplification([qb], state_function, oracle_function)
>>> qb.qs.statevector(decimals=5)
0.83147*|False> + 0.55557*|True>
>>> amplitude_amplification([qb], state_function, oracle_function)
>>> qb.qs.statevector(decimals=5)
0.55557*|False> + 0.83147*|True>
>>> amplitude_amplification([qb], state_function, oracle_function)
>>> qb.qs.statevector(decimals=5)
0.19509*|False> + 0.98079*|True>