Skip to content

Instantly share code, notes, and snippets.

@RobinLinus
Last active November 18, 2022 00:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RobinLinus/bc9662e696e43b3a4a4c8ca0d3cfc45e to your computer and use it in GitHub Desktop.
Save RobinLinus/bc9662e696e43b3a4a4c8ca0d3cfc45e to your computer and use it in GitHub Desktop.
%builtins range_check
from starkware.cairo.common.registers import get_ap, get_fp_and_pc
# from starkware.cairo.common.pow import pow
from starkware.cairo.common.registers import get_label_location
from starkware.cairo.common.math import assert_le
# P = 2**251 + 17*2**192 + 1
const G = 3
# Generator of the 32-bit subgroup
# G ** ( (P-1) / 2**32 )
const G_32 = 0x50732ed0be8ced2fea566de48221e1a719252eb81c43de5c129d0f1d3ce8992
func pow{range_check_ptr}(base, exp) -> (res : felt):
struct LoopLocals:
member bit : felt
member temp0 : felt
member res : felt
member base : felt
member exp : felt
end
if exp == 0:
return (1)
end
let initial_locs : LoopLocals* = cast(fp - 2, LoopLocals*)
initial_locs.res = 1; ap++
initial_locs.base = base; ap++
initial_locs.exp = exp; ap++
loop:
let prev_locs : LoopLocals* = cast(ap - LoopLocals.SIZE, LoopLocals*)
let locs : LoopLocals* = cast(ap, LoopLocals*)
locs.base = prev_locs.base * prev_locs.base; ap++
%{ ids.locs.bit = (ids.prev_locs.exp % PRIME) & 1 %}
jmp odd if locs.bit != 0; ap++
even:
locs.exp = prev_locs.exp / 2; ap++
locs.res = prev_locs.res; ap++
# exp cannot be 0 here.
static_assert ap + 1 == locs + LoopLocals.SIZE
jmp loop; ap++
odd:
locs.temp0 = prev_locs.exp - 1
locs.exp = locs.temp0 / 2; ap++
locs.res = prev_locs.res * prev_locs.base; ap++
static_assert ap + 1 == locs + LoopLocals.SIZE
jmp loop if locs.exp != 0; ap++
# Cap the number of steps
# A cap to 32 steps implies a 32-bit range proof
let (__ap__) = get_ap()
let (__fp__, _) = get_fp_and_pc()
let n_steps = (__ap__ - cast(initial_locs, felt*)) / LoopLocals.SIZE - 1
assert_le(n_steps, 32)
return (res=locs.res)
end
# Returns a pointer to the powers of two: [1, 2, 4, 8, 16, 32, ...].
func get_pow_of_2() -> (data: felt*) :
let (pow_of_2_address) = get_label_location(pow_of_2_start)
return (data = cast(pow_of_2_address, felt*))
pow_of_2_start:
dw 2**0
dw 2**1
dw 2**2
dw 2**3
dw 2**4
dw 2**5
dw 2**6
dw 2**7
dw 2**8
dw 2**9
dw 2**10
dw 2**11
dw 2**12
dw 2**13
dw 2**14
dw 2**15
dw 2**16
dw 2**17
dw 2**18
dw 2**19
dw 2**20
dw 2**21
dw 2**22
dw 2**23
dw 2**24
dw 2**25
dw 2**26
dw 2**27
dw 2**28
dw 2**29
dw 2**30
dw 2**31
dw 2**32
end
func pow_pow2(z, t) -> (res):
loop:
tempvar z = z * z
tempvar t = t - 1
jmp loop if t != 0; ap++
return (z)
end
func exp_32{range_check_ptr}(exp) -> (res: felt):
return pow(G_32, exp)
end
func log_32{range_check_ptr}(z) -> (result):
alloc_locals
local result
%{
G = int(ids.G_32)
z = int(ids.z)
g_i = 2**31
bit = 1
result = 0
for i in range(32):
# z_i = g_i * z
z_i = pow(z, g_i, PRIME)
if z_i != 1:
result = result + bit
# z -= bit
z = z * pow( pow(G, bit, PRIME), 2**32 - 1, PRIME )
g_i = g_i // 2
# bit <<= 1
bit = bit + bit
ids.result = result
%}
let (exp_res_32) = exp_32(result)
assert z = exp_res_32
return (result)
end
func inverse(z)->(res):
alloc_locals
local res
%{
ids.res = pow(ids.z, PRIME-2, PRIME)
%}
assert res * z = 1
return (res)
end
# z <<< n
func lrot_32{range_check_ptr}(z, pow_2_n, t) -> (res):
alloc_locals
let pow_2_32_n = 2 ** 32 / pow_2_n
# (G_32 ** x) ** 2 ** (32 - n)
# let (lshift_z_n) = pow_pow2(z, t)
let (lshift_z_n) = pow(z, pow_2_n)
# <<<
# LOG_G_32( (G_32 ** x) ** 2 ** (32 - n) )
let (_lshift_z_n) = log_32(lshift_z_n)
# >>>
let (rshift_z_n) = pow(G_32, _lshift_z_n / pow_2_n)
let (inv_rshift_z_n) = inverse(rshift_z_n)
# <<<
let (_z_inv_rshift_z_n) = log_32(z * inv_rshift_z_n)
# return (lshift_z_n * z_inv_rshift_z_32_n_div_pow_2_n)
return (_lshift_z_n + _z_inv_rshift_z_n / pow_2_32_n)
end
# k: scalar
# z: uint32
# w: uint32
# pow(z, 2**32 - 1) <=> -z
# pow(z, k) <=> k*z
# z * w <=> z+w
# pow( z, 2**32 - 1) <=> 1/z
# pow(g^x, 2**32 - 1) <=> g^-x
func main{range_check_ptr}():
alloc_locals
let (pow_of_2) = get_pow_of_2()
let (ffffaacc) = exp_32(0xFFFFAACC)
# <<<
let (lrot) = lrot_32(ffffaacc, [pow_of_2 + 4], 4 )
%{ print(hex(ids.lrot)) %}
let (lrot) = lrot_32(ffffaacc, [pow_of_2 + 8], 8 )
%{ print(hex(ids.lrot)) %}
let (lrot) = lrot_32(ffffaacc, [pow_of_2 + 12], 12 )
%{ print(hex(ids.lrot)) %}
let (lrot) = lrot_32(ffffaacc, [pow_of_2 + 16], 16 )
%{ print(hex(ids.lrot)) %}
let (lrot) = lrot_32(ffffaacc, [pow_of_2 + 20], 20 )
%{ print(hex(ids.lrot)) %}
let (lrot) = lrot_32(ffffaacc, [pow_of_2 + 24], 24 )
%{ print(hex(ids.lrot)) %}
let (lrot) = lrot_32(ffffaacc, [pow_of_2 + 28], 28 )
%{ print(hex(ids.lrot)) %}
# # 0xFFFFFFFF + 1 = 0 mod 2**32
# let (z) = exp_32(2 ** 32 - 1)
# let (value) = log_32(z * G_32)
# %{ print(hex(ids.value)) %}
# # 0x00010000 * 0x00010000 = 0 mod 2**32
# let (z) = exp_32(2 ** 16)
# let (z) = pow(z, 2 ** 16)
# let (value) = log_32(z)
# %{ print(hex(ids.value)) %}
# # This should fail for our implicit range proofs to be correct!
# # pow(2, 2**32+1, p)
return ()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment