Skip to content

Instantly share code, notes, and snippets.

@attilaolah
Created April 19, 2021 14:26
Show Gist options
  • Save attilaolah/4bb5d5de607bb42922f48436b64f83c5 to your computer and use it in GitHub Desktop.
Save attilaolah/4bb5d5de607bb42922f48436b64f83c5 to your computer and use it in GitHub Desktop.
MASS plot file verifier
"""MASS database format verifier."""
import hashlib
import os.path
import sys
USAGE = """Usage:
{} proof_dir/*.massdb
"""
HEADER_LEN = 4096
HEADER_CODE = hashlib.sha256(hashlib.sha256(b"MASSDB").digest()).digest()
VALID_BIT_LEN = tuple(range(24, 42, 2))
# TODO: This should be (0, 1), but somehow the encoded files seem to be
# different: A files contain 1, and B files contain 2. Not sure what's up here.
VALID_DB_TYPE = (1, 2) # (A, B) db types
def main(args: list[str]) -> None:
"""Verify all MassDB files in argv."""
if len(args) == 1:
print(USAGE.format(args[0]))
return
if not args:
raise ValueError("Empty args list!")
for path in args[1:]:
filename = os.path.basename(path)
print(f"Checking {filename}:", end=" ")
errors = verify_massdb_file(path)
if errors:
print("FAIL! Errors:")
for error in errors:
print(f" - {error}")
else:
print("PASS")
def verify_massdb_file(path: str, full_check: bool = True) -> list[str]:
"""Verify a MassDB file."""
with open(path, "rb") as dbf:
data = dbf.read(HEADER_LEN)
size = len(data)
if size != HEADER_LEN:
return [f"File is too small: expected at least {HEADER_LEN}b, found: "
"{size}b."]
errors: list[str] = []
pos, size = 0, len(HEADER_CODE)
header_code = data[pos:pos+size]
if header_code != HEADER_CODE:
exp = "".join("{:X}".format(c) for c in HEADER_CODE)
got = "".join("{:X}".format(c) for c in header_code)
errors.append(f"Bad file code at [{pos}:{pos+size}]: expected {exp}, "
f"found: {got}.")
pos, size = pos + size, 8
version = int.from_bytes(data[pos:pos+size], byteorder="little")
if version != 1:
errors.append(f"Bad file version at [{pos}:{pos+size}]: expected 1, "
"found: {version}.")
pos, size = pos + size, 1
bit_len = int.from_bytes(data[pos:pos+size], byteorder="little")
if bit_len not in VALID_BIT_LEN:
errors.append(f"Bad bit length at [{pos}:{pos+size}]: expected one of "
f"{VALID_BIT_LEN}, found: {bit_len}.")
pos, size = pos + size, 1
db_type = int.from_bytes(data[pos:pos+size], byteorder="little")
if db_type not in VALID_DB_TYPE:
errors.append(f"Bad database type at [{pos}:{pos+size}]: expected one "
f"of {VALID_DB_TYPE}, found: {db_type}.")
exp = 2 ** (bit_len - 1)
pos, size = pos + size, 8
checkpoint = int.from_bytes(data[pos:pos+size], byteorder="little")
if checkpoint != exp:
errors.append(f"Incomplete file: at position [{pos}:{pos+size}], "
f"expected checkpoint {exp}, found: {checkpoint}")
pos, size = pos + size, 32
pub_key_hash = data[pos:pos+size]
pos, size = pos + size, 33
pub_key = data[pos:pos+size]
errors.extend(verify_pub_key_hash(pub_key, pub_key_hash))
pos += size
if sum(data[pos:HEADER_LEN]):
errors.append(f"Found non-zero bytes in padding [{pos}:{HEADER_LEN}].")
if full_check:
with open(path, "rb") as dbf:
data = dbf.read()[HEADER_LEN:]
if db_type == 1:
errors += check_type_a(data)
if db_type == 2:
errors += check_type_b(data, pub_key_hash, bit_len)
return errors
def check_type_a(data: bytes) -> list[str]:
"""Check Type A table contents."""
raise NotImplementedError("TODO!")
def check_type_b(data: bytes, pub_key_hash: bytes, bit_len: int) -> list[str]:
"""Check Type B table contents."""
errors = []
bit_mask = (1 << bit_len) - 1
byte_len = (bit_len + 7) // 8
row_count = (1 << bit_len)
prefix = hashlib.sha256(b"MASS").digest() + pub_key_hash
exp = (row_count * 2) * byte_len
if len(data) != exp:
errors.append(f"Incorrect data length, expected {exp} bytes, found: "
f"{len(data)}.")
tbl_a = plot_table_a(pub_key_hash, bit_len)
# The generated table should have the correct number of rows:
assert len(tbl_a) == row_count
tbl_b = [0] * row_count
for row, val in enumerate(tbl_a):
row_p = row ^ bit_mask
val_p = tbl_a[row_p]
val_z = hashlib.sha256(
prefix +
val.to_bytes(byte_len, byteorder="little") +
val_p.to_bytes(byte_len, byteorder="little")
).digest()[:byte_len]
z_key = int.from_bytes(val_z, byteorder="little")
tbl_b[z_key] = val
tbl_b[z_key + 1] = val_p
continue
for row, val in enumerate(tbl_b):
exp = val.to_bytes(byte_len, byteorder="little")
got = data[row*byte_len:(row+1)*byte_len]
if exp != got:
print(f"Table B @ row {row} expected {exp}, got: {got}.")
return errors
def plot_table_a(pub_key_hash: bytes, bit_len: int) -> list[int]:
"""Reconstruct a Type A table from scratch."""
bit_mask = (1 << bit_len) - 1
byte_len = (bit_len + 7) // 8
half_count = (1 << (bit_len - 1))
prefix = hashlib.sha256(b"MASS").digest() + pub_key_hash
tbl_a: list[tuple[bytes, bytes]] = []
tbl_a = [0] * half_count * 2
for row in range(half_count):
# Calculate for x = row:
val = row.to_bytes(byte_len, byteorder="little")
key = hashlib.sha256(prefix + val).digest()[:byte_len]
tbl_a[int.from_bytes(key, byteorder="little")] = row
# Calculate for x' = ~x
row ^= bit_mask
val = row.to_bytes(byte_len, byteorder="little")
key = hashlib.sha256(prefix + val).digest()[:byte_len]
tbl_a[int.from_bytes(key, byteorder="little")] = row
return tbl_a
def verify_pub_key_hash(pub_key: bytes, pub_key_hash: bytes) -> list[str]:
"""Verify the public key hash."""
checksum = hashlib.sha256(hashlib.sha256(pub_key).digest()).digest()
if checksum != pub_key_hash:
exp = "".join("{:X}".format(c) for c in pub_key_hash)
got = "".join("{:X}".format(c) for c in checksum)
return [f"Bad public key hash, expected: {exp}, got: {got}."]
return []
if __name__ == "__main__":
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment