Source code for qrisp.alg_primitives.arithmetic.jasp_arithmetic.jasp_bigintiger

"""
********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""


import jax.numpy as jnp
import jax.lax as lax
import jax
from dataclasses import dataclass
from jax.tree_util import register_pytree_node_class

BASE = 2**32
DTYPE = jnp.uint32
BASE_FL = float(BASE)


[docs] @register_pytree_node_class @dataclass(frozen=True) class BigInteger: """ Fixed-width, little-endian base-2^32 big integer for JAX. This type represents non-negative integers using a fixed number of 32-bit limbs (dtype=uint32) in little-endian order (digits[0] is the least significant limb). It is compatible with JAX transformations (jit, vmap, pytrees). All arithmetic (addition, subtraction, multiplication, shifts, division, modulo, bitwise) is performed modulo 2^(32*n), where n = len(digits). Overflow and underflow beyond the most-significant limb are discarded. Notes ----- - Width (number of limbs) is determined by `digits.shape[0]` and remains constant; operations do not change the number of limbs. - Operators assume both operands have the same number of limbs. Attributes ---------- digits : jnp.ndarray Little-endian limbs (dtype=uint32) of length `n`. """ digits: jnp.ndarray # Little-endian base-2^32 @jax.jit def __call__(self): """ Return a float64 approximation of the integer value. Computes sum_i digits[i] * (2^32)^i as float64. Exact only for values that fit into float64; larger integers may lose precision. Returns ------- jnp.float64 Approximate numeric value (float64). """ r = lax.fori_loop( 0, self.digits.shape[0], lambda i, val: jnp.float64( self.digits[i]) * BASE_FL ** jnp.float64(i) + val, 0.0, ) return r def tree_flatten(self): """ PyTree flatten for JAX. Returns ------- tuple A pair `(children, aux_data)` where `children` is a tuple containing the digits array, and `aux_data` is `None`. """ return (self.digits,), None @classmethod def tree_unflatten(cls, aux_data, children): """ PyTree unflatten for JAX. Parameters ---------- aux_data : Any Auxiliary data (unused, expected `None`). children : tuple Tuple containing the digits array. Returns ------- BigInteger Reconstructed instance. """ return cls(*children)
[docs] @staticmethod def create_static(n, size): """ Create a BigInteger from Python using pure Python loops. This variant does not use JAX primitives and is suitable for static construction (e.g., outside `jit`). The result has exactly `size` limbs with wraparound modulo 2^(32*size). Parameters ---------- n : int or float Non-negative number. Floats are truncated; very large floats (> 2**53) may lose precision before conversion. size : int Number of limbs (digits) to allocate. Returns ------- BigInteger Fixed-width representation of `n` modulo 2^(32*size). """ digits = [] for i in range(size): digits.append(n % BASE) n //= BASE return BigInteger(jnp.array(digits, dtype=DTYPE))
[docs] @staticmethod def create(n, size): """ Create a BigInteger using JAX primitives. Constructs a fixed-width BigInteger with exactly `size` limbs, interpreting the input modulo 2^(32*size). JIT-friendly. Parameters ---------- n : int or float or jnp.integer or jnp.floating Non-negative number. Floats are truncated; very large floats (> 2**53) may lose precision before conversion. size : int Number of limbs (digits) to allocate. Returns ------- BigInteger Fixed-width representation of `n` modulo 2^(32*size). Notes ----- When called with a Python literal outside `jit`, the value must fit into JAX's host integer range (typically up to 64 bits). For arbitrarily large Python integers, prefer `create_static`. """ def body_fun(i, args): digits, num = args digits = digits.at[i].set(jnp.uint32(num % BASE)) num //= BASE return digits, num digits, _ = lax.fori_loop(0, size, body_fun, (jnp.zeros(size, dtype=DTYPE), n)) return BigInteger(digits)
@staticmethod def create_dynamic(n, size): """ Alias of `create`. Parameters ---------- n : int or float or jnp.integer or jnp.floating Non-negative number. size : int Number of limbs (digits) to allocate. Returns ------- BigInteger Fixed-width representation modulo 2^(32*size). See Also -------- BigInteger.create : JAX-compatible constructor. """ return BigInteger.create(n, size) @jax.jit def __add__(self, other: "BigInteger") -> "BigInteger": """ Add two BigIntegers with wraparound modulo 2^(32*n). If `other` is a scalar, it is converted to the same width. The number of limbs `n` is taken from `self`. Parameters ---------- other : BigInteger or int Addend. Returns ------- BigInteger (self + other) mod 2^(32*n). """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) a, b = self.digits, other.digits result = jnp.zeros_like(a) carry = jnp.uint64(0) def add_step(i, state): carry, result = state s = jnp.uint64(a[i]) + jnp.uint64(b[i]) + carry digit = jnp.uint32(s & (BASE - 1)) new_carry = jnp.uint64(s >> 32) result = result.at[i].set(jnp.uint32(digit)) return new_carry, result carry, result = lax.fori_loop(0, a.shape[0], add_step, (carry, result)) return BigInteger(result) @jax.jit def __sub__(self, other: "BigInteger") -> "BigInteger": """ Subtract two BigIntegers with wraparound modulo 2^(32*n). If `other` is a scalar, it is converted to the same width. Computes (self - other) mod 2^(32*n) using borrow propagation. Parameters ---------- other : BigInteger or int Subtrahend. Returns ------- BigInteger (self - other) mod 2^(32*n). """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) a, b = self.digits, other.digits result = jnp.zeros_like(a) def add_step(i, state): carry, result = state s, new_carry = lax.cond( jnp.uint64(a[i]) >= jnp.uint64(b[i]) + carry, lambda: (jnp.uint32(a[i] - b[i] - carry), 0), lambda: (jnp.uint32(jnp.uint64( a[i]) + jnp.uint64(BASE) - jnp.uint64(b[i]) - carry), 1), ) result = result.at[i].set(jnp.uint32(s)) return new_carry, result carry, result = lax.fori_loop(0, a.shape[0], add_step, (0, result)) return BigInteger(result) @jax.jit def __sub_alt__(self, other: "BigInteger") -> "BigInteger": """ Alternative subtraction using bitwise complement identity. Computes (self - other) as `~((~self) + other)` with wraparound modulo 2^(32*n). Parameters ---------- other : BigInteger or int Subtrahend. Returns ------- BigInteger (self - other) mod 2^(32*n). """ return ~((~self) + other) @jax.jit def __mul__(self, other: "BigInteger") -> "BigInteger": """ Multiply two BigIntegers with wraparound modulo 2^(32*n). Implements schoolbook multiplication and accumulates into `n` limbs, discarding overflow beyond the n-th limb (wraparound). Parameters ---------- other : BigInteger or int Multiplier. Returns ------- BigInteger (self * other) mod 2^(32*n). """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) a, b = self.digits, other.digits # result is uint32, but always promote to uint64 for arithmetic result = jnp.zeros(n, dtype=DTYPE) def outer_loop(i, result): carry0 = jnp.uint64(0) def inner_body(j, state): res, carry = state k = i + j tmp = ( jnp.uint64(res[k]) + jnp.uint64(a[i]) * jnp.uint64(b[j]) + carry ) new_digit = jnp.uint32(tmp & jnp.uint64(0xFFFFFFFF)) res = res.at[k].set(new_digit) carry = tmp >> jnp.uint64(32) return (res, carry) # j runs where k = i + j < n -> j in [0, n - i) result, _ = lax.fori_loop(0, n - i, inner_body, (result, carry0)) # Drop any remaining carry (mod base^n) return result result = lax.fori_loop(0, n, outer_loop, result) return BigInteger(result) @jax.jit def __pow__(self, other): """ Integer exponentiation (square-and-multiply). Performs `self ** other` by binary exponentiation, modulo 2^(32*n). Parameters ---------- other : int or jnp.integer Exponent (>= 0). Returns ------- BigInteger self raised to power `other` modulo 2^(32*n). Notes ----- For `other == 0`, returns 1 (the multiplicative identity) with the same width. """ n = self.digits.shape[0] base = self exp = jnp.asarray(other, dtype=jnp.uint64) acc = BigInteger.create(1, n) def cond_fun(state): base, exp, acc = state return exp > 0 def body_fun(state): base, exp, acc = state acc = lax.cond((exp & jnp.uint64(1)) == jnp.uint64(1), lambda a: a * base, lambda a: a, acc) base = base * base exp = exp >> jnp.uint64(1) return base, exp, acc _, _, acc = lax.while_loop(cond_fun, body_fun, (base, exp, acc)) return acc def __repr__(self): """ String representation with limbs in little-endian order. Returns ------- str String representation of digits (uint32 list). """ return self.digits.__repr__() #return f"BigInteger(digits={self.digits.tolist()})" @jax.jit def __lt__(self, other: "BigInteger"): """ Less-than comparison between two fixed-width BigIntegers. If `other` is a scalar, it is converted to the same width. Requires both operands to have the same number of limbs. Parameters ---------- other : BigInteger or int Right-hand operand. Returns ------- jnp.bool_ True if `self < other` (unsigned), else False. """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) m = other.digits.shape[0] assert n == m d0 = self.digits d1 = other.digits def body_fun(val): i, res_found = val i -= 1 res_found = lax.cond(d0[i] < d1[i], lambda: 1, lambda: -5) res_found = lax.cond(d0[i] > d1[i], lambda: 0, lambda: res_found) return i, res_found def cond_fun(val): i, res_found = val return jnp.logical_and(res_found == -5, i > 0) _, res = lax.while_loop(cond_fun, body_fun, (n, -5)) res = lax.cond(res == -5, lambda: 0, lambda: res) return (res != 0) @jax.jit def __eq__(self, other: "BigInteger"): """ Equality comparison between two fixed-width BigIntegers. If `other` is a scalar, it is converted to the same width. Parameters ---------- other : BigInteger or int Right-hand operand. Returns ------- jnp.bool_ True if all limbs are equal, else False. """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) m = other.digits.shape[0] assert n == m d0 = self.digits d1 = other.digits return jnp.all(d0 == d1) @jax.jit def __ne__(self, other: "BigInteger"): """ Inequality comparison between two fixed-width BigIntegers. Parameters ---------- other : BigInteger or int Right-hand operand. Returns ------- jnp.bool_ True if any limb differs, else False. """ return jnp.logical_not(self == other) @jax.jit def __le__(self, other: "BigInteger"): """ Less-or-equal comparison between two fixed-width BigIntegers. If `other` is a scalar, it is converted to the same width. Requires both operands to have the same number of limbs. Parameters ---------- other : BigInteger or int Right-hand operand. Returns ------- jnp.bool_ True if `self <= other` (unsigned), else False. """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) m = other.digits.shape[0] assert n == m d0 = self.digits d1 = other.digits def body_fun(val): i, res_found = val i -= 1 res_found = lax.cond(d0[i] < d1[i], lambda: 1, lambda: -5) res_found = lax.cond(d0[i] > d1[i], lambda: 0, lambda: res_found) return i, res_found def cond_fun(val): i, res_found = val return jnp.logical_and(res_found == -5, i > 0) _, res = lax.while_loop(cond_fun, body_fun, (n, -5)) res = lax.cond(res == -5, lambda: 1, lambda: res) return (res != 0) @jax.jit def __lshift__(self, shift): """ Logical left shift by a non-negative number of bits. Shifts bits left by `shift` and fills with zeros, within fixed width. Bits shifted out of the most-significant end are discarded. Parameters ---------- shift : int or jnp.integer Number of bits to shift (>= 0). Returns ------- BigInteger (self << shift) mod 2^(32*n). """ total_bits = self.digits.shape[0] * 32 zeros = BigInteger(jnp.zeros_like(self.digits)) def do_shift(_): def body_fun(i, x): return lax.cond( self.get_bit(i) != 0, lambda: x.flip_bit(i + shift), lambda: x, ) return lax.fori_loop(0, total_bits - shift, body_fun, zeros) return lax.cond( jnp.asarray(shift) >= total_bits, lambda _: zeros, do_shift, operand=None, ) @jax.jit def __rshift__(self, shift): """ Logical right shift by a non-negative number of bits. Shifts bits right by `shift` and fills with zeros, within fixed width. Bits shifted out of the least-significant end are discarded. Parameters ---------- shift : int or jnp.integer Number of bits to shift (>= 0). Returns ------- BigInteger (self >> shift) within fixed width. """ total_bits = self.digits.shape[0] * 32 zeros = BigInteger(jnp.zeros_like(self.digits)) def do_shift(_): def body_fun(i, x): return lax.cond( self.get_bit(i) != 0, lambda: x.flip_bit(i - shift), lambda: x, ) return lax.fori_loop(shift, total_bits, body_fun, zeros) return lax.cond( jnp.asarray(shift) >= total_bits, lambda _: zeros, do_shift, operand=None, ) @jax.jit def __and__(self, other: "BigInteger") -> "BigInteger": """ Bitwise AND between two fixed-width BigIntegers. If `other` is a scalar, it is converted to the same width. Parameters ---------- other : BigInteger or int Right-hand operand. Returns ------- BigInteger self & other (limb-wise). """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) return BigInteger(self.digits & other.digits) @jax.jit def __or__(self, other: "BigInteger") -> "BigInteger": """ Bitwise OR between two fixed-width BigIntegers. If `other` is a scalar, it is converted to the same width. Parameters ---------- other : BigInteger or int Right-hand operand. Returns ------- BigInteger self | other (limb-wise). """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) return BigInteger(self.digits | other.digits) @jax.jit def __xor__(self, other: "BigInteger") -> "BigInteger": """ Bitwise XOR between two fixed-width BigIntegers. If `other` is a scalar, it is converted to the same width. Parameters ---------- other : BigInteger or int Right-hand operand. Returns ------- BigInteger self ^ other (limb-wise). """ n = self.digits.shape[0] if not isinstance(other, BigInteger): other = BigInteger.create(other, n) return BigInteger(self.digits ^ other.digits) @jax.jit def __invert__(self) -> "BigInteger": """ Bitwise NOT on a fixed-width BigInteger. Returns ------- BigInteger Bitwise complement of `self`, limb-wise. """ return BigInteger(~self.digits) @jax.jit def __mod__(self, other: "BigInteger"): """ Modulo operation `self % other` (fixed-width). If `other` is a scalar, it is converted to the same width. Uses `remainder_division`. Parameters ---------- other : BigInteger or int Modulus (must be non-zero). Returns ------- BigInteger Remainder `r` with 0 <= r < other (when other != 0). """ if not isinstance(other, BigInteger): other = BigInteger.create(other, self.digits.shape[0]) r, q = self.remainder_division(other) return r @jax.jit def __floordiv__(self, other: "BigInteger"): """ Floor division `self // other` (fixed-width). If `other` is a scalar, it is converted to the same width. Uses `remainder_division`. Parameters ---------- other : BigInteger or int Divisor (must be non-zero). Returns ------- BigInteger Quotient `q` such that self = other*q + r, 0 <= r < other. """ if not isinstance(other, BigInteger): other = BigInteger.create(other, self.digits.shape[0]) r, q = self.remainder_division(other) return q @jax.jit def get_bit(self, i: int): """ Get the value of the i-th bit (0-based, LSB=bit 0). Parameters ---------- i : int or jnp.integer Bit index (0 <= i < 32*n). Returns ------- jnp.uint32 Either 0 or 1. """ pos = i // 32 pos_in = i % 32 return (self.digits[pos] >> pos_in) & 1 @jax.jit def flip_bit(self, i: int): """ Toggle the i-th bit (0-based, LSB=bit 0). Parameters ---------- i : int or jnp.integer Bit index to flip (0 <= i < 32*n). Returns ------- BigInteger Copy of `self` with the specified bit toggled. """ pos = i // 32 pos_in = i % 32 ds = jnp.copy(self.digits) ds = ds.at[pos].set(jnp.uint32(ds[pos] ^ (1 << pos_in))) return BigInteger(ds) @jax.jit def bit_size(self): """ Return the position of the most significant set bit plus one. Approximates the bit-length of the value: floor(log2(x)) + 1. Returns 0 for zero. Returns ------- jnp.int64 Bit-length of the value (0 for zero). """ n = self.digits.shape[0] is_zero = jnp.all(self.digits == 0) def nonzero_len(): # Find index ms of the most-significant non-zero limb (scan from top) def cond_fun(i): return jnp.logical_and(i > 0, self.digits[i] == jnp.uint32(0)) def body_fun(i): return i - 1 ms = lax.while_loop(cond_fun, body_fun, n - 1) limb = self.digits[ms] # bit length of a 32-bit limb limb_bits = (jnp.floor(jnp.log2(jnp.float64(limb))).astype(jnp.int64) + 1) return jnp.int64(32) * jnp.int64(ms) + limb_bits return lax.cond(is_zero, lambda: jnp.int64(0), nonzero_len) def remainder_division(self, other: "BigInteger"): """ Exact division using Knuth long division (base 2^32). Computes quotient and remainder such that: `self = other * q + r`, with `0 <= r < other`. Parameters ---------- other : BigInteger Divisor (must be non-zero), same width as `self`. Returns ------- tuple of BigInteger `(r, q)` where `r` is the remainder and `q` is the quotient. Notes ----- Uses a normalized Knuth division with limb base 2^32. Both `self` and `other` must have the same number of limbs. """ n = self.digits.shape[0] m = other.digits.shape[0] assert n == m r_digits, q_digits = _remainder_division_knuth(self.digits, other.digits) return BigInteger(r_digits), BigInteger(q_digits) def get_larger(self): """ Given a BigInteger with n limbs, return a new BigInteger with 2n limbs and the same number Returns ------- BigInteger Larger BigInteger with 2n limbs """ pad = jnp.zeros(self.digits.shape[0], dtype=self.digits.dtype) return BigInteger(jnp.concatenate([self.digits, pad], axis=0))
@jax.jit def _clz32(x): """ Count leading zeros in a 32-bit word. Parameters ---------- x : jnp.uint32 Input 32-bit word. Returns ------- jnp.int32 Number of leading zeros in `x` (0..32). """ # Count leading zeros in a 32-bit word (simple loop) def cond_fun(state): x, s = state return jnp.logical_and(s < jnp.int32(32), (x & jnp.uint32(0x80000000)) == jnp.uint32(0)) def body_fun(state): x, s = state return jnp.uint32(x << jnp.uint32(1)), s + jnp.int32(1) _, s = lax.while_loop(cond_fun, body_fun, (x, jnp.int32(0))) return s @jax.jit def _ms_length(a: jnp.ndarray): """ Effective limb length (highest non-zero index + 1). Parameters ---------- a : jnp.ndarray Little-endian uint32 limb array. Returns ------- jnp.int32 Effective length in limbs (0 if all zero). """ # Effective limb length: highest non-zero index + 1, or 0 if all zero. n = a.shape[0] all_zero = jnp.all(a == jnp.uint32(0)) def find_ms(): def cond_fun(i): return jnp.logical_and(i > 0, a[i] == jnp.uint32(0)) def body_fun(i): return i - 1 ms = lax.while_loop(cond_fun, body_fun, jnp.int32(n - 1)) return ms + jnp.int32(1) return lax.cond(all_zero, lambda: jnp.int32(0), find_ms) @jax.jit def _shl_bits(arr: jnp.ndarray, s): """ Shift-left by `s` bits across limbs (0 <= s < 32). Parameters ---------- arr : jnp.ndarray Little-endian uint32 limb array to shift. s : jnp.int32 Bit count (0..31). Returns ------- tuple `(out, carry_out)` where `out` is the shifted array and `carry_out` is the carry from the most-significant limb (uint32). """ # Shift-left by s bits (0<=s<32) across limbs (little endian). s_u = jnp.uint32(s) n = arr.shape[0] def no_shift(): return jnp.array(arr, dtype=jnp.uint32), jnp.uint32(0) def do_shift(): out = jnp.zeros_like(arr) carry = jnp.uint32(0) def body(i, state): carry, out = state ai = arr[i] low = jnp.uint32((jnp.uint64(ai) << jnp.uint64(s_u)) & jnp.uint64(0xFFFFFFFF)) new_digit = jnp.uint32(low | carry) out = out.at[i].set(new_digit) # IMPORTANT: parentheses for precedence new_carry = jnp.uint32(jnp.uint64(ai) >> (jnp.uint64(32) - jnp.uint64(s_u))) return new_carry, out carry_out, out = lax.fori_loop(0, n, body, (carry, out)) return out, carry_out return lax.cond(s == jnp.int32(0), no_shift, do_shift) @jax.jit def _shr_bits(arr: jnp.ndarray, s) -> jnp.ndarray: """ Shift-right by `s` bits across limbs (0 <= s < 32). Parameters ---------- arr : jnp.ndarray Little-endian uint32 limb array to shift. s : jnp.int32 Bit count (0..31). Returns ------- jnp.ndarray Shifted array (little-endian, uint32). """ # Shift-right by s bits (0<=s<32) across limbs (little endian). s_u = jnp.uint32(s) n = arr.shape[0] def no_shift(): return jnp.array(arr, dtype=jnp.uint32) def do_shift(): out = jnp.zeros_like(arr) carry = jnp.uint32(0) def body(i, state): carry, out = state idx = n - 1 - i ai = arr[idx] high = jnp.uint32((jnp.uint64(carry) << (jnp.uint64( 32) - jnp.uint64(s_u))) & jnp.uint64(0xFFFFFFFF)) new_digit = jnp.uint32( (jnp.uint64(ai) >> jnp.uint64(s_u)) | jnp.uint64(high)) out = out.at[idx].set(new_digit) new_carry = jnp.uint32(ai & (jnp.uint32(1) << s_u) - jnp.uint32(1)) return new_carry, out _, out = lax.fori_loop(0, n, body, (carry, out)) return out return lax.cond(s == jnp.int32(0), no_shift, do_shift) @jax.jit def _divmod_single_limb(u: jnp.ndarray, d): """ Divide a multi-limb number by a single 32-bit limb. Parameters ---------- u : jnp.ndarray Little-endian uint32 limb array (dividend). d : jnp.uint32 Single-limb divisor (must be non-zero). Returns ------- tuple `(q_digits, r_low)` where `q_digits` is the quotient array (uint32) and `r_low` is the remainder (uint32). """ # Divide multi-limb u by single 32-bit d (d != 0). Returns (q_digits, r_low). n = u.shape[0] q = jnp.zeros_like(u) rem = jnp.uint64(0) def body(i, state): q, rem = state idx = n - 1 - i cur = (rem << jnp.uint64(32)) + jnp.uint64(u[idx]) qi = jnp.uint32(cur // jnp.uint64(d)) rem = jnp.uint64(cur % jnp.uint64(d)) q = q.at[idx].set(qi) return q, rem q, rem = lax.fori_loop(0, n, body, (q, rem)) return q, jnp.uint32(rem) @jax.jit def _remainder_division_knuth(u: jnp.ndarray, v: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """ Knuth long division (base 2^32) for arrays of equal length. Parameters ---------- u : jnp.ndarray Dividend limbs (little-endian uint32). v : jnp.ndarray Divisor limbs (little-endian uint32), not all zero. Returns ------- tuple of jnp.ndarray `(r_digits, q_digits)` where both are little-endian uint32 arrays of the same length as `u`. The relation `u = v * q + r` holds with `0 <= r < v`. Notes ----- Normalizes the divisor and uses a classic multi-precision division scheme. """ # Knuth long division base 2^32. Inputs uint32 arrays (same length N). # Returns (r_digits, q_digits), both length N. N = u.shape[0] m = _ms_length(v) # effective length of divisor n_eff = _ms_length(u) # effective length of dividend # Divisor zero => return (0, 0) def div_by_zero(): return jnp.zeros_like(u), jnp.zeros_like(u) def normal_path(): # If u < v: q = 0, r = u def less_path(): return u, jnp.zeros_like(u) def ge_path(): # Single-limb fast path def single_limb_path(): d = v[0] q, r_low = _divmod_single_limb(u, d) r = jnp.zeros_like(u).at[0].set(r_low) return r, q # Multi-limb path (m >= 2) def multi_limb_path(): ms_idx = m - jnp.int32(1) v_msw = v[ms_idx] s = _clz32(v_msw) # 0..31 v_norm, _ = _shl_bits(v, s) # length N u_norm_part, carry_u = _shl_bits(u, s) # length N, carry u_norm = jnp.concatenate([u_norm_part, jnp.array( [carry_u], dtype=jnp.uint32)], axis=0) # N+1 q = jnp.zeros_like(u) # how many quotient positions we fill j_count = jnp.maximum(n_eff - m, jnp.int32(0)) BASE_MASK = jnp.uint64(0xFFFFFFFF) BASE64 = jnp.uint64(1) << jnp.uint64(32) def body(t, state): u_norm, q = state active = t <= j_count def do_step(state): u_norm, q = state j = j_count - t # 0..j_count ujm = j + m # index to top limb of current window # Estimate qhat from top two limbs of u_norm and top limb of v_norm u2 = jnp.uint64(u_norm[ujm]) u1 = jnp.uint64(u_norm[ujm - 1]) num = (u2 << jnp.uint64(32)) + u1 den = jnp.uint64(v_norm[m - 1]) qhat = jnp.minimum(num // den, jnp.uint64(0xFFFFFFFF)) rhat = num - qhat * den # Refinement (at most twice) vm2 = jnp.uint64(v_norm[m - 2]) u0 = jnp.uint64(u_norm[ujm - 2]) cond1 = jnp.logical_and(qhat == jnp.uint64(0xFFFFFFFF), qhat * vm2 > (rhat << jnp.uint64(32)) + u0) qhat = jnp.where(cond1, qhat - jnp.uint64(1), qhat) rhat = jnp.where(cond1, rhat + den, rhat) cond2 = qhat * vm2 > (rhat << jnp.uint64(32)) + u0 qhat = jnp.where(cond2, qhat - jnp.uint64(1), qhat) rhat = jnp.where(cond2, rhat + den, rhat) # Subtract qhat * v_norm from u_norm[j ... j+m] carry = jnp.uint64(0) def sub_body(i, sstate): u_norm, carry = sstate vi = jnp.uint64(v_norm[i]) p = qhat * vi + carry p_low = jnp.uint32(p & BASE_MASK) p_high = p >> jnp.uint64(32) idx = j + i u_di = u_norm[idx] borrow = u_di < p_low new_u = jnp.uint32( (jnp.uint64(u_di) + BASE64 - jnp.uint64(p_low)) & BASE_MASK) u_norm = u_norm.at[idx].set(new_u) carry = p_high + jnp.uint64(borrow) return u_norm, carry u_norm, carry = lax.fori_loop(0, m, sub_body, (u_norm, carry)) top_before = u_norm[ujm] underflow = top_before < jnp.uint32(carry) top_after = jnp.uint32( (jnp.uint64(top_before) + BASE64 - jnp.uint64(carry)) & BASE_MASK) u_norm = u_norm.at[ujm].set(top_after) def fix_underflow(uq_state): u_norm, q, qhat, j = uq_state qhat = qhat - jnp.uint64(1) carry2 = jnp.uint64(0) def add_body(i, astate): u_norm, carry2 = astate idx = j + i ssum = jnp.uint64(u_norm[idx]) + \ jnp.uint64(v_norm[i]) + carry2 new_digit = jnp.uint32(ssum & BASE_MASK) carry2 = ssum >> jnp.uint64(32) u_norm = u_norm.at[idx].set(new_digit) return u_norm, carry2 u_norm, carry2 = lax.fori_loop( 0, m, add_body, (u_norm, carry2)) u_norm = u_norm.at[ujm].set(jnp.uint32( jnp.uint64(u_norm[ujm]) + jnp.uint64(1))) return u_norm, q, qhat, j u_norm, q, qhat, _ = lax.cond( underflow, fix_underflow, lambda uq_state: uq_state, operand=(u_norm, q, qhat, j), ) q = q.at[j].set(jnp.uint32(qhat & BASE_MASK)) return u_norm, q u_norm, q = lax.cond( active, do_step, lambda state: state, operand=(u_norm, q)) return u_norm, q u_norm, q = lax.fori_loop(0, N, body, (u_norm, q)) # Build r_full by copying the first m limbs (dynamic) of u_norm, then de-normalize r_norm_full = jnp.zeros_like(u) def copy_body(i, r_acc): val = jnp.where(i < m, u_norm[i], jnp.uint32(0)) return r_acc.at[i].set(val) r_norm_full = lax.fori_loop(0, N, copy_body, r_norm_full) r = _shr_bits(r_norm_full, s) return r, q return lax.cond(m == jnp.int32(1), single_limb_path, multi_limb_path) return lax.cond(n_eff < m, less_path, ge_path) return lax.cond(m == jnp.int32(0), div_by_zero, normal_path) @jax.jit def bi_modinv(a: BigInteger, m: BigInteger) -> BigInteger: """ Modular inverse using an Extended Euclidean Algorithm variant. Finds `t` such that `(a * t) % m == 1`, assuming `gcd(a, m) == 1`. Returns the non-negative representative in `[0, m)`. Parameters ---------- a : BigInteger Value to invert (same width as `m`). m : BigInteger Modulus (must be > 0; often odd in Montgomery contexts). Returns ------- BigInteger `t = a^{-1} mod m`. Raises ------ AssertionError If inputs do not have the same width. Notes ----- Uses `//` and `%` which rely on exact multi-precision division implemented in this module. Both `a` and `m` are treated as fixed-width unsigned values. """ # Widen to 2n limbs n = a.digits.shape[0] pad = jnp.zeros(n, dtype=DTYPE) a = BigInteger(jnp.concatenate([a.digits, pad], axis=0)) m = BigInteger(jnp.concatenate([m.digits, pad], axis=0)) bi0 = BigInteger.create(0, 2*n) bi1 = BigInteger.create(1, 2*n) t, new_t = bi0, bi1 r, new_r = m, a def cond(state): _, new_t, r, new_r = state return new_r != bi0 def body(state): t, new_t, r, new_r = state quotient = r // new_r # Unsigned arithmetic with wraparound simulated modulo behavior t_updated = (t + m - (quotient * new_t) % m) % m r_updated = r - quotient * new_r return new_t, t_updated, new_r, r_updated final_t, _, final_r, _ = lax.while_loop(cond, body, (t, new_t, r, new_r)) return BigInteger(final_t.digits[0:n]) @jax.jit def bi_extended_euclidean(a, b): """ Extended Euclidean Algorithm (fixed-width arithmetic). Computes `g, x, y` such that `a*x + b*y = g = gcd(a, b)`. Compatible with `jax.jit`. Inputs are expected to be `BigInteger` of the same width (or scalars that are promoted to that width). Parameters ---------- a : BigInteger or int First number. b : BigInteger or int Second number. Returns ------- tuple `(g, x, y)` where `g` is the gcd and `x, y` are Bézout coefficients (all `BigInteger`). Notes ----- All operations occur modulo 2^(32*n), so the interpretation follows fixed-width arithmetic semantics. """ n = a.digits.shape[0] if not isinstance(b, BigInteger): b = BigInteger.create(b, n) bi0 = BigInteger.create(0, n) bi1 = BigInteger.create(1, n) # State: (r, old_r, s, old_s, t, old_t) state = (b, a, bi0, bi1, bi1, bi0) def cond_fun(state): r, old_r, s, old_s, t, old_t = state return r != bi0 def body_fun(state): r, old_r, s, old_s, t, old_t = state quotient = old_r // r return ( old_r - quotient * r, r, old_s - quotient * s, s, old_t - quotient * t, t, ) r, old_r, s, old_s, t, old_t = lax.while_loop(cond_fun, body_fun, state) # old_r is gcd, old_s and old_t are the Bézout coefficients return old_r, old_s, old_t @jax.jit def bi_montgomery_encode(x: BigInteger, R: BigInteger, modulus: BigInteger) -> BigInteger: """ Montgomery encode: map x to (x * R) mod modulus, without intermediate wraparound. This routine widens all operands to 2n limbs (where n is the current width) to compute the product exactly, reduces the 2n-limb product modulo the widened modulus (also 2n limbs), and then shrinks the remainder back to n limbs. This avoids the truncation that would occur if the n-limb product were computed directly (since BigInteger arithmetic is modulo 2^(32*n)). Parameters ---------- x : BigInteger Value to encode (width n limbs). R : BigInteger Montgomery radix (width n limbs), typically R = 2^m in the same limb system or equivalent precomputed value for the chosen modulus. modulus : BigInteger Modulus (width n limbs). In radix-2 Montgomery arithmetic, modulus is assumed odd. Returns ------- BigInteger x in Montgomery form, i.e., (x * R) mod modulus, represented on n limbs. Notes ----- - The 2n-limb widening ensures that (x * R) is computed without loss before the modulo reduction. The remainder is strictly less than modulus < 2^(32*n), so it fits in n limbs and can be safely truncated back. - All three inputs must share the same limb width n. """ n = modulus.digits.shape[0] # Widen to 2n limbs pad = jnp.zeros(n, dtype=DTYPE) x2 = BigInteger(jnp.concatenate([x.digits, pad], axis=0)) R2 = BigInteger(jnp.concatenate([R.digits, pad], axis=0)) N2 = BigInteger(jnp.concatenate([modulus.digits, pad], axis=0)) # Exact product at 2n limbs, reduce modulo widened modulus, then shrink to n limbs prod2 = x2 * R2 rem2 = prod2 % N2 return BigInteger(rem2.digits[:n]) @jax.jit def bi_montgomery_decode(x_mon: BigInteger, R: BigInteger, modulus: BigInteger) -> BigInteger: """ Montgomery decode: map x_mon to (x_mon * R^{-1}) mod modulus, without wraparound. This routine computes invR = R^{-1} mod modulus at width n, widens x_mon and invR to 2n limbs, performs the 2n-limb product and reduction modulo the widened modulus, and then shrinks the remainder back to n limbs. This avoids truncation during the multiply step in fixed-width arithmetic. Parameters ---------- x_mon : BigInteger Value in Montgomery form (width n limbs). R : BigInteger Montgomery radix used for encoding (width n limbs). modulus : BigInteger Modulus (width n limbs). In radix-2 Montgomery arithmetic, modulus is assumed odd. Returns ------- BigInteger Decoded value, i.e., (x_mon * R^{-1}) mod modulus, represented on n limbs. Notes ----- - The 2n-limb widening ensures that (x_mon * invR) is computed exactly before the modulo reduction. The remainder is < modulus < 2^(32*n), so it fits within n limbs. - All three inputs must share the same limb width n. """ n = modulus.digits.shape[0] # Compute R^{-1} mod modulus at n limbs invR = bi_modinv(R, modulus) # Widen operands and modulus to 2n limbs pad = jnp.zeros(n, dtype=DTYPE) x2 = BigInteger(jnp.concatenate([x_mon.digits, pad], axis=0)) inv2 = BigInteger(jnp.concatenate([invR.digits, pad], axis=0)) N2 = BigInteger(jnp.concatenate([modulus.digits, pad], axis=0)) # Exact product at 2n limbs, reduce modulo widened modulus, then shrink to n limbs prod2 = x2 * inv2 rem2 = prod2 % N2 return BigInteger(rem2.digits[:n]) @jax.jit def _bi_all_ones(n_limbs: int) -> BigInteger: """ Internal: BigInteger with all digits set to 0xFFFFFFFF (acts like +infinity bound). Parameters ---------- n_limbs : int Number of 32-bit limbs. Returns ------- BigInteger All-ones value, width n_limbs. """ return BigInteger(jnp.full((n_limbs,), jnp.uint32(0xFFFFFFFF), dtype=DTYPE)) def bi_contfrac_best_approx(a: BigInteger, b: BigInteger, max_den: BigInteger | None = None, max_iters: int | None = None) -> tuple[BigInteger, BigInteger]: """ Best rational approximation p/q to a/b via continued fractions, with q <= max_den. This computes a convergent (or intermediate convergent) of the continued fraction expansion of a/b using the standard recurrence, stopping when either: - remainder becomes 0 (exact rational found), or - the next denominator would exceed `max_den`; in that case the best "intermediate" convergent within the bound is returned using t = floor((max_den - q0)/q1), p = t*p1 + p0, q = t*q1 + q0. The implementation uses a single JAX while_loop; it allocates no dynamic lists and is suitable for very large BigIntegers and JIT/vmap usage. Parameters ---------- a : BigInteger Numerator (non-negative). Width must match `b`. b : BigInteger Denominator (non-zero, non-negative). Width must match `a`. max_den : BigInteger or None, optional Upper bound on the denominator q. If None, a built-in all-ones value of the same width is used (effectively unbounded). max_iters : int or None, optional Maximum number of continued-fraction steps (safety cap). If None, defaults to 2*bitlen + 4, where bitlen = 32 * n_limbs. Returns ------- (BigInteger, BigInteger) A pair (p, q) with q > 0, p/q the best approximation to a/b under the bound q <= max_den. For exact rationals, returns the exact p/q. Raises ------ AssertionError If widths of a and b differ, or if b == 0. Notes ----- - All operands must share the same width (number of limbs). - Arithmetic is exact in the fixed-width ring; for valid inputs arising in Shor's post-processing (with q <= b and typical bounds <= b), no wrap-around occurs. """ n_limbs = a.digits.shape[0] assert n_limbs == b.digits.shape[0], "BigInteger widths must match" bi0 = BigInteger.create(0, n_limbs) bi1 = BigInteger.create(1, n_limbs) assert b != bi0, "Denominator b must be non-zero" # Prepare bound and iteration cap if max_den is None: max_den = _bi_all_ones(n_limbs) else: assert max_den.digits.shape[0] == n_limbs, "max_den must match width of a,b" if max_iters is None: max_iters = int(2 * 32 * n_limbs + 4) @jax.jit def _loop(a0: BigInteger, b0: BigInteger, p0: BigInteger, p1: BigInteger, q0: BigInteger, q1: BigInteger, res_p: BigInteger, res_q: BigInteger, done: bool, i: int) -> tuple: """ One step of CF with bound handling. Internal helper for while_loop. """ # Compute quotient and remainder quot = a0 // b0 rem = a0 % b0 # Next convergent (pn/qn) pn = quot * p1 + p0 qn = quot * q1 + q0 # Conditions exact_end = (rem == bi0) # exceed if next denominator would be > max_den exceed = jnp.logical_not(qn <= max_den) # Intermediate convergent if exceed: t = floor((max_den - q0)/q1) # Guard q1==0 (only possible at very first step if max_den < 1) q1_is_zero = (q1 == bi0) t_num = (max_den - q0) t = t_num // jax.lax.cond(q1_is_zero, lambda: bi1, lambda: q1) p_bound = t * p1 + p0 q_bound = t * q1 + q0 # Choose final result when done at this step sel_p = jax.lax.select(exceed, p_bound, pn) sel_q = jax.lax.select(exceed, q_bound, qn) # Update result slots if we finish now res_p_next = jax.lax.select(jnp.logical_or(exceed, exact_end), sel_p, res_p) res_q_next = jax.lax.select(jnp.logical_or(exceed, exact_end), sel_q, res_q) done_next = jnp.logical_or(done, jnp.logical_or(exceed, exact_end)) # Prepare state for next iteration (if not done) a_next = jax.lax.select(done_next, a0, b0) b_next = jax.lax.select(done_next, b0, rem) p0_next = jax.lax.select(done_next, p0, p1) p1_next = jax.lax.select(done_next, p1, pn) q0_next = jax.lax.select(done_next, q0, q1) q1_next = jax.lax.select(done_next, q1, qn) i_next = i + jnp.int32(1) return (a_next, b_next, p0_next, p1_next, q0_next, q1_next, res_p_next, res_q_next, done_next, i_next) def cond_fn(state): a0, b0, p0, p1, q0, q1, res_p, res_q, done, i = state under_cap = (i < jnp.int32(max_iters)) return jnp.logical_and(jnp.logical_not(done), under_cap) # Initial CF state: p[-2]=0, p[-1]=1; q[-2]=1, q[-1]=0 init_state = (a, b, bi0, bi1, bi1, bi0, bi0, bi1, jnp.bool_(False), jnp.int32(0)) final_state = lax.while_loop(cond_fn, _loop, init_state) _, _, _, _, _, _, out_p, out_q, _, _ = final_state return out_p, out_q def bi_shor_recover_denominator(a: BigInteger, b: BigInteger, N_bound: BigInteger | int, max_iters: int | None = None) -> BigInteger: """ Recover the candidate period denominator r from a/b for Shor's algorithm. This computes the best convergent p/q of a/b with q <= N_bound, where N_bound is typically the number to factor (or a small multiple), returning q. Parameters ---------- a : BigInteger Numerator of measurement ratio (0 <= a < b). b : BigInteger Denominator (typically a power of two, b = 2^t). N_bound : BigInteger or int Upper bound for the denominator r (e.g., N). max_iters : int or None, optional Maximum CF steps (safety cap). See `bi_contfrac_best_approx`. Returns ------- BigInteger The recovered denominator r (candidate period). Use with standard Shor post-checks (e.g., verify a close approximation and test that x^r ≡ 1 mod N, handle r even, etc.). Notes ----- - If N_bound is an int, it is promoted to a BigInteger of the same width as a,b. - Returns the denominator of the CF-derived convergent; caller should perform the usual validity checks for Shor (closeness, non-trivial factor conditions, etc.). """ n_limbs = a.digits.shape[0] if isinstance(N_bound, BigInteger): max_den = N_bound assert max_den.digits.shape[0] == n_limbs, "N_bound width must match a,b" else: max_den = BigInteger.create(int(N_bound), n_limbs) p, q = bi_contfrac_best_approx(a, b, max_den=max_den, max_iters=max_iters) return q def bi_contfrac_convergents(a: BigInteger, b: BigInteger, max_terms: int | None = None): """ Generator of convergents p/q for a/b (Python generator; not JAX-traced). Yields successive convergents (p_k, q_k) using the standard recurrence until remainder becomes 0 or `max_terms` is reached. Useful for debugging or non-jitted workflows. For high-performance in-jit use, prefer `bi_contfrac_best_approx`. Parameters ---------- a : BigInteger Numerator. b : BigInteger Denominator (non-zero). max_terms : int or None, optional Maximum number of convergents to yield. Yields ------ (BigInteger, BigInteger) Convergent pairs (p_k, q_k), k = 0, 1, 2, ... """ n_limbs = a.digits.shape[0] assert n_limbs == b.digits.shape[0], "BigInteger widths must match" bi0 = BigInteger.create(0, n_limbs) bi1 = BigInteger.create(1, n_limbs) assert b != bi0, "Denominator b must be non-zero" a0, b0 = a, b p0, p1 = bi0, bi1 q0, q1 = bi1, bi0 k = 0 while True: q = a0 // b0 r = a0 % b0 pn = q * p1 + p0 qn = q * q1 + q0 yield pn, qn if r == bi0: break a0, b0 = b0, r p0, p1 = p1, pn q0, q1 = q1, qn k += 1 if max_terms is not None and k >= max_terms: break