Skip to content

Instantly share code, notes, and snippets.

@Strilanc
Created April 24, 2023 04:12
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Strilanc/d08dc245a9661fb9f61359b5f756d9e0 to your computer and use it in GitHub Desktop.
Save Strilanc/d08dc245a9661fb9f61359b5f756d9e0 to your computer and use it in GitHub Desktop.
An implementation of the half-workspace modular multiplication described in https://arxiv.org/pdf/quant-ph/0601097.pdf
import math
from typing import Union
from typing import Tuple
def zalka_imul_mod_low_workspace(*, dst: 'Quint', factor: int, modulus: int) -> None:
"""Performs an inplace modular multiplication using n/2 + O(1) workspace."""
assert factor < modulus
assert math.gcd(factor, modulus) == 1
assert len(dst) >= modulus.bit_length() * 3 // 2 + 2
a, b = pseudo_factor_mod(factor, modulus)
assert a * pow(b, -1, modulus) % modulus == factor
imul_mod_small(dst=dst, factor=a, modulus=modulus)
imul_mod_inv_small(dst=dst, factor=abs(b), modulus=modulus)
if b < 0:
inegate_mod(dst=dst, modulus=modulus)
class MutableIntBuffer:
def __init__(self, val: int):
self.buf = val
class Quint:
"""Plays the role of a quantum integer.
As implemented, it's really more of a mutable integer with convenient reversible operations.
Also, it has a bunch of escape hatches to deal with impedence mismatch in the abstraction.
Main features:
- Slicing gives subviews of the integer's bits, with bit 0 being the least significant.
- Prefers inplace assignment. You can write a ^= b, but a = a ^ b downcasts to int.
- Pretends to be an integer when doing equality checks and comparisons.
- Operations are implicitly mod 2**len(self), and fail if they aren't reversible.
"""
def __init__(self, target: MutableIntBuffer, bit_positions: Union[int, range]):
assert isinstance(target, MutableIntBuffer)
self.target = target
self.bit_positions = range(bit_positions, bit_positions + 1) if isinstance(bit_positions, int) else bit_positions
def __str__(self):
return str(int(self))
def __repr__(self):
return 'Quint:' + str(self)
def inplace_left_rotate(self) -> None:
m = int(self)
s = len(self) - 1
b = (m & 1)
m = (m >> 1) | (b << s)
self._set_val(m)
def inplace_right_rotate(self) -> None:
m = int(self)
s = len(self) - 1
b = m & ((1 << s) - 1)
m = (m >> s) | (b << 1)
self._set_val(m)
def __getitem__(self, item):
return Quint(self.target, self.bit_positions[item])
def __setitem__(self, key, value):
v = Quint(self.target, self.bit_positions[key])
v._set_val(value)
return v
def __bool__(self) -> bool:
return bool(int(self))
def __ixor__(self, other) -> 'Quint':
self._set_val(int(self) ^ other)
return self
def __and__(self, other) -> int:
return int(self) & int(other)
def __rand__(self, other) -> int:
return int(self) & int(other)
def __add__(self, other):
return int(self) + int(other)
def __radd__(self, other):
return int(self) + int(other)
def __mul__(self, other):
return int(self) * int(other)
def __rmul__(self, other):
return int(self) * int(other)
def __or__(self, other) -> int:
return int(self) | int(other)
def __ror__(self, other) -> int:
return int(self) | int(other)
def __gt__(self, other):
return int(self) > int(other)
def __lt__(self, other):
return int(self) < int(other)
def __ge__(self, other):
return int(self) >= int(other)
def __le__(self, other):
return int(self) <= int(other)
def __iadd__(self, other):
self._set_val((int(self) + other) & ~(~0 << len(self)))
return self
def __isub__(self, other):
self._set_val((int(self) - other) & ~(~0 << len(self)))
return self
def __ilshift__(self, other):
assert other >= 0
assert other == 0 or self[-other:] == 0
self._set_val(int(self) << other)
return self
def inplace_right_shift(self, *, amount: int, expected_bottom_bit: bool):
assert amount >= 0
if expected_bottom_bit:
assert self[:amount] == (1 << amount) - 1
else:
assert self[:amount] == 0
self._set_val(int(self) >> amount)
def inplace_left_shift(self, *, amount: int, expected_top_bit: bool, new_bottom_bit: bool):
assert amount >= 0
actual = self[-amount:] if amount > 0 else 0
m = (1 << amount) - 1
expected = m if expected_top_bit else 0
assert actual == expected
new_val = int(self) << amount
if new_bottom_bit:
new_val |= m
self._set_val(new_val)
def __irshift__(self, other):
self.inplace_right_shift(amount=other, expected_bottom_bit=False)
return self
def __xor__(self, other) -> int:
return int(self) ^ int(other)
def __rxor__(self, other) -> int:
return int(self) ^ int(other)
def __rshift__(self, other):
return int(self) >> other
def __lshift__(self, other):
return int(self) << other
def __len__(self) -> int:
return len(self.bit_positions)
def _set_val(self, new_val: int) -> None:
assert (new_val >> len(self)) in [0, ~0]
new_val = int(new_val) & ~(~0 << len(self))
v = self.target.buf
if self.bit_positions.step == 1:
m = ~(~0 << self.bit_positions.stop) & (~0 << self.bit_positions.start)
v &= ~m
v |= new_val << self.bit_positions.start
else:
for k, b in enumerate(self.bit_positions):
if new_val & (1 << k):
v |= 1 << b
else:
v &= ~(1 << b)
self.target.buf = v
def __eq__(self, other) -> bool:
if isinstance(other, (int, Quint, MutableIntBuffer)):
return int(self) == int(other)
return NotImplemented
def __ne__(self, other) -> bool:
return not (self == other)
def signed_val(self) -> int:
v = int(self)
if self[-1]:
v -= 1 << len(self)
return v
def iadd_mod(self, offset: Union[int, 'Quint'], modulus: Union[int, 'Quint']) -> 'Quint':
self._set_val((int(self) + int(offset)) % int(modulus))
return self
def imul_mod(self, factor: Union[int, 'Quint'], modulus: Union[int, 'Quint']) -> 'Quint':
factor = int(factor)
modulus = int(modulus)
assert math.gcd(factor, modulus) == 1
self._set_val((int(self) * factor) % modulus)
return self
def __int__(self) -> int:
if self.bit_positions.step == 1:
return (self.target.buf & (~(~0 << self.bit_positions.stop))) >> self.bit_positions.start
t = 0
for k in range(len(self.bit_positions)):
if self.target.buf & (1 << self.bit_positions[k]):
t |= 1 << k
return t
def alloc_quint(*, val: int = 0, width: int) -> Quint:
return Quint(MutableIntBuffer(val), range(width))
def broadcast(k: Union[int, Quint]) -> int:
"""Turns False into 0 (all bits off) and True into -1 (all bits on)."""
k = int(k)
assert k == 0 or k == 1
return -k
def factors_of_2(v: int) -> int:
p = 0
while v & (1 << p) == 0:
p += 1
return p
def pseudo_factor_mod(f: int, N: int) -> Tuple[int, int]:
"""Returns a pair of values, with half as many digits as N, that inverse multiply to f (mod N).
Args:
f: The number to factor.
N: the modulus.
Returns:
A tuple (a, b) satisfying:
a * inv(b) == f (mod N)
a.bit_length() * 2 <= N.bit_length()
b.bit_length() * 2 <= N.bit_length()
`a` is the forward pseudo-factor
`b` is the backward pseudo-factor.
"""
x, y = 1, 0
u, v = 0, 1
a = f
b = N
pairs = []
while b:
q = a // b
x, u = u, x - q * u
y, v = v, y - q * v
a, b = b, a - q * b
if math.gcd(u, N) == 1:
pairs.append((b, u))
if math.gcd(x, N) == 1:
pairs.append((a, x))
best_pair = min(pairs, key=lambda pair: max(abs(pair[0]).bit_length(), abs(pair[1]).bit_length()))
return best_pair
def inegate_mod(*, dst: Quint, modulus: int) -> None:
v = dst[:modulus.bit_length()]
v += 2**len(v) - modulus - 1
v ^= -1
w = dst[len(v)]
assert w == 0
w ^= v == modulus
v ^= modulus & broadcast(w[-1])
w ^= v == 0
assert w == 0
def imul_mod_small(*, dst: Quint, factor: int, modulus: int) -> None:
"""An inplace modular multiplication.
Requires the target register to have room to temporarily hold the non-modular multiplication.
"""
assert dst < modulus
len_f = factor.bit_length()
assert len(dst) >= modulus.bit_length() + len_f + 1
# Perform dst *= small_factor (non-modular)
p = factors_of_2(factor)
shift = factor >> (p + 1)
for k in range(modulus.bit_length())[::-1]:
if dst[k]:
dst[k + 1:] += shift
dst <<= p
# Perform dst divmod N
dst_div = dst[-1:]
dst_rem = dst[:-1]
for t in range(len_f):
t += 1
dst_div = dst[-t:]
dst_rem = dst[:-t]
threshold = modulus << (len_f - t)
assert dst_div[0] == 0
dst_div[0] ^= dst_rem >= threshold
if dst_div[0]:
dst_rem -= threshold
# Perform del dst_div = -dst_rem * inv(N) % small_factor
for k in range(modulus.bit_length()):
if dst_rem[k]:
dst_div.iadd_mod((1 << k) * pow(modulus, -1, factor), factor)
assert dst_div == 0
assert dst == dst_rem
def imul_mod_inv_small(*, dst: Quint, factor: int, modulus: int) -> None:
"""A reversed inplace modular multiplication.
Multiplies by the multiplicative inverse of the given factor.
Requires the target register to have room to temporarily hold the non-modular multiplication.
"""
assert dst < modulus
len_f = factor.bit_length()
assert len(dst) >= modulus.bit_length() + len_f + 1
# Perform del dst_div = -dst_rem * inv(N) % small_factor
dst_div = dst[-len_f:]
dst_rem = dst[:-len_f]
assert dst_div == 0
assert dst == dst_rem
for k in range(modulus.bit_length())[::-1]:
if dst_rem[k]:
dst_div.iadd_mod(-(1 << k) * pow(modulus, -1, factor), factor)
# Perform dst divmod N
for t in range(len_f)[::-1]:
t += 1
dst_div = dst[-t:]
dst_rem = dst[:-t]
threshold = modulus << (len_f - t)
if dst_div[0]:
dst_rem += threshold
dst_div[0] ^= dst_rem >= threshold
assert dst_div[0] == 0
# Perform dst *= small_factor (non-modular)
p = factors_of_2(factor)
dst >>= p
shift = factor >> (p + 1)
for k in range(modulus.bit_length()):
if dst[k]:
dst[k + 1:] -= shift
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment