Building a QAOA implementation in Jasp#

In this tutorial, we will explain step-by-step how to build a custom QAOA implementation in Jasp for the example of the MaxCut problem.

First, let us recall the problem description for MaxCut:

Given a Graph \(G = (V,E)\) find a bipartition \(S\), \(V\setminus S\) of the set of vertices \(V\) such that the number of edges between \(S\) and \(V\setminus S\) is maximal.

../../_images/maxcut_tutorial.png

For a graph \(G\) with \(n\) nodes, such a bipartition can be encoded with a QuantumVariable with \(n\) qubits: we measure the \(i\)-th qubit in 0 if the node \(i\) is in the set \(S\), and 1 if the node \(i\) is in the set \(V\setminus S\). The cut value is the number of edges \(e=(i,j)\) in \(G\) such that \(i\in S\) and \(j\in V\setminus S\).

In Jasp, varibales are decoded to integers (i.e. jax.numpy.int) and not to binrary strings. In this case, the binary representation of an integer encodes a bipartition of the graph \(G\). Therefore, repeated sampling from a QuantumVariable in a superposition state will result in an array of integers representing bipartitions of the graph \(G\). Within QAOA, we require a post processing function to compute the average cut value for an array of samples.

As a first step, we will learn how to write a post_processor that can be compiled using jax.jit into a highly optimized version using Just-In-Time (JIT) compilation. This can significantly speed up the execution of numerical computations.

Computing the Average Cut of a Graph with JAX#

Step 1: Import Libraries

First, we need to import the necessary libraries.

import jax.numpy as jnp
from jax import jit, vmap
import networkx as nx

from qrisp import QuantumVariable, h, rx, rzz
from qrisp.jasp import sample, minimize, jaspify, jrange, make_jaspr

Step 2: Define the Function to Extract Boolean Digits

We will define a function that extracts the value of a specific bit (digit) from an integer.

@jit
def extract_boolean_digit(integer, digit):
    return (integer >> digit) & 1

Step 3: Create the Cut Computer Function

The cut computer function calculates the cut value for a given integer representation of a bipartition of a graph. This function will use the edges of the graph to determine how many edges cross the cut.

def create_cut_computer(G):
    edge_list = jnp.array(G.edges())  # Convert edge list to JAX array

    @jit
    def cut_computer(x):
        x_uint = jnp.uint32(x)
        bools = extract_boolean_digit(x_uint, edge_list[:, 0]) != extract_boolean_digit(x_uint, edge_list[:, 1])
        cut = jnp.sum(bools)  # Count the number of edges crossing the cut
        return -cut

    return cut_computer

Step 4: Create the Sample Array Post Processor

This function will process an array of samples and compute the average cut using the cut_computer function. It will utilize JAX’s vectorization capabilities for efficiency.

def create_sample_array_post_processor(G):
    cut_computer = create_cut_computer(G)

    def post_processor(sample_array):
        # Use vmap for automatic vectorization
        cut_values = vmap(cut_computer)(sample_array)
        average_cut = jnp.mean(cut_values)  # Directly compute average
        return average_cut

    return post_processor

Step 5: Example Usage

Now we can create a graph and use our functions to compute the average cut.

# Create a sample graph
G = nx.Graph()
G.add_edges_from([(0, 1), (1, 2), (2, 0), (1, 3)])

# Create the post processor function
post_processor = create_sample_array_post_processor(G)

# Sample input array representing different cuts
sample_array = jnp.array([0b0001, 0b0010, 0b0100, 0b1000])  # Example binary representations

# Compute the average cut
average_cut = post_processor(sample_array)
print("Average Cut:", average_cut)

So far, we created a function using JAX to compute the average cut of a graph efficiently. We defined a few helper functions, including one for extracting bits and another for calculating cuts, and then used JAX’s vectorization capabilities to process multiple samples effectively.

Setting up the QAOA#

For additional details, we refer to the MaxCut QAOA Implementation tutorial.

Step 6: Define the QAOA ansatz

First, we will define the the cost operator and mixer.

def create_cost_operator(G):

    def apply_cost_operator(qv, gamma):
        for pair in list(G.edges()):
            rzz(gamma, qv[pair[0]], qv[pair[1]])

    return apply_cost_operator


def apply_mixer(qv, beta):
    rx(beta, qv)

Next, we define the QAOA ansatz that creates a QuantumVariable, brings it into uniform superposition and applies \(p\) layers of the parametrized cost operator and mixer.

def create_ansatz(G):

    apply_cost_operator = create_cost_operator(G)

    def ansatz(theta, p):
        qv = QuantumVariable(G.number_of_nodes())

        # Prepare uniform superposition
        h(qv)

        for i in jrange(p):
            apply_cost_operator(qv, theta[i])
            apply_mixer(qv, theta[p+i])

        return qv

    return ansatz

Step 7: Define the Objective Function

The objective function samples from the parametrized QAOA ansatz and computes the average cut value.

def create_objective(G):

    ansatz = create_ansatz(G)
    post_processor = create_sample_array_post_processor(G)

    def objective(theta, p):

        res_sample = sample(ansatz, shots=1000)(theta ,p)

        value = post_processor(res_sample)

        return value

    return objective

Step 8: Use a JAX-traceable Optimization Routine

We define the qaoa function for finding the optimal parameter values using the JAX-traceable minimize routine. It returns an array of optimal parameters and the average cost value for the optimal solution.

def qaoa():

    # Create a sample graph
    G = nx.Graph()
    G.add_edges_from([(0, 1), (1, 2), (2, 0), (1, 3)])

    ansatz = create_ansatz(G)
    objective = create_objective(G)

    # Number of layers
    p = 3

    # Initial point for theta
    x0 = jnp.array([0.5]*2*p)

    result = minimize(objective,x0,(p,))

    # Sample from ansatz state for optimal parameters
    samples = sample(ansatz, shots = 10)(result.x,p)

    return samples

Step 9: Run the QAOA

Finally, the jaspify method allows for running Jasp-traceable functions using the integrated Qrisp simulator. For hybrid algorithms like QAOA and VQE that rely on calculating expectation values based on sampling, the terminal_sampling feature significantly speeds up the simulation: samples are drawn from the state vector instead of performing repeated simulation and measurement of the quantum circuits.

jaspify(qaoa, terminal_sampling=True)()

You can also create the Jaspr object and compile to QIR using Catalyst.

jaspr = make_jaspr(qaoa)()
qir_str = jaspr.to_qir()