Source code for qrisp.alg_primitives.reflection

"""
********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""

import numpy as np
from qrisp import (
    QuantumArray,
    QuantumVariable,
    gate_wrap,
    gphase,
    h,
    mcx,
    mcp,
    x,
    z,
    conjugate,
    invert,
    control,
)
from qrisp.jasp import jlen, qache


#@qache
[docs] @gate_wrap(permeability=[], is_qfree=False) def reflection(qargs, state_function, args=(), kwargs={}, phase=np.pi, reflection_indices=None): r""" Applies a reflection around a state $\ket{\psi}$ of (multiple) QuantumVariables, i.e., applies the operator .. math:: R = ((1-e^{i\phi})\ket{\psi}\bra{\psi}-\mathbb I) = U^{\dagger}((1-e^{i\phi})\ket{0}\bra{0}-I)U, where $\ket{\psi} = U\ket{0}$. Parameters ---------- qargs : QuantumVariable | QuantumArray | list[QuantumVariable | QuantumArray] The (list of) QuantumVariables representing the state to apply the reflection on. state_function : function, optional A Python function ``state_function(*qargs, *args, **kwargs)`` preparing the state $\ket{\psi}$ in variables ``qargs`` around which to reflect. args : tuple, optional Additional arguments for the state function. kwargs : dict, optional Keyword arguments for the state function. phase : float or sympy.Symbol, optional Specifies the phase shift. The default is $\pi$. refection_indices : list[int], optional A list of indices indicating with respect to which variables the reflection is performed. This is used for `oblivious amplitude amplification <https://arxiv.org/pdf/1312.1414>`_. Indices correspond to the flattened ``qargs``, e.g., if ``qargs = QuantumArray(QuantumFloat(3), (6,))``, ``reflection_indices=[0,1,2,3]`` corresponds to the first four variables in the array. By default, the reflection is performed with respect to all variables in ``qargs``. Examples -------- We prepare a QuantumVariable in state $\ket{1}^{\otimes n}$, and reflect around the GHZ state $\frac{1}{\sqrt{2}}(\ket{0}^{\otimes n} + \ket{1}^{\otimes n})$. The resulting state is $\ket{0}^{\otimes n}$. :: from qrisp import QuantumVariable, QuantumArray, h, x, cx, reflection def ghz(qv): h(qv[0]) for i in range(1, qv.size): cx(qv[0], qv[i]) # Prepare |1> state qv = QuantumVariable(5) x(qv) print(qv) # {'11111': 1.0} # Reflection around GHZ state reflection(qv, ghz) print(qv) # {'00000': 1.0} The refletion can also be applied to lists of QuantumVariables and QuantumArrays: :: from qrisp import QuantumVariable, QuantumArray, h, x, cx, reflection, multi_measurement def ghz(qv, qa): h(qv[0]) for i in range(1, qv.size): cx(qv[0], qv[i]) for var in qa: for i in range(var.size): cx(qv[0], var[i]) # Prepare |1> state qv = QuantumVariable(5) qa = QuantumArray(QuantumVariable(3), shape=(3,)) x(qv) x(qa) print(multi_measurement([qv, qa])) # {('11111', OutcomeArray(['111', '111', '111'], dtype=object)): 1.0} # Reflection around GHZ state reflection([qv, qa], ghz) print(multi_measurement([qv, qa])) # {('00000', OutcomeArray(['000', '000', '000'], dtype=object)): 1.0} Addtional arguments can be passed to the state function: :: from qrisp import QuantumVariable, QuantumArray, h, x, cx, ry, reflection def perturbed_ghz(qv, a, b): h(qv[0]) ry(a, qv[1]) ry(b, qv[2]) for i in range(1, qv.size): cx(qv[0], qv[i]) # Prepare |1> state qv = QuantumVariable(5) x(qv) reflection(qv, perturbed_ghz, args=(0.1, 0.1)) print(qv) # {'00000': 0.9900599999999998,'01000': 0.0024799999999999996,'00100': 0.0024799999999999996,'11011': 0.0024799999999999996,'10111': 0.0024799999999999996,'11111': 1.9999999999999998e-05} """ # Convert qargs into a list if isinstance(qargs, (QuantumVariable, QuantumArray)): qargs = [qargs] # Generate a (flat) list of all QuantumVariables in input_object flattened_qargs = [] for arg in qargs: if isinstance(arg, QuantumVariable): flattened_qargs.append(arg) elif isinstance(arg, QuantumArray): flattened_qargs.extend([qv for qv in arg.flatten()]) else: raise TypeError("Arguments must be of type QuantumVariable or QuantumArray") if reflection_indices is None: reflection_indices = range(len(flattened_qargs)) qubits = sum([flattened_qargs[i].reg for i in reflection_indices], []) def inv_state_function(qargs, args, kwargs): with invert(): state_function(*qargs, *args, **kwargs) with conjugate(inv_state_function)(qargs, args, kwargs): with control(phase == np.pi): x(qubits[-1]) with control(jlen(qubits) == 1): z(qubits[0]) with control(jlen(qubits) > 1): h(qubits[-1]) mcx(qubits[:-1], qubits[-1], ctrl_state=0) h(qubits[-1]) x(qubits[-1]) with control(phase != np.pi): mcp(phase, qubits, ctrl_state=0) gphase(np.pi, qargs[0][0])