Source code for qrisp.jasp.interpreter_tools.interpreters.qc_extraction_interpreter

"""
********************************************************************************
* Copyright (c) 2026 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************

QuantumCircuit Extraction Interpreter
=====================================

This module implements the interpreter for converting Jaspr (JAX-based quantum 
intermediate representation) into static QuantumCircuit objects.

Overview
--------
Jaspr represents hybrid quantum-classical algorithms using JAX's tracing 
infrastructure. While this representation is powerful for optimization and 
compilation, it needs to be "lowered" to a QuantumCircuit for execution on 
current quantum hardware or simulators that expect circuit-based input.

The Challenge
-------------
The main challenge in this conversion is handling **measurement results**:

1. **In Jaspr**: Measurement results are represented as JAX booleans.
   These can be manipulated using standard JAX/NumPy operations like array
   construction, indexing, arithmetic, etc.

2. **In QuantumCircuit**: Measurement results are represented as `Clbit` objects.
   These are non-JAX types that cannot be processed by JAX primitives.

When we encounter classical post-processing of measurement results (e.g., 
`meas_res * 2` or `parity(m1, m2)`), we cannot represent this computation
in the QuantumCircuit itself. Instead, we use placeholder objects.

Solution Architecture
---------------------
This module provides three key classes/concepts:

1. **ProcessedMeasurement**: A placeholder for measurement results that have
   undergone classical post-processing. Since QuantumCircuit cannot represent
   classical computation, this placeholder indicates "some classical processing
   happened here."

2. **MeasurementArray**: Handles the specific problem of measurement results
   being inserted into or extracted from JAX arrays. See detailed explanation
   in the MeasurementArray class documentation.

3. **qc_extraction_eqn_evaluator**: The equation evaluator that intercepts
   JAX primitives and either:
   - Executes quantum primitives directly (building the circuit)
   - Handles array operations on measurements (using MeasurementArray)
   - Creates ProcessedMeasurement placeholders for classical post-processing
   - Delegates to default JAX evaluation for pure classical operations

"""

import numpy as np


# =============================================================================
# SECTION 1: Placeholder Classes for Classical Post-Processing
# =============================================================================


class ProcessedMeasurement:
    """
    Placeholder for measurement results that have undergone classical post-processing.

    Problem Being Solved
    --------------------
    QuantumCircuit objects represent quantum operations and measurements, but they
    cannot represent classical computation on measurement results. When a Jaspr
    contains code like:

        meas_res = measure(qv)
        processed = meas_res * 2 + 1  # Classical post-processing

    We cannot encode the `* 2 + 1` operation in the QuantumCircuit. In many
    cases the post-processing can however be achieved through the
    Jaspr.extract_post_processing feature, so instead of raising
    an error, we continue with the circuit extraction. Instead, we
    use this placeholder to indicate that classical processing occurred.
    An error is raised if a placeholder needs to decide over
    further construction steps of the circuit.

    Usage
    -----
    When the interpreter encounters an operation that:
    1. Takes measurement results (Clbit) as input
    2. Is not a quantum operation
    3. Produces output that depends on the measurement

    It creates a ProcessedMeasurement as the output, signaling that the actual
    value cannot be determined until runtime execution with real measurement data.

    Example
    -------
    >>> result, qc = jaspr.to_qc()
    >>> isinstance(result, ProcessedMeasurement)
    True  # The result involves classical post-processing
    """

    pass


[docs] class ParityHandle: """ A lightweight handle representing the result of a parity computation in quantum circuits. ParityHandle objects are returned by :meth:`qrisp.QuantumCircuit.parity` and serve as keys in the detector and observable maps produced by :meth:`~qrisp.QuantumCircuit.to_stim`. They enable tracking of parity check results throughout circuit manipulation pipelines. Attributes ---------- instruction : Instruction The :class:`~qrisp.Instruction` object containing the ``ParityOperation`` that sits in the circuit's data list (``qc.data``). Properties ---------- clbits : list[Clbit] The list of :class:`~qrisp.Clbit` objects involved in this parity computation. Retrieved from ``instruction.clbits``. These are the measurement results being XORed together. expectation : int The expected parity value (0 or 1). Retrieved from ``instruction.op.expectation``. In Stim detector mode, deviations from this expectation indicate errors. observable : bool Whether this parity represents a Stim observable (``True``) or detector (``False``). Retrieved from ``instruction.op.observable``. Observables track logical information while detectors assert deterministic parities for error detection. Examples -------- ParityHandles are typically created via :meth:`qrisp.QuantumCircuit.parity`: >>> from qrisp import QuantumCircuit >>> qc = QuantumCircuit(2, 2) >>> qc.h(0) >>> qc.cx(0, 1) >>> qc.measure([0, 1], [0, 1]) >>> handle = qc.parity([qc.clbits[0], qc.clbits[1]], expectation=0) >>> handle.clbits [Clbit(cb_2), Clbit(cb_3)] >>> handle.expectation 0 Use handles as keys in Stim conversion maps: >>> stim_circuit, meas_map, det_map = qc.to_stim( ... return_measurement_map=True, ... return_detector_map=True ... ) >>> det_map[handle] # Get the Stim detector index 0 """ # Design Rationale # ---------------- # In quantum error correction and Stim workflows, parity operations compute the XOR of # multiple measurement results. Unlike measurements (which are quantum operations producing # new classical bits), parity is a purely classical computation operating on existing # measurement results. # **Why Not Use Clbits for Parity Results?** # A naive approach might create a new ``Clbit`` to represent each parity result. However, # this breaks the fundamental semantic property that classical bits in a QuantumCircuit # have a 1:1 correspondence with actual quantum measurements. Parity operations don't # perform measurements—they compute classical functions of existing measurement outcomes. # Creating "fake" Clbits for parity results would: # * Violate the semantic contract that each Clbit corresponds to a physical measurement # * Complicate measurement record management and classical bit indexing # * Make it ambiguous which Clbits represent actual measurement results vs. computed parities # * Interfere with transpilation and other circuit transformations that assume Clbit semantics # **The ParityHandle Solution** # Instead, ParityHandle provides a distinct type for parity results that: # * Maintains the 1:1 mapping between Clbits and quantum measurements # * Clearly distinguishes computed parities from measurement results # * Enables efficient lookup in Stim detector/observable maps # * Works correctly across circuit transpilation (via content-based equality) # * Provides access to the underlying measurement Clbits via the ``clbits`` property # When converting to Stim format, ParityHandles map to ``DETECTOR`` or ``OBSERVABLE_INCLUDE`` # instructions, which is their natural representation in the Stim model. # Content-Based Equality # ---------------------- # ParityHandle uses content-based equality (rather than object identity) for hashing and # comparison. Two ParityHandles are considered equal if they have: # * The same set of input Clbits (order-independent) # * The same expectation value # * The same observable flag # This design choice ensures ParityHandles work correctly across transpilation passes, # where :class:`~qrisp.Instruction` objects may be copied but their semantic content # (the clbits and parameters they operate on) remains the same. def __init__(self, instruction): self.instruction = instruction @property def clbits(self): """Get the clbits involved in this parity from the underlying instruction.""" return self.instruction.clbits @property def expectation(self): """Get the expected parity value from the underlying instruction.""" return self.instruction.op.expectation @property def observable(self): """Get whether this parity is an observable from the underlying instruction.""" return self.instruction.op.observable def __repr__(self): return f"ParityHandle{tuple(self.clbits)}" def __hash__(self): # Hash by clbits (as frozenset), expectation, and observable for content-based comparison # This allows matching across transpile calls where instruction objects are copied return hash((frozenset(self.clbits), self.expectation, self.observable)) def __eq__(self, other): if isinstance(other, ParityHandle): # Compare by content: same clbits, same expectation, and same observable return ( set(self.clbits) == set(other.clbits) and self.expectation == other.expectation and self.observable == other.observable ) return False
# ============================================================================= # SECTION 2: MeasurementArray - Handling Arrays of Measurement Results # ============================================================================= class MeasurementArray(np.ndarray): """ A numpy ndarray subclass for measurement-related values during QuantumCircuit extraction. Problem Being Solved -------------------- In Jaspr, measurement results are JAX boolean values that can be freely combined into arrays. When lowering to QuantumCircuit, measurements return `Clbit` objects instead of JAX booleans. JAX array operations cannot handle Clbit objects directly. Solution -------- MeasurementArray is a numpy ndarray subclass with dtype=object that stores: - Clbit objects: Measurement results that can be used in circuit operations - ParityHandle objects: Parity computation results - ProcessedMeasurement: Marker for classical post-processing results - bool: Known boolean values (True/False) Being a numpy subclass means reshape, slice, concatenate, etc. work natively. The subclass provides type identification via isinstance() and helper methods. """ def __new__(cls, data): """Create a new MeasurementArray from data.""" arr = np.asarray(data, dtype=object).view(cls) return arr def __array_finalize__(self, obj): """Called after array construction to finalize the object.""" pass def mark_as_processed(self): """ Return a new MeasurementArray with all entries marked as processed. Returns ------- MeasurementArray New array with all entries set to ProcessedMeasurement(). """ processed_data = np.array( [ProcessedMeasurement() for _ in self.flat], dtype=object ) return MeasurementArray(processed_data.reshape(self.shape)) # ============================================================================= # SECTION 3: Helper Functions # ============================================================================= def contains_measurement_data(val): """ Check if a value contains measurement-related data. Parameters ---------- val : any Value to check. Returns ------- bool True if the value is or contains Clbit, MeasurementArray, ParityHandle, or ProcessedMeasurement data. """ from qrisp import Clbit if isinstance(val, (Clbit, MeasurementArray, ProcessedMeasurement, ParityHandle)): return True if isinstance(val, list) and len(val): return contains_measurement_data(val[0]) return False def to_object_array(val): """ Convert a measurement-related value to an object numpy array. Parameters ---------- val : any Value to convert. Can be: - MeasurementArray: returned as-is (already an ndarray subclass) - Clbit, ParityHandle, ProcessedMeasurement, bool: wrapped in 0-d array - numpy array: converted to object dtype - Other: returned unchanged Returns ------- numpy.ndarray or original value Object array with the value(s). """ from qrisp import Clbit if isinstance(val, MeasurementArray): # MeasurementArray is already an ndarray subclass return val elif isinstance(val, (Clbit, ParityHandle, ProcessedMeasurement)): return np.array(val, dtype=object) elif isinstance(val, (bool, np.bool_)): return np.array(bool(val), dtype=object) elif isinstance(val, np.ndarray): return val.astype(object) if val.dtype != object else val else: return val def apply_array_primitive(prim_name, params, invalues): """ Apply a JAX array primitive to measurement data using numpy equivalents. Parameters ---------- prim_name : str Name of the JAX primitive (e.g., 'broadcast_in_dim', 'concatenate'). params : dict Parameters of the JAX primitive. invalues : list Input values to the primitive. Returns ------- MeasurementArray, scalar, or None - MeasurementArray for array results - Scalar (Clbit, bool, ProcessedMeasurement, ParityHandle) for 0-d results - None if this primitive is not handled """ # Convert all inputs to object arrays encoded = [to_object_array(v) for v in invalues] # Apply the numpy equivalent based on primitive name if prim_name == "broadcast_in_dim": shape = params["shape"] result = np.broadcast_to(encoded[0], shape) elif prim_name == "concatenate": dimension = params.get("dimension", 0) result = np.concatenate(encoded, axis=dimension) elif prim_name == "squeeze": dimensions = params.get("dimensions", None) result = np.squeeze(encoded[0], axis=dimensions) elif prim_name == "slice": start_indices = params["start_indices"] limit_indices = params["limit_indices"] slices = tuple(slice(s, e) for s, e in zip(start_indices, limit_indices)) result = encoded[0][slices] elif prim_name == "dynamic_slice": # Start indices come from invalues[1:] start_indices = [ int(encoded[i]) if np.ndim(encoded[i]) == 0 else int(encoded[i].flat[0]) for i in range(1, len(encoded)) ] slice_sizes = params["slice_sizes"] slices = tuple(slice(s, s + sz) for s, sz in zip(start_indices, slice_sizes)) result = encoded[0][slices] elif prim_name == "gather": # Simple indexing case indices = encoded[1] if hasattr(indices, "item"): idx = indices.item() if isinstance(idx, (int, np.integer)): idx = int(idx) elif np.ndim(indices) == 0: idx = ( int(indices) if isinstance(indices, (int, np.integer, np.ndarray)) else indices ) else: idx = int(indices[0]) if len(indices) == 1 else indices if isinstance(idx, (int, np.integer)): result = encoded[0].flat[idx] else: result = encoded[0][idx] elif prim_name == "reshape": new_sizes = params.get("new_sizes", params.get("dimensions", None)) if new_sizes is not None: result = encoded[0].reshape(new_sizes) else: result = encoded[0] elif prim_name == "transpose": permutation = params.get("permutation", None) result = np.transpose(encoded[0], permutation) elif prim_name == "rev": # Reverse array along specified dimensions dimensions = params.get("dimensions", ()) result = encoded[0] for dim in dimensions: result = np.flip(result, axis=dim) elif prim_name == "dynamic_update_slice": # dynamic_update_slice(operand, update, start_indices) # Updates a slice of operand starting at start_indices with the values from update operand = encoded[0] update = encoded[1] # Start indices come from remaining invalues (already encoded) start_indices = tuple( int(encoded[i]) if np.ndim(encoded[i]) == 0 else int(encoded[i].flat[0]) for i in range(2, len(encoded)) ) # Create a copy and update the slice result = operand.copy() slices = tuple(slice(s, s + u) for s, u in zip(start_indices, update.shape)) result[slices] = update elif prim_name == "select_n": # select_n(cond, *cases) - selects from cases based on cond values # cond[i] == j means select cases[j][i] cond = encoded[0] cases = encoded[1:] # For boolean/binary condition (most common case) if len(cases) == 2: # np.where with object arrays result = np.where(cond.astype(bool), cases[1], cases[0]) else: # General case: build result by indexing result = np.empty_like(cases[0]) for i in range(len(cases)): mask = cond == i result[mask] = cases[i][mask] else: return None # Handle result result = np.asarray(result, dtype=object) if result.ndim == 0: # Scalar result - return the object directly return result.item() else: # Array result return MeasurementArray(result) def resolve_measurement_arrays(value): """ Recursively resolve MeasurementArrays to numpy arrays with dtype=object. This function is called at the end of jaspr_to_qc to convert the internal MeasurementArray representation to a standard numpy array that users can work with directly. Parameters ---------- value : any The value to resolve. Can be: - MeasurementArray: viewed as plain numpy array with dtype=object - tuple/list: each element is recursively resolved - Other types: returned unchanged Returns ------- any The resolved value with MeasurementArrays converted to numpy arrays. """ if isinstance(value, MeasurementArray): # View as plain numpy array (MeasurementArray is already an ndarray subclass) return value.view(np.ndarray) elif isinstance(value, tuple): return tuple(resolve_measurement_arrays(v) for v in value) elif isinstance(value, list): return [resolve_measurement_arrays(v) for v in value] else: return value def handle_classical_processing(invalues): """ Handle operations that represent classical post-processing on measurement data. Many operations (arithmetic, comparisons, reductions, bitwise ops) represent classical computation that cannot be performed during circuit construction. This function provides a unified way to handle all such operations. The strategy is: 1. If any input is a MeasurementArray, return a MeasurementArray of the same size with all entries marked as "processed". 2. If inputs contain scalar measurement data (Clbit, ProcessedMeasurement, ParityHandle), return a scalar ProcessedMeasurement. 3. If no measurement data is involved, return None to indicate default JAX evaluation should be used. Parameters ---------- invalues : list Input values to the operation. Returns ------- MeasurementArray, ProcessedMeasurement, or None - MeasurementArray with processed entries if input was an array - ProcessedMeasurement for scalar measurement inputs - None if no measurement data (use default JAX evaluation) """ from qrisp import Clbit # Check for MeasurementArray inputs first (preserves array structure) for v in invalues: if isinstance(v, MeasurementArray): return v.mark_as_processed() # Check for scalar measurement data for v in invalues: if isinstance(v, (Clbit, ProcessedMeasurement, ParityHandle)): return ProcessedMeasurement() # Check for lists containing measurement data for v in invalues: if isinstance(v, list) and len(v) and contains_measurement_data(v[0]): return ProcessedMeasurement() # No measurement data - use default evaluation return None # List of primitives that represent classical processing on measurement data. # These operations cannot be performed during circuit construction because # measurement results are not known until runtime. CLASSICAL_PROCESSING_PRIMITIVES = { # Arithmetic operations "add", "sub", "mul", "div", "rem", "pow", "neg", "integer_pow", "floor", "ceil", "round", "abs", # Comparison operations "eq", "ne", "lt", "gt", "le", "ge", # Reduction operations "reduce_sum", "reduce_prod", "reduce_max", "reduce_min", "reduce_or", "reduce_and", "reduce_xor", # Bitwise operations "not", "and", "or", "xor", "shift_left", "shift_right_arithmetic", "shift_right_logical", } # ============================================================================= # SECTION 4: Equation Evaluator Factory # ============================================================================= def make_qc_extraction_eqn_evaluator(qc): """ Create an equation evaluator for extracting a QuantumCircuit from a Jaspr. This factory function creates a closure over the QuantumCircuit being built, returning an evaluator function that can be passed to eval_jaxpr. Parameters ---------- qc : QuantumCircuit The quantum circuit to build. Operations will be appended to this circuit. Returns ------- callable An equation evaluator function with signature (eqn, context_dic) -> bool|None """ def qc_extraction_eqn_evaluator(eqn, context_dic): """ Evaluate a single Jaxpr equation during QuantumCircuit extraction. This function is called for each equation in the Jaspr. It determines how to handle the equation based on its primitive type: 1. **Quantum Primitives** (jasp.*): Execute directly using their implementation, which appends operations to the QuantumCircuit. 2. **Array Operations on Measurements**: Handle specially using MeasurementArray to maintain Clbit references through array ops. 3. **Classical Operations on Measurements**: Create ProcessedMeasurement placeholders since we can't represent classical computation in QC. 4. **Pure Classical Operations**: Delegate to default JAX evaluation by returning True. Parameters ---------- eqn : JaxprEqn The equation to evaluate. context_dic : dict Dictionary mapping Jaxpr variables to their current values. Returns ------- bool or None - True: Use default JAX evaluation for this equation - None/False: Equation was fully handled, skip default evaluation """ # Import here to avoid circular imports from qrisp import Clbit from qrisp.jasp import ( Jaspr, extract_invalues, insert_outvalues, QuantumPrimitive, ParityOperation, ) from qrisp.jasp.interpreter_tools.interpreters import cond_to_cl_control invalues = extract_invalues(eqn, context_dic) prim_name = eqn.primitive.name # ----------------------------------------------------------------- # SECTION 4.1: Control Flow and Structural Primitives # ----------------------------------------------------------------- if prim_name == "jit" and ( isinstance(eqn.params["jaxpr"], Jaspr) or any(contains_measurement_data(v) for v in invalues) ): # Nested Jaspr (from @qache or similar) - evaluate with our interpreter from qrisp.jasp import eval_jaxpr definition_jaxpr = eqn.params["jaxpr"] res = eval_jaxpr( definition_jaxpr.jaxpr, eqn_evaluator=qc_extraction_eqn_evaluator )(*(invalues + definition_jaxpr.consts)) if len(definition_jaxpr.jaxpr.outvars) == 1: res = [res] insert_outvalues(eqn, context_dic, res) return elif prim_name == "jit": return True elif prim_name == "cond": # Conditional branching - may become classically controlled operation return cond_to_cl_control(eqn, context_dic, qc_extraction_eqn_evaluator) elif prim_name == "while": # While loops need special handling - delegate to default for now # (the loop will be unrolled during evaluation) return True # ----------------------------------------------------------------- # SECTION 4.2: Quantum Primitives with Special Handling # ----------------------------------------------------------------- elif prim_name == "jasp.parity": # Parity operation: XOR of multiple classical bits # Note: Parity on arrays is dissolved into loops during tracing, # so invalues are always scalars (Clbit, ParityHandle, or ProcessedMeasurement) # Check for ProcessedMeasurement - return ProcessedMeasurement if any(isinstance(v, ProcessedMeasurement) for v in invalues): insert_outvalues(eqn, context_dic, ProcessedMeasurement()) return # Expand all inputs to clbits using symmetric difference (XOR semantics) # This means duplicate clbits cancel out (a XOR a = 0) clbit_set = set() for inp in invalues: if isinstance(inp, Clbit): clbit_set.symmetric_difference_update({inp}) elif isinstance(inp, ParityHandle): # ParityHandle stores expanded clbits clbit_set.symmetric_difference_update(inp.clbits) elif isinstance(inp, list) and all(isinstance(c, Clbit) for c in inp): # List-wrapped Clbits (e.g. from measuring a single-qubit QuantumBool) clbit_set.symmetric_difference_update(inp) # Convert set to sorted list for deterministic ordering expanded_clbits = sorted(clbit_set, key=lambda cb: qc.clbits.index(cb)) # Add parity operation to the circuit (for stim conversion) qc.append( ParityOperation( len(expanded_clbits), expectation=eqn.params["expectation"], observable=eqn.params["observable"], ), clbits=expanded_clbits, ) # Get the instruction we just appended parity_instr = qc.data[-1] # Create a parity handle with the instruction handle = ParityHandle(parity_instr) insert_outvalues(eqn, context_dic, handle) return # ----------------------------------------------------------------- # SECTION 4.3: Type Conversion (convert_element_type) # ----------------------------------------------------------------- # JAX often inserts type conversions. For measurement data: # - bool->bool conversions pass through unchanged # - bool->int conversions for Clbit pass through (used by cond primitive) # - Conversions to float types mark as processed since we can't # actually compute with measurement values elif prim_name == "convert_element_type": inval = context_dic[eqn.invars[0]] new_dtype = eqn.params.get("new_dtype", None) if isinstance(inval, MeasurementArray): # Check if converting to a non-boolean/non-integer type if new_dtype is not None and not ( np.issubdtype(new_dtype, np.bool_) or np.issubdtype(new_dtype, np.integer) ): # Converting to float type - mark as processed context_dic[eqn.outvars[0]] = inval.mark_as_processed() else: # Bool-to-bool, bool-to-int, or unknown conversion - pass through context_dic[eqn.outvars[0]] = inval return elif isinstance(inval, Clbit): # Clbit should pass through for bool->int conversions (used by cond) # Only mark as processed for float conversions if new_dtype is not None and not ( np.issubdtype(new_dtype, np.bool_) or np.issubdtype(new_dtype, np.integer) ): context_dic[eqn.outvars[0]] = ProcessedMeasurement() else: context_dic[eqn.outvars[0]] = inval return elif isinstance(inval, ProcessedMeasurement): # ProcessedMeasurement stays processed context_dic[eqn.outvars[0]] = ProcessedMeasurement() return elif ( isinstance(inval, list) and len(inval) and isinstance(inval[0], (ProcessedMeasurement, Clbit)) ): # List of measurement data if new_dtype is not None and not np.issubdtype(new_dtype, np.bool_): context_dic[eqn.outvars[0]] = ProcessedMeasurement() else: context_dic[eqn.outvars[0]] = inval return return True # ----------------------------------------------------------------- # SECTION 4.4: Array Operations on Measurement Data # ----------------------------------------------------------------- # These primitives handle JAX array operations when the arrays contain # measurement results. We use a generic handler that converts to object # arrays, applies the numpy equivalent, and wraps back into MeasurementArray. elif prim_name in ( "broadcast_in_dim", "concatenate", "squeeze", "slice", "dynamic_slice", "gather", "reshape", "transpose", "rev", "dynamic_update_slice", "select_n", ): # Check if any input contains measurement data if any(contains_measurement_data(v) for v in invalues): result = apply_array_primitive(prim_name, eqn.params, invalues) if result is not None: insert_outvalues(eqn, context_dic, result) return return True elif prim_name == "split": # split: Split an array into multiple sub-arrays # This is a multiple-result primitive, so we need to handle it specially if any(contains_measurement_data(v) for v in invalues): encoded = to_object_array(invalues[0]) axis = eqn.params.get("axis", 0) sizes = eqn.params.get("sizes", ()) # numpy.split expects indices, but JAX gives us sizes # Convert sizes to split indices indices = np.cumsum(sizes[:-1]).tolist() # Perform the split results = np.split(encoded, indices, axis=axis) # Wrap each result as MeasurementArray wrapped_results = [] for r in results: r = np.asarray(r, dtype=object) if r.ndim == 0: wrapped_results.append(r.item()) else: wrapped_results.append(MeasurementArray(r)) insert_outvalues(eqn, context_dic, wrapped_results) return return True elif prim_name == "scatter": # scatter: Update array elements at specified indices # Used in array assignment operations (e.g., building arrays in loops) # # Scatter has the form: scatter(operand, indices, updates) # - operand: The array being updated # - indices: Where to place the updates # - updates: The values to insert operand = invalues[0] indices = invalues[1] updates = invalues[2] if len(invalues) > 2 else None # Unwrap updates if it's a single-element list if isinstance(updates, list) and len(updates) == 1: updates = updates[0] # Helper to get index from various index formats def get_scatter_index(indices): if hasattr(indices, "item"): return int(indices.item()) elif hasattr(indices, "__len__") and len(indices) == 1: return int(indices[0]) else: return int(indices) # Check if we're working with measurement data if isinstance(operand, MeasurementArray): idx = get_scatter_index(indices) # Create a copy and update directly (object array) result = operand.copy() result[idx] = updates insert_outvalues(eqn, context_dic, result) return elif contains_measurement_data(updates): # Operand is not a MeasurementArray but updates contains measurement data # This happens when building an array from measurements in a loop # Get the shape of the operand if hasattr(operand, "shape"): size = int(np.prod(operand.shape)) elif hasattr(operand, "__len__"): size = len(operand) else: size = 1 # Initialize MeasurementArray with False values new_data = np.array([False] * size, dtype=object) # Set the update at the given index idx = get_scatter_index(indices) new_data[idx] = updates result = MeasurementArray(new_data) insert_outvalues(eqn, context_dic, result) return elif isinstance(operand, ProcessedMeasurement): insert_outvalues(eqn, context_dic, ProcessedMeasurement()) return else: # No measurement data involved return True # ----------------------------------------------------------------- # SECTION 4.5: Classical Processing Operations # ----------------------------------------------------------------- # This section handles all operations that represent classical # computation on measurement data. These operations cannot be # performed during circuit construction because measurement results # are not known until runtime. # # We use a unified approach: # 1. Check if the primitive is in CLASSICAL_PROCESSING_PRIMITIVES # 2. Use handle_classical_processing() to create appropriate output # 3. This preserves array structure (MeasurementArray with processed # entries) while marking the data as "processed" elif prim_name in CLASSICAL_PROCESSING_PRIMITIVES: result = handle_classical_processing(invalues) if result is not None: insert_outvalues(eqn, context_dic, result) return else: return True # ----------------------------------------------------------------- # SECTION 4.6: Default Handling # ----------------------------------------------------------------- else: # Check if any input contains measurement data for val in invalues: if contains_measurement_data(val): break else: # No measurement data - check if it's a quantum primitive if isinstance(eqn.primitive, QuantumPrimitive): # Execute quantum primitive directly outvalues = eqn.primitive.impl(*invalues, **eqn.params) insert_outvalues(eqn, context_dic, outvalues) return else: # Pure classical operation - use default JAX evaluation return True # ----------------------------------------------------------------- # SECTION 4.7: Fallback for Unhandled Measurement Operations # ----------------------------------------------------------------- # If we reach here, the operation involves measurement data but isn't # one of the specifically handled cases above. We create # ProcessedMeasurement placeholders for the outputs. if len(eqn.outvars) == 0: return elif len(eqn.outvars) == 1 and not eqn.primitive.multiple_results: outvalues = ProcessedMeasurement() elif len(eqn.outvars) >= 1: outvalues = [ProcessedMeasurement() for _ in range(len(eqn.outvars))] insert_outvalues(eqn, context_dic, outvalues) return qc_extraction_eqn_evaluator # ============================================================================= # SECTION 5: Public API - jaspr_to_qc Function # ============================================================================= def jaspr_to_qc(jaspr, *args): """ Convert a Jaspr into a QuantumCircuit. This is the main entry point for converting Jaspr intermediate representation into a static QuantumCircuit that can be executed on quantum hardware or simulators. Limitations ----------- - **Real-time feedback**: Algorithms that use measurement results to control subsequent quantum operations (true real-time feedback) cannot be fully represented. The control flow will be evaluated with placeholder values. - **Classical post-processing**: Any classical computation on measurement results (arithmetic, comparisons, etc.) cannot be represented in the QuantumCircuit. These are replaced with ProcessedMeasurement placeholders. Parameters ---------- jaspr : Jaspr The Jaspr object to convert. *args : tuple Arguments to call the Jaspr with. These should NOT include the QuantumCircuit argument (it's added automatically). Exclude any static arguments like callables. Returns ------- tuple A tuple containing: - Return values from the Jaspr (QuantumVariable returns become qubit lists, measurement results become Clbit or ProcessedMeasurement objects) - The constructed QuantumCircuit (always the last element) Examples -------- Basic circuit extraction: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(example_function)(2) qb_list, qc = jaspr_to_qc(jaspr, 2) print(qc) # qb_0: ──■─────── # ┌─┴─┐┌───┐ # qb_1: ┤ X ├┤ T ├ # └───┘└───┘ With measurement post-processing (returns ProcessedMeasurement): :: from qrisp.jasp.interpreter_tools.interpreters import ProcessedMeasurement def example_function(i): qf = QuantumFloat(i) cx(qf[0], qf[1]) t(qf[1]) meas_res = measure(qf) meas_res *= 2 # Classical post-processing return meas_res jaspr = make_jaspr(example_function)(2) meas_res, qc = jaspr_to_qc(jaspr, 2) print(isinstance(meas_res, ProcessedMeasurement)) # True With array operations on measurements: :: import jax.numpy as jnp def array_example(): qv = QuantumVariable(3) m0 = measure(qv[0]) m1 = measure(qv[1]) m2 = measure(qv[2]) # Create array of measurements arr = jnp.array([m0, m1, m2]) # Extract first measurement return arr[0] jaspr = make_jaspr(array_example)() result, qc = jaspr_to_qc(jaspr) # result is the Clbit corresponding to m0's measurement print(type(result)) # <class 'qrisp.circuit.clbit.Clbit'> """ from qrisp import QuantumCircuit from qrisp.jasp import eval_jaxpr from qrisp.circuit import Qubit from jax._src.core import eval_context qc = QuantumCircuit() ammended_args = list(args) + [qc] if len(ammended_args) != len(jaspr.invars): raise Exception( "Supplied invalid number of arguments to Jaspr.to_qc " "(please exclude any static arguments, in particular callables)" ) # Pre-register any Qubit objects passed as arguments into the # internal QuantumCircuit. This is needed when the Jaspr receives # qubits directly (e.g. from a QuantumVariable or QuantumArray traced # inside make_jaspr) rather than creating them via create_qubits. for arg in args: if isinstance(arg, list) and len(arg) and isinstance(arg[0], Qubit): for qb in arg: if qb not in qc.qubits: qc.add_qubit(qb) # Use eval_context to temporarily exit any outer JAX trace. # This ensures that primitive .bind() calls for classical operations # evaluate concretely rather than being captured by an outer trace, # allowing jaspr_to_qc to be called from within a make_jaspr or # jit tracing context without TracerIntegerConversionError. with eval_context(): res = eval_jaxpr(jaspr, eqn_evaluator=make_qc_extraction_eqn_evaluator(qc))( *ammended_args ) # Resolve MeasurementArrays to numpy arrays with dtype=object res = resolve_measurement_arrays(res) return res