Jaspr#
- class Jaspr(*args, permeability=None, isqfree=None, ctrl_jaspr=None, inv_jaspr=None, **kwargs)[source]#
The
Jasprclass enables an efficient representations of a wide variety of (hybrid) algorithms. For many applications, the representation is agnostic to the scale of the problem, implying function calls with 10 or 10000 qubits can be represented by the same object. The actual unfolding to a circuit-level description is outsourced to established, classical compilation infrastructure, implying state-of-the-art compilation speed can be reached.As a subtype of
jax.extend.core.ClosedJaxpr, Jasprs are embedded into the well matured Jax ecosystem, which facilitates the compilation of classical real-time computation using some of the most advanced libraries in the world such as CUDA. Especially machine learning and other scientific computations tasks are particularly well supported.To get a better understanding of the syntax and semantics of Jaxpr (and with that also Jaspr) please check this link.
Similar to Jaxpr, Jaspr objects represent (hybrid) quantum algorithms in the form of a functional programming language in SSA-form.
It is possible to compile Jaspr objects into QIR, which is facilitated by the Catalyst framework (check
qrisp.jasp.jaspr.to_qir()for more details).Qrisp scripts can be turned into Jaspr objects by calling the
make_jasprfunction, which has similar semantics as jax.make_jaxpr.from qrisp import * from qrisp.jasp import make_jaspr def test_fun(i): qv = QuantumFloat(i, -1) x(qv[0]) cx(qv[0], qv[i-1]) meas_res = measure(qv) meas_res += 1 return meas_res jaspr = make_jaspr(test_fun)(4) print(jaspr)
This will give you the following output:
{ lambda ; a:QuantumState b:i32[]. let c:QuantumState d:QubitArray = create_qubits a b e:Qubit = get_qubit d 0 f:QuantumState = x c e g:i32[] = sub b 1 h:Qubit = get_qubit d g i:QuantumState = cx f e h j:QuantumState k:i32[] = measure i d l:f32[] = convert_element_type[new_dtype=float64 weak_type=True] k m:f32[] = mul l 0.5 n:f32[] = add m 1.0 in (j, n) }
A defining feature of the Jaspr class is that the first input and the first output are always of QuantumState type. Therefore, Jaspr objects always represent some (hybrid) quantum operation.
Qrisp comes with a built-in Jaspr interpreter. For that you simply have to call the object like a function:
>>> print(jaspr(2)) 2.5 >>> print(jaspr(4)) 5.5
Methods#
Manipulation#
Returns the inverse Jaspr (if applicable). |
|
|
Returns the controlled version of the Jaspr. |
Evaluation#
|
Leverages the Catalyst pipeline to compile a QIR representation of this function and executes that function using the Catalyst QIR runtime. |
|
Converts the Jaspr into a QuantumCircuit if applicable. |
Extracts the post-processing logic from this Jaspr and returns a function that performs the post-processing on measurement results. |
|
|
Compiles the Jaspr into an OpenQASM 2 string. |
Compiles the Jaspr to an xDSL module using the Jasp Dialect. |
|
Compiles the jaspr to the corresponding Catalyst jaxpr. |
|
Compiles the Jaspr to MLIR using the Catalyst dialect. |
|
Compiles the Jaspr to QIR using the Catalyst framework. |
Construction#
- make_jaspr(fun, flatten_envs=True, return_shape=False, **jax_kwargs)[source]
Creates a function that returns the Jaspr representation of a quantum function.
This function is analogous to JAX’s
make_jaxpr, but produces a Jaspr (a Jaxpr enhanced with quantum primitives) from a Qrisp quantum function.- Parameters:
- funCallable
The quantum function whose Jaspr is to be computed.
- flatten_envsbool, optional
If True (default), flatten quantum environments in the resulting Jaspr.
- return_shapebool, optional
If True, the returned function produces a tuple
(jaspr, out_tree)whereout_treeis a PyTreeDef representing the structure of the output offun. This can be used to reconstruct PyTree objects from flat output lists usingjax.tree_util.tree_unflatten. Default is False.- **jax_kwargs
Additional keyword arguments passed to
jax.make_jaxpr, such asstatic_argnums.
- Returns:
- Callable
A function that, when called with example arguments, returns either: - A Jaspr representation of
fun(ifreturn_shape=False) - A tuple(Jaspr, out_tree)(ifreturn_shape=True) whereout_treeis a PyTreeDef that can be used withtree_unflatten
Examples
Basic quantum circuit with measurement
Create a Jaspr for a simple Bell state circuit:
from qrisp import QuantumVariable, h, cx, measure from qrisp.jasp import make_jaspr def simple_circuit(): qv = QuantumVariable(2) h(qv[0]) cx(qv[0], qv[1]) return measure(qv) jaspr = make_jaspr(simple_circuit)() result = jaspr() # Returns 0 or 3 with equal probability
Parameterized quantum circuit
Create a Jaspr with parameterized gates that can be executed with different parameters:
from qrisp import QuantumVariable, h, p, measure from qrisp.jasp import make_jaspr def rotation_circuit(angle): qv = QuantumVariable(1) h(qv) p(angle, qv) return measure(qv) jaspr = make_jaspr(rotation_circuit)(0.5) result1 = jaspr(0.5) # Execute with angle=0.5 result2 = jaspr(1.0) # Execute with angle=1.0
Using return_shape for PyTree reconstruction
Retrieve the output tree structure alongside the Jaspr for reconstructing complex return values:
from qrisp import QuantumVariable, h, cx, x, measure from qrisp.jasp import make_jaspr from jax.tree_util import tree_unflatten, tree_flatten def multi_output_circuit(): qa = QuantumVariable(2) qb = QuantumVariable(2) h(qa[0]) cx(qa[0], qa[1]) x(qb) return measure(qa), measure(qb) jaspr, out_tree = make_jaspr(multi_output_circuit, return_shape=True)() result_a, result_b = jaspr() # Use out_tree to reconstruct the output structure flat_results, _ = tree_flatten((result_a, result_b)) reconstructed = tree_unflatten(out_tree, flat_results)
Advanced details#
This section elaborates how Jaspr objects are embedded into the Jax infrastructure. If you just want to accelerate your code you can (probably) skip this. It is recommended to first get a solid understanding of Jax primitives and how to create a Jaxpr out of them.
Jasp is designed to model dynamic quantum computations with a minimal set of primitives.
For that, there are 3 new Jax abstract data types defined:
QuantumState, which represents an object that tracks what kind of manipulations are applied to the quantum state.QubitArray, which represents an array of qubits that can have a dynamic number of qubits.Qubit, which represents individual qubits.
Before we describe how quantum computations are realized, we list some “administrative” primitives and their semantics.
Primitive |
Semantics |
|---|---|
|
Creates new qubits. Takes a (dynamic) integer representing the size and a |
|
Extracts a |
|
Retrieves the size of a |
|
Deallocates a |
|
Resets qubits in a |
Quantum Operations#
Quantum gates are represented by the jasp.quantum_gate primitive. Here’s an example:
from qrisp import *
from qrisp.jasp import *
def test_function(i):
qv = QuantumVariable(i)
cx(qv[0], qv[1])
bl = measure(qv[1])
return qv, bl
print(make_jaspr(test_function)(2))
{ lambda ; a:i64[] b:QuantumState. let
c:QubitArray d:QuantumState = jasp.create_qubits a b
e:Qubit = jasp.get_qubit c 0:i64[]
f:Qubit = jasp.get_qubit c 1:i64[]
g:QuantumState = jasp.quantum_gate[gate=cx] e f d
h:bool[] i:QuantumState = jasp.measure f g
in (c, h, i) }
The line starting with g: describes how quantum gates are represented in a Jaspr: The gate name is specified in the parameters (gate=cx), followed by the Qubit arguments, and finally the QuantumState. This structure closely mirrors how quantum computations are modeled mathematically: as a unitary applied to a tensor at certain indices. You can think of QuantumState objects as tensors, Qubit objects as integer indices, and QubitArray objects as arrays of indices.
The jasp.measure primitive takes a special role: Unlike other quantum operations, it not only returns a new QuantumState but also a measurement outcome. When measuring a single Qubit, it returns a boolean value. When measuring a QubitArray, it returns an integer:
def test_function(i):
qv = QuantumVariable(i)
cx(qv[0], qv[1])
a = measure(qv)
return a
print(make_jaspr(test_function)(2))
{ lambda ; a:i64[] b:QuantumState. let
c:QubitArray d:QuantumState = jasp.create_qubits a b
e:Qubit = jasp.get_qubit c 0:i64[]
f:Qubit = jasp.get_qubit c 1:i64[]
g:QuantumState = jasp.quantum_gate[gate=cx] e f d
h:i64[] i:QuantumState = jasp.measure c g
in (h, i) }
Both variants return values (bool[] or i64[]) that other Jax modules understand, highlighting the seamless embedding of quantum computations into the Jax ecosystem.
QuantumEnvironments#
Quantum Environments objects in Jasp can be represented in two forms: unflattened (where environments appear as jasp.q_env primitives) or flattened (where environment transformations are applied directly).
Unflattened form (flatten_envs=False):
def test_function(i):
qv = QuantumVariable(i)
with invert():
t(qv[0])
cx(qv[0], qv[1])
return qv
jaspr = make_jaspr(test_function, flatten_envs=False)(2)
print(jaspr)
{ lambda ; a:i64[] b:QuantumState. let
c:QubitArray d:QuantumState = jasp.create_qubits a b
e:QuantumState = jasp.q_env[
jaspr={ lambda ; c:QubitArray f:QuantumState. let
g:Qubit = jasp.get_qubit c 0:i64[]
h:QuantumState = jasp.quantum_gate[gate=t] g f
i:Qubit = jasp.get_qubit c 1:i64[]
j:QuantumState = jasp.quantum_gate[gate=cx] g i h
in (j,) }
type=InversionEnvironment
] c d
in (c, e) }
Here, the body of the InversionEnvironment is collected into a nested Jaspr within the jasp.q_env primitive. This representation preserves the environment structure and reflects how QuantumEnvironments describe higher-order quantum functions that transform quantum operations.
Flattened form (flatten_envs=True, default):
jaspr = make_jaspr(test_function, flatten_envs=True)(2)
print(jaspr)
{ lambda ; a:i64[] b:QuantumState. let
c:QubitArray d:QuantumState = jasp.create_qubits a b
e:Qubit = jasp.get_qubit c 0:i64[]
f:Qubit = jasp.get_qubit c 1:i64[]
g:QuantumState = jasp.quantum_gate[gate=cx] e f d
h:QuantumState = jasp.quantum_gate[gate=t_dg] e g
in (c, h) }
In the flattened form, the InversionEnvironment transformation has been applied: the order of the cx and t gates has been reversed, and the t gate has been transformed into t_dg (T-dagger). This is the default behavior as it produces more optimized Jaspr representations suitable for execution.
For more detailed information about the Jasp primitives and their semantics, see the MLIR Interface documentation.