Skip to content

Instantly share code, notes, and snippets.

@muellerberndt
Created November 28, 2023 05:14
Show Gist options
  • Save muellerberndt/2eba23be5cc22c0294faef4689c7eb45 to your computer and use it in GitHub Desktop.
Save muellerberndt/2eba23be5cc22c0294faef4689c7eb45 to your computer and use it in GitHub Desktop.
// SPDX-License-Identifier: GPL-3.0
import "hardhat/console.sol";
pragma solidity >=0.8.2 <0.9.0;
contract Homework3 {
uint256 constant curve_order = 21888242871839275222246405745257275088548364400416034343698204186575808495617;
struct ECPoint {
uint256 x;
uint256 y;
}
ECPoint G1 = ECPoint(1,2);
/*
* @return the sum of two points of G1
*/
function plus(
ECPoint memory p1,
ECPoint memory p2
) internal view returns (ECPoint memory r) {
uint256[4] memory input;
input[0] = p1.x;
input[1] = p1.y;
input[2] = p2.x;
input[3] = p2.y;
bool success;
// solium-disable-next-line security/no-inline-assembly
assembly {
success := staticcall(sub(gas(), 2000), 6, input, 0xc0, r, 0x60)
// Use "invalid" to make gas estimation work
switch success case 0 { invalid() }
}
require(success, "pairing-add-failed");
}
/*
* @return r the product of a point on G1 and a scalar, i.e.
* p == p.scalar_mul(1) and p.plus(p) == p.scalar_mul(2) for all
* points p.
*/
function scalar_mul(ECPoint memory p, uint256 s) internal view returns (ECPoint memory r) {
uint256[3] memory input;
input[0] = p.x;
input[1] = p.y;
input[2] = s;
bool success;
// solium-disable-next-line security/no-inline-assembly
assembly {
success := staticcall(sub(gas(), 2000), 7, input, 0x80, r, 0x60)
// Use "invalid" to make gas estimation work
switch success case 0 { invalid() }
}
require(success, "pairing-mul-failed");
}
function modExp(
uint256 _b,
uint256 _e,
uint256 _m
) private returns (uint256 result) {
assembly {
// Free memory pointer
let pointer := mload(0x40)
// Define length of base, exponent and modulus. 0x20 == 32 bytes
mstore(pointer, 0x20)
mstore(add(pointer, 0x20), 0x20)
mstore(add(pointer, 0x40), 0x20)
// Define variables base, exponent and modulus
mstore(add(pointer, 0x60), _b)
mstore(add(pointer, 0x80), _e)
mstore(add(pointer, 0xa0), _m)
// Store the result
let value := mload(0xc0)
// Call the precompiled contract 0x05 = bigModExp
if iszero(call(not(0), 0x05, 0, pointer, 0xc0, value, 0x20)) {
revert(0, 0)
}
result := mload(value)
}
}
function rationalAdd(ECPoint calldata A, ECPoint calldata B, uint256 num, uint256 den) public returns (bool verified) {
// return true if the prover knows two numbers that add up to num/den
uint256 c = mulmod(num, modExp(den, curve_order - 2, curve_order), curve_order);
ECPoint memory sum = plus(A, B);
ECPoint memory C = scalar_mul(G1, c);
return sum.x == C.x && sum.y == C.y;
}
function matmul(uint256[] calldata matrix,
uint256 n, // n x n for the matrix
ECPoint[] calldata s, // n elements
uint256[] calldata o // n elements
) public view returns (bool verified) {
// revert if dimensions don't make sense or the matrices are empty
require(matrix.length == n * n, "Invalid matrix size");
require(s.length == n, "Invalid vector size");
require(o.length == n, "Invalid output vector size");
// return true if Ms == o elementwise. You need to do n equality checks.
// If you're lazy, you can hardcode n to 3, but it is suggested that you do this with a for loop
for (uint256 i = 0; i < n; i++) {
ECPoint memory result = scalar_mul(s[0], matrix[i * n]);
for (uint256 j = 1; j < n; j++) {
result = plus(result, scalar_mul(s[j], matrix[i * n + j]));
}
// compare with expected output in o
ECPoint memory expected = scalar_mul(G1, o[i]);
if (result.x != expected.x || result.y != expected.y ) {
return false;
}
}
return true;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment