Skip to content

Instantly share code, notes, and snippets.

@fubuloubu
Last active June 29, 2020 00:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fubuloubu/495ccb9d0ee6681aa11ff180b4b9d33e to your computer and use it in GitHub Desktop.
Save fubuloubu/495ccb9d0ee6681aa11ff180b4b9d33e to your computer and use it in GitHub Desktop.
Implementation of overflow-safe version of exponentiation
# Implementation of overflow-safe version of exponentiation
# Prototyped for the EVM environment of Vyper
# from https://en.wikipedia.org/wiki/Exponentiation_by_squaring
import math
import pytest
from hypothesis import given, strategies as st, settings
global max_rounds
max_rounds = 0
def power(a: int, b: int) -> int:
# Easy cases
# TODO: Adjust for EVM oddities
if a == 0 and b != 0:
return 0
if a == 1 or b == 0:
return 1
if b == 1:
return a
if a == -1:
return 1 if b % 2 == 0 else -1
if b < 0 or b >= 256: # Sanity check on arg
raise ValueError
x = a
n = b
y = 1
global max_rounds # For keeping track of O(log(n)) claim
num_rounds = 0
# TODO: Adjust for EVM oddities
while n > 1:
# Overflow check on x ** 2
# NOTE: x ** 2 < -(2 ** 127) is impossible
if x ** 2 >= 2 ** 256:
raise ValueError
# Overflow check on x * y
if x * y < -(2 ** 127) or x * y >= 2 ** 256:
raise ValueError
if n % 2 == 0: # n is even
x = x ** 2
n = n // 2
else:
y = x * y
x = x ** 2
n = (n - 1) // 2
if num_rounds > max_rounds:
max_rounds = num_rounds
num_rounds += 1
# Overflow check on x * y
if x * y < -(2 ** 127) or x * y >= 2 ** 256:
raise ValueError
return x * y
# Adapt base strategy to be reasonable with given value of power_st produces
# NOTE: Still allow some overflow/underflow cases, but make it more balanced
@st.composite
def base_and_power(draw, n=st.integers(min_value=0, max_value=256)): # noqa: B008
n = draw(n)
x = draw(
st.integers(
# pulls in-range number >50% of the time (50% + 2 / 257 chance)
min_value=-round(2 * (n ** (math.log(2 ** 127, n) / n))) if n > 1 else -(2 ** 127),
# pulls in-range number >50% of the time (50% + 2 / 257 chance)
max_value=round(2 * (n ** (math.log(2 ** 256, n) / n))) if n > 1 else 2 ** 256 - 1,
) # pulls in-range number >50% * >50% = >25% of the time
)
return (x, n)
@given(xn=base_and_power())
@settings(max_examples=1000000)
def test_power(xn):
x, n = xn
if x ** n < -(2 ** 127) or x ** n >= 2 ** 256:
with pytest.raises(ValueError):
power(x, n)
else:
# TODO: Adjust for EVM oddities
assert power(x, n) == x ** n
global max_rounds # For keeping track of O(log(n)) claim
assert max_rounds <= 8 # log_2(256) = 8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment