Source code for qrisp.jasp.optimization_tools.spsa
"""
********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""
import jax
import jax.numpy as jnp
from jax.lax import fori_loop
from jax.scipy.optimize import OptimizeResults
# https://www.jhuapl.edu/SPSA/PDF-SPSA/Spall_An_Overview.PDF
# Conditions: alpha <= 1; 1/6 <= gamma <= 1/2; 2*(alpha-gamma) > 1
[docs]
def spsa(fun, x0, args, maxiter=50, a=2.0, c=0.1, alpha=0.702, gamma=0.201, seed=3):
r"""
Minimize a scalar function of one or more variables using the `Simultaneous Perturbation Stochastic Approximation algorithm <https://en.wikipedia.org/wiki/Simultaneous_perturbation_stochastic_approximation>`_.
This algorithm aims at finding the optimal control $x^*$ minimizing a given loss fuction $f$:
.. math::
x^* = \text{argmin}_{x} f(x)
This is done by an iterative process starting from an initial guess $x_0$:
.. math::
x_{k+1} = x_k - a_kg_k(x_k)
where $a_k=\dfrac{a}{n^{\alpha}}$ for scaling parameters $a, \alpha>0$.
For each step $x_k$ the gradient is approximated by
.. math::
(g_k(x_k))_i = \frac{f(x_k+c_k\Delta_k)-f(x_k-c_k\Delta_k)}{2c_k(\Delta_k)_i}
where $c_k=\dfrac{c}{n^{\gamma}}$ for scaling parameters $c, \gamma>0$, and $\Delta_k$ is a random perturbation vector.
Parameters
----------
maxiter : int
Maximum number of iterations to perform. Each iteration requires 2 function evaluations.
a : float
Scaling parameter for update rule.
alpha : float
Scaling exponent for update rule.
c : float
Scaling parameter for gradient estimation.
gamma : float
Scaling exponent for gradient estimation.
Returns
-------
results
An `OptimizeResults <https://docs.jax.dev/en/latest/_autosummary/jax.scipy.optimize.OptimizeResults.html#jax.scipy.optimize.OptimizeResults>`_ object.
"""
rng = jax.random.PRNGKey(seed)
def body_fun(k, state):
x, rng = state
# Generate random perturbation delta with components +/-1
rng, rng_input = jax.random.split(rng)
delta = jax.random.choice(rng, jnp.array([1, -1]), shape=(*x.shape,))
ak = a / (k + 1) ** alpha
ck = c / (k + 1) ** gamma
# Evaluate loss function at perturbed points
x_plus = x + ck * delta
x_minus = x - ck * delta
loss_plus = fun(x_plus, *args)
loss_minus = fun(x_minus, *args)
# Approximate gradient
gk = (loss_plus - loss_minus) / (2.0 * ck * delta)
# Update parameters
x = x - ak * gk
return x, rng
from qrisp.jasp import make_tracer
x, rng = fori_loop(0, make_tracer(maxiter), body_fun, (x0, rng))
fx = fun(x, *args)
return OptimizeResults(x, True, 0, fx, None, None, 2 * maxiter + 1, 0, maxiter)