Skip to content

Instantly share code, notes, and snippets.

@rot256
Last active July 8, 2024 12:05
Show Gist options
  • Select an option

  • Save rot256/1a18689cdb543930c1ba8a893066cf0f to your computer and use it in GitHub Desktop.

Select an option

Save rot256/1a18689cdb543930c1ba8a893066cf0f to your computer and use it in GitHub Desktop.
Symbolically Evaluate/Simplify Barettenberg Custom Gate Expressions
zero_idx = 0
scaling_factor = 1
undefined_var = 0
AUX_LIMB_ACCUMULATE_1 = "AUX_LIMB_ACCUMULATE_1"
AUX_LIMB_ACCUMULATE_2 = "AUX_LIMB_ACCUMULATE_2"
AUX_NON_NATIVE_FIELD_1 = "AUX_NON_NATIVE_FIELD_1"
AUX_NON_NATIVE_FIELD_2 = "AUX_NON_NATIVE_FIELD_2"
AUX_NON_NATIVE_FIELD_3 = "AUX_NON_NATIVE_FIELD_3"
AUX_NONE = "AUX_NONE"
LIMB_SHIFT = var('LIMB_SHIFT')
LIMB_SIZE = var('LIMB_SIZE')
SUBLIMB_SHIFT = var('SUBLIMB_SHIFT')
def undefined():
global undefined_var
undefined_var += 1
return var('undefined' + str(undefined_var))
class Column:
def __init__(self):
self.row = []
def emplace_back(self, v):
self.row.append(v)
def __getitem__(self, i):
try:
return self.row[i]
except IndexError:
return undefined()
class Block:
def __init__(self):
self.col_w_1 = Column()
self.col_w_2 = Column()
self.col_w_3 = Column()
self.col_w_4 = Column()
self.col_q_m = Column()
self.col_q_1 = Column()
self.col_q_2 = Column()
self.col_q_3 = Column()
self.col_q_4 = Column()
self.col_q_c = Column()
self.col_q_elliptic = Column()
self.col_q_arith = Column()
self.col_q_delta_range = Column()
self.col_q_lookup_type = Column()
self.col_aux = Column()
self.accumulators = []
def q_aux(self):
return self.col_aux
def w_1(self):
return self.col_w_1
def w_2(self):
return self.col_w_2
def w_3(self):
return self.col_w_3
def w_4(self):
return self.col_w_4
def q_m(self):
return self.col_q_m
def q_1(self):
return self.col_q_1
def q_2(self):
return self.col_q_2
def q_3(self):
return self.col_q_3
def q_4(self):
return self.col_q_4
def q_c(self):
return self.col_q_c
def q_arith(self):
return self.col_q_arith
def q_delta_range(self):
return self.col_q_delta_range
def q_lookup_type(self):
return self.col_q_lookup_type
def q_elliptic(self):
return self.col_q_elliptic
def populate_wires(self, w_1, w_2, w_3, w_4):
self.col_w_1.emplace_back(w_1)
self.col_w_2.emplace_back(w_2)
self.col_w_3.emplace_back(w_3)
self.col_w_4.emplace_back(w_4)
def columns(self):
return [
self.col_q_m,
self.col_q_1,
self.col_q_2,
self.col_q_3,
self.col_q_4,
self.col_q_c,
self.col_q_arith,
self.col_q_delta_range,
self.col_q_lookup_type,
self.col_q_elliptic,
self.col_aux
]
def max_len(self):
return max([len(col.row) for col in self.columns()])
def pad_additional(self):
max_len = self.max_len()
for col in self.columns():
while len(col.row) < max_len:
col.emplace_back(0)
def rows(self):
return len(self.col_w_1.row)
def expr_aux(self):
for i in range(self.rows()):
w_1 = self.w_1()[i]
w_2 = self.w_2()[i]
w_3 = self.w_3()[i]
w_4 = self.w_4()[i]
q_m = self.q_m()[i]
q_1 = self.q_1()[i]
q_2 = self.q_2()[i]
q_3 = self.q_3()[i]
q_4 = self.q_4()[i]
q_c = self.q_c()[i]
q_arith = self.q_arith()[i]
q_aux = self.q_aux()[i]
w_1_shift = self.w_1()[i + 1]
w_2_shift = self.w_2()[i + 1]
w_3_shift = self.w_3()[i + 1]
w_4_shift = self.w_4()[i + 1]
limb_subproduct = w_1 * w_2_shift + w_1_shift * w_2
non_native_field_gate_2 = (w_1 * w_4 + w_2 * w_3 - w_3_shift)
non_native_field_gate_2 *= LIMB_SIZE
non_native_field_gate_2 -= w_4_shift
non_native_field_gate_2 += limb_subproduct
non_native_field_gate_2 *= q_4
limb_subproduct *= LIMB_SIZE
limb_subproduct += (w_1_shift * w_2_shift)
non_native_field_gate_1 = limb_subproduct
non_native_field_gate_1 -= (w_3 + w_4)
non_native_field_gate_1 *= q_3
non_native_field_gate_3 = limb_subproduct
non_native_field_gate_3 += w_4
non_native_field_gate_3 -= (w_3_shift + w_4_shift)
non_native_field_gate_3 *= q_m
non_native_field_identity = non_native_field_gate_1 + non_native_field_gate_2 + non_native_field_gate_3
non_native_field_identity *= q_2
limb_accumulator_1 = w_2_shift * SUBLIMB_SHIFT
limb_accumulator_1 += w_1_shift;
limb_accumulator_1 *= SUBLIMB_SHIFT
limb_accumulator_1 += w_3
limb_accumulator_1 *= SUBLIMB_SHIFT
limb_accumulator_1 += w_2
limb_accumulator_1 *= SUBLIMB_SHIFT
limb_accumulator_1 += w_1
limb_accumulator_1 -= w_4
limb_accumulator_1 *= q_4
limb_accumulator_2 = w_3_shift * SUBLIMB_SHIFT
limb_accumulator_2 += w_2_shift
limb_accumulator_2 *= SUBLIMB_SHIFT
limb_accumulator_2 += w_1_shift
limb_accumulator_2 *= SUBLIMB_SHIFT
limb_accumulator_2 += w_4
limb_accumulator_2 *= SUBLIMB_SHIFT
limb_accumulator_2 += w_3
limb_accumulator_2 -= w_4_shift
limb_accumulator_2 *= q_m
limb_accumulator_identity = limb_accumulator_1 + limb_accumulator_2
limb_accumulator_identity *= q_3
q_aux_by_scaling = q_aux * scaling_factor
non_native_field_identity + limb_accumulator_identity
auxiliary_identity = non_native_field_identity + limb_accumulator_identity
auxiliary_identity *= q_aux_by_scaling
self.accumulate(i, 0, auxiliary_identity)
def accumulate(self, row, j, value):
if len(self.accumulators) <= row:
self.accumulators.append({})
return self.accumulate(row, j, value)
try:
self.accumulators[row][j] += value
except KeyError:
self.accumulators[row][j] = value
def expr_arith(self):
for i in range(self.rows()):
# alias
q_l = self.q_1()[i]
q_r = self.q_2()[i]
q_o = self.q_3()[i]
q_m = self.q_m()[i]
q_4 = self.q_4()[i]
q_c = self.q_c()[i]
w_l = self.w_1()[i]
w_r = self.w_2()[i]
w_o = self.w_3()[i]
w_4 = self.w_4()[i]
q_arith = self.q_arith()[i]
neg_half = 1 / (-2)
w_4_shift = self.w_4()[i + 1]
w_l_shift = self.w_1()[i + 1]
# constraint 1
tmp = (q_arith - 3) * (q_m * w_r * w_l) * neg_half
tmp += (q_l * w_l) + (q_r * w_r) + (q_o * w_o) + (q_4 * w_4) + q_c
tmp += (q_arith - 1) * w_4_shift
tmp *= q_arith
tmp *= scaling_factor
self.accumulate(i, 0, tmp)
# constraint 2
tmp = w_l + w_4 - w_l_shift + q_m
tmp *= (q_arith - 2)
tmp *= (q_arith - 1)
tmp *= q_arith
tmp *= scaling_factor
self.accumulate(i, 1, tmp)
def expr(self):
# check length
for col in self.columns():
assert len(col.row) == self.rows()
self.expr_aux()
self.expr_arith()
exprs = []
for i in range(self.rows()):
rels = []
for expr in self.accumulators[i].values():
expr = expr.full_simplify()
if expr != 0:
rels.append(expr == 0)
exprs += rels
return exprs
def apply_aux_selectors(block, aux_type):
block.q_aux().emplace_back(0 if aux_type == AUX_NONE else 1)
block.q_delta_range().emplace_back(0)
block.q_lookup_type().emplace_back(0)
block.q_elliptic().emplace_back(0)
if aux_type == AUX_NONE:
block.q_1().emplace_back(0)
block.q_2().emplace_back(0)
block.q_3().emplace_back(0)
block.q_4().emplace_back(0)
block.q_m().emplace_back(0)
block.q_c().emplace_back(0)
block.q_arith().emplace_back(0)
elif aux_type == AUX_LIMB_ACCUMULATE_1:
block.q_1().emplace_back(0)
block.q_2().emplace_back(0)
block.q_3().emplace_back(1)
block.q_4().emplace_back(1)
block.q_m().emplace_back(0)
block.q_c().emplace_back(0)
block.q_arith().emplace_back(0)
elif aux_type == AUX_LIMB_ACCUMULATE_2:
block.q_1().emplace_back(0)
block.q_2().emplace_back(0)
block.q_3().emplace_back(1)
block.q_4().emplace_back(0)
block.q_m().emplace_back(1)
block.q_c().emplace_back(0)
block.q_arith().emplace_back(0)
elif aux_type == AUX_NON_NATIVE_FIELD_1:
block.q_1().emplace_back(0)
block.q_2().emplace_back(1)
block.q_3().emplace_back(1)
block.q_4().emplace_back(0)
block.q_m().emplace_back(0)
block.q_c().emplace_back(0)
block.q_arith().emplace_back(0)
elif aux_type == AUX_NON_NATIVE_FIELD_2:
block.q_1().emplace_back(0)
block.q_2().emplace_back(1)
block.q_3().emplace_back(0)
block.q_4().emplace_back(1)
block.q_m().emplace_back(0)
block.q_c().emplace_back(0)
block.q_arith().emplace_back(0)
elif aux_type == AUX_NON_NATIVE_FIELD_3:
block.q_1().emplace_back(0)
block.q_2().emplace_back(1)
block.q_3().emplace_back(0)
block.q_4().emplace_back(0)
block.q_m().emplace_back(1)
block.q_c().emplace_back(0)
block.q_arith().emplace_back(0)
else:
raise ValueError("Invalid aux type")
block.pad_additional()
def create_big_add_gate(blocks, inp, include_next_gate_w_4):
blocks.populate_wires(inp['a'], inp['b'], inp['c'], inp['d'])
blocks.q_m().emplace_back(0)
blocks.q_1().emplace_back(inp['a_scaling'])
blocks.q_2().emplace_back(inp['b_scaling'])
blocks.q_3().emplace_back(inp['c_scaling'])
blocks.q_c().emplace_back(inp['const_scaling'])
blocks.q_arith().emplace_back(2 if include_next_gate_w_4 else 1)
blocks.q_4().emplace_back(inp['d_scaling'])
blocks.q_delta_range().emplace_back(0)
blocks.q_lookup_type().emplace_back(0)
blocks.q_elliptic().emplace_back(0)
blocks.q_aux().emplace_back(0)
blocks.pad_additional()
def create_big_mul_gate(blocks, inp):
blocks.populate_wires(inp['a'], inp['b'], inp['c'], inp['d'])
blocks.q_m().emplace_back(inp['mul_scaling'])
blocks.q_1().emplace_back(inp['a_scaling'])
blocks.q_2().emplace_back(inp['b_scaling'])
blocks.q_3().emplace_back(inp['c_scaling'])
blocks.q_c().emplace_back(inp['const_scaling'])
blocks.q_arith().emplace_back(1)
blocks.q_4().emplace_back(inp['d_scaling'])
blocks.q_delta_range().emplace_back(0)
blocks.q_lookup_type().emplace_back(0)
blocks.q_elliptic().emplace_back(0)
blocks.q_aux().emplace_back(0)
blocks.pad_additional()
def create_dummy_gate(blocks, w1, w2, w3, w4):
blocks.populate_wires(w1, w2, w3, w4)
blocks.q_m().emplace_back(0)
blocks.q_1().emplace_back(0)
blocks.q_2().emplace_back(0)
blocks.q_3().emplace_back(0)
blocks.q_c().emplace_back(0)
blocks.q_arith().emplace_back(0)
blocks.q_4().emplace_back(0)
blocks.q_delta_range().emplace_back(0)
blocks.q_elliptic().emplace_back(0)
blocks.q_lookup_type().emplace_back(0)
blocks.q_aux().emplace_back(0)
blocks.pad_additional()
def evaluate_non_native_field_multiplication(range_constrain_quotient_and_remainder=True):
LIMB_SHIFT_2 = LIMB_SHIFT^2
LIMB_SHIFT_3 = LIMB_SHIFT^3
LIMB_RSHIFT = 1 / LIMB_SHIFT
LIMB_RSHIFT_2 = LIMB_RSHIFT^2
neg_modulus = [var('NEG_MOD' + str(i)) for i in range(4)]
input_a = [var('input_a' + str(i)) for i in range(4)]
input_b = [var('input_b' + str(i)) for i in range(4)]
input_q = [var('input_q' + str(i)) for i in range(5)]
input_r = [var('input_r' + str(i)) for i in range(5)]
lo_0_idx = var('lo_0_idx')
lo_1_idx = var('lo_1_idx')
hi_0_idx = var('hi_0_idx')
hi_1_idx = var('hi_1_idx')
hi_2_idx = var('hi_2_idx')
hi_3_idx = var('hi_3_idx')
blocks = Block()
# omits range checks
if range_constrain_quotient_and_remainder:
# check that input_r[4]: the Fp limb is computed correctly
create_big_add_gate(blocks, {
'a': input_r[1],
'b': input_r[2],
'c': input_r[3],
'd': input_r[4],
'a_scaling': LIMB_SHIFT,
'b_scaling': LIMB_SHIFT_2,
'c_scaling': LIMB_SHIFT_3,
'd_scaling': -1,
'const_scaling': 0
}, True)
create_dummy_gate(blocks, zero_idx, zero_idx, zero_idx, input_r[0])
# check that input_q[4]: the Fp limb is computed correctly
create_big_add_gate(blocks, {
'a': input_q[1],
'b': input_q[2],
'c': input_q[3],
'd': input_q[4],
'a_scaling': LIMB_SHIFT,
'b_scaling': LIMB_SHIFT_2,
'c_scaling': LIMB_SHIFT_3,
'd_scaling': -1,
'const_scaling': 0
}, True)
create_dummy_gate(blocks, zero_idx, zero_idx, zero_idx, input_q[0])
blocks.populate_wires(input_q[0], input_q[1], input_r[1], lo_1_idx)
blocks.q_m().emplace_back(0)
blocks.q_1().emplace_back(neg_modulus[0] + neg_modulus[1] * LIMB_SHIFT)
blocks.q_2().emplace_back(neg_modulus[0] * LIMB_SHIFT)
blocks.q_3().emplace_back(-LIMB_SHIFT)
blocks.q_c().emplace_back(0)
blocks.q_arith().emplace_back(2)
blocks.q_4().emplace_back(-(LIMB_SHIFT^2))
blocks.q_delta_range().emplace_back(0)
blocks.q_lookup_type().emplace_back(0)
blocks.q_elliptic().emplace_back(0)
blocks.q_aux().emplace_back(0)
blocks.pad_additional()
blocks.populate_wires(input_a[1], input_b[1], input_r[0], lo_0_idx)
apply_aux_selectors(blocks, AUX_NON_NATIVE_FIELD_1)
blocks.populate_wires(input_a[0], input_b[0], input_a[3], input_b[3])
apply_aux_selectors(blocks, AUX_NON_NATIVE_FIELD_2)
blocks.populate_wires(input_a[2], input_b[2], input_r[3], hi_0_idx)
apply_aux_selectors(blocks, AUX_NON_NATIVE_FIELD_3)
blocks.populate_wires(input_a[1], input_b[1], input_r[2], hi_1_idx)
apply_aux_selectors(blocks, AUX_NONE)
create_big_add_gate(blocks, {
'a': input_q[2],
'b': input_q[3],
'c': lo_1_idx,
'd': hi_1_idx,
'a_scaling': -neg_modulus[1] * LIMB_SHIFT - neg_modulus[0],
'b_scaling': -neg_modulus[0] * LIMB_SHIFT,
'c_scaling': -1,
'd_scaling': -1,
'const_scaling': 0
}, True)
create_big_add_gate(blocks, {
'a': hi_3_idx,
'b': input_q[0],
'c': input_q[1],
'd': hi_2_idx,
'a_scaling': -1,
'b_scaling': neg_modulus[3] * LIMB_RSHIFT + neg_modulus[2] * LIMB_RSHIFT_2,
'c_scaling': neg_modulus[2] * LIMB_RSHIFT + neg_modulus[1] * LIMB_RSHIFT_2,
'd_scaling': LIMB_RSHIFT_2,
'const_scaling': 0
}, False)
return blocks.expr()
def process_non_native_field_multiplications():
blocks = Block()
input_a = [var('input_a' + str(i)) for i in range(4)]
input_b = [var('input_b' + str(i)) for i in range(4)]
lo_0 = var('lo_0_idx')
hi_0 = var('hi_0_idx')
hi_1 = var('hi_1_idx')
blocks.populate_wires(input_a[1], input_b[1], zero_idx, lo_0)
apply_aux_selectors(blocks, AUX_NON_NATIVE_FIELD_1)
blocks.populate_wires(input_a[0], input_b[0], input_a[3], input_b[3])
apply_aux_selectors(blocks, AUX_NON_NATIVE_FIELD_2)
blocks.populate_wires(input_a[2], input_b[2], zero_idx, hi_0)
apply_aux_selectors(blocks, AUX_NON_NATIVE_FIELD_3)
blocks.populate_wires(input_a[1], input_b[1], zero_idx, hi_1)
apply_aux_selectors(blocks, AUX_NONE)
return blocks.expr()
def accumulate_field_t():
accumulator = [
(var(f'w_{i}'), var(f'm_{i}'), var(f'c_{i}')) for i in range(3 * 3)
]
num_gates = len(accumulator) // 3
blocks = Block()
accumulating_total = var(f'acc_0')
for i in range(num_gates):
create_big_add_gate(blocks, {
'a': accumulator[3 * i][0],
'b': accumulator[3 * i + 1][0],
'c': accumulator[3 * i + 2][0],
'd': accumulating_total,
'a_scaling': accumulator[3 * i][1],
'b_scaling': accumulator[3 * i + 1][1],
'c_scaling': accumulator[3 * i + 2][1],
'd_scaling': -1,
'const_scaling': accumulator[3 * i][2] + accumulator[3 * i + 1][2] + accumulator[3 * i + 2][2]
}, i != num_gates - 1)
accumulating_total = var(f'acc_{i+1}')
return blocks.expr()
def evaluate_non_native_field_addition():
'''
No multiplication constant on the prime limb
'''
x_mulconst0 = var('x_mulconst0')
x_mulconst1 = var('x_mulconst1')
x_mulconst2 = var('x_mulconst2')
x_mulconst3 = var('x_mulconst3')
y_mulconst0 = var('y_mulconst0')
y_mulconst1 = var('y_mulconst1')
y_mulconst2 = var('y_mulconst2')
y_mulconst3 = var('y_mulconst3')
addconst0 = var('addconst0')
addconst1 = var('addconst1')
addconst2 = var('addconst2')
addconst3 = var('addconst3')
addconstp = var('addconstp')
z_0 = var('z_0')
z_1 = var('z_1')
z_2 = var('z_2')
z_3 = var('z_3')
z_p = var('z_p')
x_0 = var('x_0')
x_1 = var('x_1')
x_2 = var('x_2')
x_3 = var('x_3')
x_p = var('x_p')
y_0 = var('y_0')
y_1 = var('y_1')
y_2 = var('y_2')
y_3 = var('y_3')
y_p = var('y_p')
block = Block()
block.populate_wires(y_p, x_0, y_0, x_p)
block.populate_wires(z_p, x_1, y_1, z_0)
block.populate_wires(x_2, y_2, z_2, z_1)
block.populate_wires(x_3, y_3, z_3, zero_idx)
block.q_m().emplace_back(addconstp)
block.q_1().emplace_back(0)
block.q_2().emplace_back(-x_mulconst0 * 2)
block.q_3().emplace_back(-y_mulconst0 * 2)
block.q_4().emplace_back(0)
block.q_c().emplace_back(-addconst0 * 2)
block.q_arith().emplace_back(3)
block.pad_additional()
block.q_m().emplace_back(0)
block.q_1().emplace_back(0)
block.q_2().emplace_back(-x_mulconst1)
block.q_3().emplace_back(-y_mulconst1)
block.q_4().emplace_back(0)
block.q_c().emplace_back(-addconst1)
block.q_arith().emplace_back(2)
block.pad_additional()
block.q_m().emplace_back(0)
block.q_1().emplace_back(-x_mulconst2)
block.q_2().emplace_back(-y_mulconst2)
block.q_3().emplace_back(1)
block.q_4().emplace_back(0)
block.q_c().emplace_back(-addconst2)
block.q_arith().emplace_back(1)
block.pad_additional()
block.q_m().emplace_back(0)
block.q_1().emplace_back(-x_mulconst3)
block.q_2().emplace_back(-y_mulconst3)
block.q_3().emplace_back(1)
block.q_4().emplace_back(0)
block.q_c().emplace_back(-addconst3)
block.q_arith().emplace_back(1)
block.pad_additional()
return block.expr()
def print_exprs(fn):
exprs = fn()
for expr in exprs:
print(expr)
# print_exprs(process_non_native_field_multiplications)
# print_exprs(evaluate_non_native_field_multiplication)
# print_exprs(accumulate_field_t)
print_exprs(evaluate_non_native_field_addition)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment