Source code for qrisp.alg_primitives.switch_case

"""
********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""

from qrisp.core import QuantumArray, QuantumVariable, x
from qrisp.qtypes import QuantumBool
from qrisp.environments import conjugate, control
from qrisp.alg_primitives import demux
from qrisp.core.gate_application_functions import mcx
from qrisp.jasp import check_for_tracing_mode, jrange, q_fori_loop, q_cond
import numpy as np
import jax.numpy as jnp


[docs] def qswitch(operand, case, case_function, method="auto"): """ Executes a switch - case statement distinguishing between a list of given in-place functions. Parameters ---------- operand : :ref:`QuantumVariable` The argument on which the case function operates. case : :ref:`QuantumFloat` The index specifying which case should be executed. case_function : list[callable] or callable A list of functions, performing some in-place operation on ``operand``, or a function ``case_function(i, operand)`` performing some in-place operation on ``operand`` depending on a nonnegative integer index ``i`` specifying the case. method : str, optional The compilation method. Available are ``sequential``, ``parallel``, ``tree`` and ``auto``. ``parallel`` is exponentially fast but requires more temporary qubits. ``tree`` uses `balanced binaray trees <https://arxiv.org/pdf/2407.17966v1>`_. The default is ``auto``. Examples -------- First, we consider the case where ``case_function`` is a **list of functions**: We create some sample functions: :: from qrisp import * def f0(x): x += 1 def f1(x): inpl_mult(x, 3, treat_overflow = False) def f2(x): pass def f3(x): h(x[1]) case_function_list = [f0, f1, f2, f3] Create operand and case variable: :: operand = QuantumFloat(4) operand[:] = 1 case = QuantumFloat(2) h(case) Execute switch - case function: >>> qswitch(operand, case, case_function_list) Simulate: >>> print(multi_measurement([case, operand])) {(0, 2): 0.25, (1, 3): 0.25, (2, 1): 0.25, (3, 1): 0.125, (3, 3): 0.125} Second, we consider the case where ``case_function`` is a **function**: :: def case_function(i, qv): x(qv[i]) operand = QuantumFloat(4) case = QuantumFloat(2) h(case) qswitch(operand, case, case_function) Simulate: >>> print(multi_measurement([case, operand])) {(0, 1): 0.25, (1, 2): 0.25, (2, 4): 0.25, (3, 8): 0.25} """ if callable(case_function): case_amount = 2**case.size xrange = jrange if method == "auto": method = "tree" else: case_amount = len(case_function) # Extend case_function list by identity such that its size is 2*n (necessary for tree qswitch) def identity(operand): pass case_function.extend( [identity] * ((1 << ((case_amount - 1).bit_length())) - case_amount) ) xrange = range if method == "auto": if case_amount <= 4: method = "sequential" else: method = "tree" if method == "sequential": control_qbl = QuantumBool() for i in xrange(case_amount): with conjugate(mcx)(case, control_qbl, ctrl_state=i): with control(control_qbl): if callable(case_function): case_function(i, operand) else: case_function[i](operand) control_qbl.delete() elif method == "parallel": if check_for_tracing_mode(): raise Exception( f"Compile method {method} for switch-case structure not available in tracing mode." ) # Idea: Use demux function to move operand and enabling bool into QuantumArray # to execute cases in parallel. # This QuantumArray acts as an addressable QRAM via the demux function enable = QuantumArray(qtype=QuantumBool(), shape=(case_amount,)) enable[0].flip() qa = QuantumArray(qtype=operand, shape=((case_amount,))) with conjugate(demux)(operand, case, qa, parallelize_qc=True): with conjugate(demux)(enable[0], case, enable, parallelize_qc=True): for i in range(case_amount): with control(enable[i]): if callable(case_function): case_function(i, qa[i]) else: case_function[i](qa[i]) qa.delete() enable[0].flip() enable.delete() # Uses balanced binaray trees https://arxiv.org/pdf/2407.17966v1 elif method == "tree": n = case.size def bounce(d: int, anc, ca, oper): with control(anc[d - 1]): x(anc[d]) with control(anc[d]): x(anc[d + 1]) with control(anc[d - 1]): with control(ca[n - 1 - d]): x(anc[d + 1]) def down(d: int, anc, ca, oper): with control(anc[d]): x(ca[n - 1 - d]) with control(ca[n - 1 - d]): x(anc[d + 1]) x(ca[n - 1 - d]) def up(d: int, anc, ca, oper): with control(anc[d]): with control(ca[n - 1 - d]): x(anc[d + 1]) # Jasp mode if check_for_tracing_mode(): xrange = jrange x_fori_loop = q_fori_loop def bitwise_count_diff(a, b): return jnp.bitwise_count(jnp.bitwise_xor(a, b)) # Normal mode else: xrange = range def x_fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val def bitwise_count_diff(a, b): return np.bitwise_count(np.bitwise_xor(a, b)) # Function mode if callable(case_function): def leaf(d: int, anc, ca, oper, i): with control(anc[d + 1]): case_function(i, oper) with control(anc[d]): x(anc[d + 1]) with control(anc[d + 1]): case_function(i + 1, oper) # List mode elif isinstance(case_function, list): if check_for_tracing_mode(): def leaf(d: int, anc, ca, oper, i): def apply_leaf(A, B): with control(anc[d + 1]): A(oper) with control(anc[d]): x(anc[d + 1]) with control(anc[d + 1]): B(oper) for j in range(0, len(case_function), 2): q_cond( j == i, apply_leaf, lambda a, b: None, case_function[j], case_function[j + 1], ) else: def leaf(d: int, anc, ca, oper, i): with control(anc[d + 1]): case_function[i](oper) with control(anc[d]): x(anc[d + 1]) with control(anc[d + 1]): case_function[i + 1](oper) else: raise TypeError( "Argument 'case_function' must be a list or a callable(i, x)" ) def body_fun(pos, val): anc, ca, oper = val # Apply leaf leaf(n - 1, anc, ca, oper, 2 * pos) # Jump to next leaf q = bitwise_count_diff(pos, pos + 1) for j in xrange(0, q - 1, 1): up(n - j - 1, anc, ca, oper) bounce(n - q, anc, ca, oper) for j in xrange(0, q - 1, 1): down(n - (q - 1) + j, anc, ca, oper) return anc, ca, oper anc = QuantumVariable(n + 1) x(anc[0]) # Go to first node for j in xrange(0, n, 1): down(j, anc, case, operand) # Perform leafs and jumps anc_, case, operand = x_fori_loop( 0, 2 ** (n - 1) - 1, body_fun, (anc, case, operand) ) # Perfrom last leaf leaf(n - 1, anc, case, operand, 2**n - 2) # Go back from last node for j in xrange(0, n, 1): up(n - j - 1, anc, case, operand) x(anc[0]) anc.delete() else: raise Exception( f"Don't know compile method {method} for switch-case structure." )