Skip to content

Instantly share code, notes, and snippets.

@RobinLinus
Last active November 17, 2022 23:58
Embed
What would you like to do?
Emulate a Uint32 number type in a subfield of Cairo's `felt` type
# 32-Bit Arithmetic native to Cairo's Finite Field
#
# A collection of operations for 32-bit arithmetic performed in the
# exponents of elements of Cairo's finite field.
#
# The field's modulus is
# 2**251 + 17 * 2**192 + 1 == 2**192 * 5 * 7 * 98714381 * 166848103.
# so it contains multiplicative subgroups of sizes
# 2**192, 2**191, 2**190, ..., 2, 5, 7, 98714381, 166848103, ...
#
# We use elements of the subgroup with 2**32 elements
# to perform computations mod 2**32.
#
#
# Cairo's finite field
#
# The field's modulus p
p = 2**251 + 17*2**192 + 1
# A generator of the field's multiplicative group
# See: https://github.com/starkware-libs/cairo-lang/blob/54d7e92a703b3b5a1e07e9389608178129946efc/src/starkware/cairo/stark_verifier/core/utils.cairo#L5
g_field = 3
#
# The Uint32 type in the subfield of size 2**32
#
# A generator of the 32-bit additive subgroup of F_(p-1)
g_add = (p-1) // 2**32
# A generator of the 32-bit multiplicative subgroup of F_p
g = pow(g_field, g_add, p)
# Convert a scalar to an Uint32 number
def to_uint32(z):
return pow(g, z, p)
# Add two Uint32 numbers
def uint32_add(a, b):
return a * b % p
# Multiply an Uint32 number by a scalar
def uint32_mul_scalar(a,b):
return pow(a, b, p)
# Compute the additive inverse of an Uint32 mod 2**32
def uint32_neg(z):
return uint32_mul_scalar(z, 2**32 - 1)
# Subtract two Uint32 numbers
def uint32_sub(a, b):
return uint32_add(a, uint32_neg(b))
# Flip all bits of an Uint32
def uint32_bitwise_not(a):
ALL_ONES = to_uint32(2**32 - 1)
return uint32_sub(ALL_ONES, a)
# Perform a logical left shift
# Shift all bits of an Uint32 t bits to the left and fill up with zeros on the right
def uint32_left_shift(a, t):
result = a
for _ in range(t):
result = result * result % p
return result
# Generator of the subfield with 2 elements in the field with 2**32 elements "in the exponent"
g_2 = 2**32 // 2
# Compute an Uint32's remainder when divided by 2
def uint32_mod2(x):
return pow(x, g_2, p) # returns g^0=1 or g^(p-1)/2
# Convert an Uint32 number back to a scalar
def from_uint32(z):
result = 0
g_i = g_2
pow_of_2 = 1
# Compute bitwise the discrete logarithm
# using mod 2**i and set every "1" to "0" to
# produce one more trailing zero each step
for i in range(32):
z_i = pow(z, g_i, p)
if z_i != 1:
result = result + pow_of_2
z = uint32_sub(z, to_uint32(pow_of_2))
g_i = g_i // 2
pow_of_2 = pow_of_2 * 2
return result
# Multiply two Uint32 numbers
def uint32_mul(a, b):
return uint32_mul_scalar(a, from_uint32(b))
# Validate that a scalar is in the 32-bit range
# 0 < x < (2**32 - 1)
#
def is_uint32(a):
return a == from_uint32(to_uint32(a))
# Shift all bits t steps to the right. This works only with trailing zeros.
def uint32_right_shift_zeros(z, t):
result = 1
base = 2 ** t
g_i = g_2 // base
pow_of_2 = 1
# Compute bitwise the discrete logarithm as in from_uint32
for i in range(32-t):
z_i = pow(z, g_i, p)
if z_i != 1:
result = uint32_add(result, to_uint32(pow_of_2))
z = uint32_sub(z, to_uint32(pow_of_2 * base))
g_i = g_i // 2
pow_of_2 = pow_of_2 * 2
return result
# Rotate all bits t steps to the right
def uint32_rotate_right(z, t):
mod_t_padded = uint32_mul_scalar(z, 2**(32-t) )
mod_t = uint32_right_shift_zeros(mod_t_padded, 32-t)
div_t = uint32_right_shift_zeros(uint32_sub(z, mod_t), t)
return uint32_add(mod_t_padded, div_t)
#
#
# Tests and sanity checks
#
#
print('\nAddition')
a = to_uint32(42)
b = to_uint32(2**32 - 23)
c = uint32_add(a,b)
c_expected = to_uint32( (42 + 2**32-23) % 2**32 )
print(c, c_expected)
print('\nSubtraction')
a = to_uint32(23)
b = to_uint32(42)
c = uint32_sub(a,b)
c_expected = to_uint32( (2**32 + 23 - 42) % 2**32 )
print(c, c_expected)
print('\nMultiplication by a scalar')
a = to_uint32(23)
b = 42
c = uint32_mul_scalar(a,b)
c_expected = to_uint32( 23*42 )
print(c, c_expected)
print('\nBitwise not')
a = to_uint32(0b00010101)
c = uint32_bitwise_not(a)
c_expected = pow(g, 0b11111111111111111111111111101010, p)
print(c, c_expected)
print('\nShift all bits to the left')
a = to_uint32(0b11111111111111111111111111101010)
t = 3
c = uint32_left_shift(a, t)
c_expected = pow(g, 0b11111111111111111111111101010000, p)
print(c, c_expected)
print('\nCompute modulo 2')
# Even
a = to_uint32(0b11111111111111111111111111101010)
c = uint32_mod2(a)
print('even', c == 1)
# Odd
a = to_uint32(0b11111111111111111111111111101011)
c = uint32_mod2(a)
print('odd ', c != 1)
print('\nEfficient Discrete Logarithm')
i = 1
while i < 21:
z = to_uint32(i+100000)
s = from_uint32(z)
print(s, i, z)
i += 1
print('\nMultiplication')
a = to_uint32(23)
b = to_uint32(42)
c = uint32_mul(a,b)
c_expected = to_uint32( 23*42 )
print(c, c_expected)
print('\nProve a 32-bit range')
x = 424242
x_evil = x + 2**32
print( is_uint32(x), is_uint32(x_evil) )
print('\nShift trailing zeros to the right')
a = to_uint32(0b10001101000001100000000000000000)
c = uint32_right_shift_zeros(a, 14)
print( "{0:b}".format( from_uint32(c)) )
a = to_uint32(0b10000000000000000000000000000000)
c = uint32_right_shift_zeros(a, 29)
print( "{0:b}".format( from_uint32(c)) )
print('\nRotate right')
a = to_uint32(0b11111111000000001111011100110001)
c = uint32_rotate_right(a, 7)
c_expected = 0b01100011111111100000000111101110
print( "{0:b}".format( from_uint32(c)) )
print( "{0:b}".format( c_expected) )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment