Created
April 24, 2023 04:12
-
-
Save Strilanc/d08dc245a9661fb9f61359b5f756d9e0 to your computer and use it in GitHub Desktop.
An implementation of the half-workspace modular multiplication described in https://arxiv.org/pdf/quant-ph/0601097.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from typing import Union | |
from typing import Tuple | |
def zalka_imul_mod_low_workspace(*, dst: 'Quint', factor: int, modulus: int) -> None: | |
"""Performs an inplace modular multiplication using n/2 + O(1) workspace.""" | |
assert factor < modulus | |
assert math.gcd(factor, modulus) == 1 | |
assert len(dst) >= modulus.bit_length() * 3 // 2 + 2 | |
a, b = pseudo_factor_mod(factor, modulus) | |
assert a * pow(b, -1, modulus) % modulus == factor | |
imul_mod_small(dst=dst, factor=a, modulus=modulus) | |
imul_mod_inv_small(dst=dst, factor=abs(b), modulus=modulus) | |
if b < 0: | |
inegate_mod(dst=dst, modulus=modulus) | |
class MutableIntBuffer: | |
def __init__(self, val: int): | |
self.buf = val | |
class Quint: | |
"""Plays the role of a quantum integer. | |
As implemented, it's really more of a mutable integer with convenient reversible operations. | |
Also, it has a bunch of escape hatches to deal with impedence mismatch in the abstraction. | |
Main features: | |
- Slicing gives subviews of the integer's bits, with bit 0 being the least significant. | |
- Prefers inplace assignment. You can write a ^= b, but a = a ^ b downcasts to int. | |
- Pretends to be an integer when doing equality checks and comparisons. | |
- Operations are implicitly mod 2**len(self), and fail if they aren't reversible. | |
""" | |
def __init__(self, target: MutableIntBuffer, bit_positions: Union[int, range]): | |
assert isinstance(target, MutableIntBuffer) | |
self.target = target | |
self.bit_positions = range(bit_positions, bit_positions + 1) if isinstance(bit_positions, int) else bit_positions | |
def __str__(self): | |
return str(int(self)) | |
def __repr__(self): | |
return 'Quint:' + str(self) | |
def inplace_left_rotate(self) -> None: | |
m = int(self) | |
s = len(self) - 1 | |
b = (m & 1) | |
m = (m >> 1) | (b << s) | |
self._set_val(m) | |
def inplace_right_rotate(self) -> None: | |
m = int(self) | |
s = len(self) - 1 | |
b = m & ((1 << s) - 1) | |
m = (m >> s) | (b << 1) | |
self._set_val(m) | |
def __getitem__(self, item): | |
return Quint(self.target, self.bit_positions[item]) | |
def __setitem__(self, key, value): | |
v = Quint(self.target, self.bit_positions[key]) | |
v._set_val(value) | |
return v | |
def __bool__(self) -> bool: | |
return bool(int(self)) | |
def __ixor__(self, other) -> 'Quint': | |
self._set_val(int(self) ^ other) | |
return self | |
def __and__(self, other) -> int: | |
return int(self) & int(other) | |
def __rand__(self, other) -> int: | |
return int(self) & int(other) | |
def __add__(self, other): | |
return int(self) + int(other) | |
def __radd__(self, other): | |
return int(self) + int(other) | |
def __mul__(self, other): | |
return int(self) * int(other) | |
def __rmul__(self, other): | |
return int(self) * int(other) | |
def __or__(self, other) -> int: | |
return int(self) | int(other) | |
def __ror__(self, other) -> int: | |
return int(self) | int(other) | |
def __gt__(self, other): | |
return int(self) > int(other) | |
def __lt__(self, other): | |
return int(self) < int(other) | |
def __ge__(self, other): | |
return int(self) >= int(other) | |
def __le__(self, other): | |
return int(self) <= int(other) | |
def __iadd__(self, other): | |
self._set_val((int(self) + other) & ~(~0 << len(self))) | |
return self | |
def __isub__(self, other): | |
self._set_val((int(self) - other) & ~(~0 << len(self))) | |
return self | |
def __ilshift__(self, other): | |
assert other >= 0 | |
assert other == 0 or self[-other:] == 0 | |
self._set_val(int(self) << other) | |
return self | |
def inplace_right_shift(self, *, amount: int, expected_bottom_bit: bool): | |
assert amount >= 0 | |
if expected_bottom_bit: | |
assert self[:amount] == (1 << amount) - 1 | |
else: | |
assert self[:amount] == 0 | |
self._set_val(int(self) >> amount) | |
def inplace_left_shift(self, *, amount: int, expected_top_bit: bool, new_bottom_bit: bool): | |
assert amount >= 0 | |
actual = self[-amount:] if amount > 0 else 0 | |
m = (1 << amount) - 1 | |
expected = m if expected_top_bit else 0 | |
assert actual == expected | |
new_val = int(self) << amount | |
if new_bottom_bit: | |
new_val |= m | |
self._set_val(new_val) | |
def __irshift__(self, other): | |
self.inplace_right_shift(amount=other, expected_bottom_bit=False) | |
return self | |
def __xor__(self, other) -> int: | |
return int(self) ^ int(other) | |
def __rxor__(self, other) -> int: | |
return int(self) ^ int(other) | |
def __rshift__(self, other): | |
return int(self) >> other | |
def __lshift__(self, other): | |
return int(self) << other | |
def __len__(self) -> int: | |
return len(self.bit_positions) | |
def _set_val(self, new_val: int) -> None: | |
assert (new_val >> len(self)) in [0, ~0] | |
new_val = int(new_val) & ~(~0 << len(self)) | |
v = self.target.buf | |
if self.bit_positions.step == 1: | |
m = ~(~0 << self.bit_positions.stop) & (~0 << self.bit_positions.start) | |
v &= ~m | |
v |= new_val << self.bit_positions.start | |
else: | |
for k, b in enumerate(self.bit_positions): | |
if new_val & (1 << k): | |
v |= 1 << b | |
else: | |
v &= ~(1 << b) | |
self.target.buf = v | |
def __eq__(self, other) -> bool: | |
if isinstance(other, (int, Quint, MutableIntBuffer)): | |
return int(self) == int(other) | |
return NotImplemented | |
def __ne__(self, other) -> bool: | |
return not (self == other) | |
def signed_val(self) -> int: | |
v = int(self) | |
if self[-1]: | |
v -= 1 << len(self) | |
return v | |
def iadd_mod(self, offset: Union[int, 'Quint'], modulus: Union[int, 'Quint']) -> 'Quint': | |
self._set_val((int(self) + int(offset)) % int(modulus)) | |
return self | |
def imul_mod(self, factor: Union[int, 'Quint'], modulus: Union[int, 'Quint']) -> 'Quint': | |
factor = int(factor) | |
modulus = int(modulus) | |
assert math.gcd(factor, modulus) == 1 | |
self._set_val((int(self) * factor) % modulus) | |
return self | |
def __int__(self) -> int: | |
if self.bit_positions.step == 1: | |
return (self.target.buf & (~(~0 << self.bit_positions.stop))) >> self.bit_positions.start | |
t = 0 | |
for k in range(len(self.bit_positions)): | |
if self.target.buf & (1 << self.bit_positions[k]): | |
t |= 1 << k | |
return t | |
def alloc_quint(*, val: int = 0, width: int) -> Quint: | |
return Quint(MutableIntBuffer(val), range(width)) | |
def broadcast(k: Union[int, Quint]) -> int: | |
"""Turns False into 0 (all bits off) and True into -1 (all bits on).""" | |
k = int(k) | |
assert k == 0 or k == 1 | |
return -k | |
def factors_of_2(v: int) -> int: | |
p = 0 | |
while v & (1 << p) == 0: | |
p += 1 | |
return p | |
def pseudo_factor_mod(f: int, N: int) -> Tuple[int, int]: | |
"""Returns a pair of values, with half as many digits as N, that inverse multiply to f (mod N). | |
Args: | |
f: The number to factor. | |
N: the modulus. | |
Returns: | |
A tuple (a, b) satisfying: | |
a * inv(b) == f (mod N) | |
a.bit_length() * 2 <= N.bit_length() | |
b.bit_length() * 2 <= N.bit_length() | |
`a` is the forward pseudo-factor | |
`b` is the backward pseudo-factor. | |
""" | |
x, y = 1, 0 | |
u, v = 0, 1 | |
a = f | |
b = N | |
pairs = [] | |
while b: | |
q = a // b | |
x, u = u, x - q * u | |
y, v = v, y - q * v | |
a, b = b, a - q * b | |
if math.gcd(u, N) == 1: | |
pairs.append((b, u)) | |
if math.gcd(x, N) == 1: | |
pairs.append((a, x)) | |
best_pair = min(pairs, key=lambda pair: max(abs(pair[0]).bit_length(), abs(pair[1]).bit_length())) | |
return best_pair | |
def inegate_mod(*, dst: Quint, modulus: int) -> None: | |
v = dst[:modulus.bit_length()] | |
v += 2**len(v) - modulus - 1 | |
v ^= -1 | |
w = dst[len(v)] | |
assert w == 0 | |
w ^= v == modulus | |
v ^= modulus & broadcast(w[-1]) | |
w ^= v == 0 | |
assert w == 0 | |
def imul_mod_small(*, dst: Quint, factor: int, modulus: int) -> None: | |
"""An inplace modular multiplication. | |
Requires the target register to have room to temporarily hold the non-modular multiplication. | |
""" | |
assert dst < modulus | |
len_f = factor.bit_length() | |
assert len(dst) >= modulus.bit_length() + len_f + 1 | |
# Perform dst *= small_factor (non-modular) | |
p = factors_of_2(factor) | |
shift = factor >> (p + 1) | |
for k in range(modulus.bit_length())[::-1]: | |
if dst[k]: | |
dst[k + 1:] += shift | |
dst <<= p | |
# Perform dst divmod N | |
dst_div = dst[-1:] | |
dst_rem = dst[:-1] | |
for t in range(len_f): | |
t += 1 | |
dst_div = dst[-t:] | |
dst_rem = dst[:-t] | |
threshold = modulus << (len_f - t) | |
assert dst_div[0] == 0 | |
dst_div[0] ^= dst_rem >= threshold | |
if dst_div[0]: | |
dst_rem -= threshold | |
# Perform del dst_div = -dst_rem * inv(N) % small_factor | |
for k in range(modulus.bit_length()): | |
if dst_rem[k]: | |
dst_div.iadd_mod((1 << k) * pow(modulus, -1, factor), factor) | |
assert dst_div == 0 | |
assert dst == dst_rem | |
def imul_mod_inv_small(*, dst: Quint, factor: int, modulus: int) -> None: | |
"""A reversed inplace modular multiplication. | |
Multiplies by the multiplicative inverse of the given factor. | |
Requires the target register to have room to temporarily hold the non-modular multiplication. | |
""" | |
assert dst < modulus | |
len_f = factor.bit_length() | |
assert len(dst) >= modulus.bit_length() + len_f + 1 | |
# Perform del dst_div = -dst_rem * inv(N) % small_factor | |
dst_div = dst[-len_f:] | |
dst_rem = dst[:-len_f] | |
assert dst_div == 0 | |
assert dst == dst_rem | |
for k in range(modulus.bit_length())[::-1]: | |
if dst_rem[k]: | |
dst_div.iadd_mod(-(1 << k) * pow(modulus, -1, factor), factor) | |
# Perform dst divmod N | |
for t in range(len_f)[::-1]: | |
t += 1 | |
dst_div = dst[-t:] | |
dst_rem = dst[:-t] | |
threshold = modulus << (len_f - t) | |
if dst_div[0]: | |
dst_rem += threshold | |
dst_div[0] ^= dst_rem >= threshold | |
assert dst_div[0] == 0 | |
# Perform dst *= small_factor (non-modular) | |
p = factors_of_2(factor) | |
dst >>= p | |
shift = factor >> (p + 1) | |
for k in range(modulus.bit_length()): | |
if dst[k]: | |
dst[k + 1:] -= shift |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment