Skip to content

Instantly share code, notes, and snippets.

@smlu
Created July 29, 2022 15:05
Show Gist options
  • Save smlu/13b60e3544ea1272e1f6dc49de8c5fde to your computer and use it in GitHub Desktop.
Save smlu/13b60e3544ea1272e1f6dc49de8c5fde to your computer and use it in GitHub Desktop.
RSASSA-PSS signature verification script
# Author: Crt Vavros
# MIT license
# Script implements RSASSA-PSS signature verification scheme as specified in RFC 8017 sec. 8.1.2
# https://datatracker.ietf.org/doc/html/rfc8017#section-8.1.2
# RSASSA-PSS MGF1 signature verification
# positional arguments:
# n RSA public key modulus as hex string
# e RSA public key exponent as hex string
# M Signed message as hex string
# S RSASSA-PSS signature as hex string
# sLen The length of salt
# options:
# -h, --help show this help message and exit
# --hash HASH Hash function which was used to produce signature (default: sha256)
# --verbose, --no-verbose
import argparse, hashlib, math, sys
__verbose = False
def print_error(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
def log(*args, **kwargs):
if __verbose:
print(*args, **kwargs)
def get_hash_constructor(name:str):
if name in {'SHA1', 'sha1'}:
return hashlib.sha1
elif name in {'MD5', 'md5'}:
return hashlib.md5
elif name in {'SHA224', 'sha224'}:
return hashlib.sha224
elif name in {'SHA256', 'sha256'}:
return hashlib.sha256
elif name in {'SHA384', 'sha384'}:
return hashlib.sha384
elif name in {'SHA512', 'sha512'}:
return hashlib.sha512
elif name in {'blake2b'}:
return hashlib.blake2b
elif name in {'blake2s'}:
return hashlib.blake2s
elif name in {'sha3_224'}:
return hashlib.sha3_224
elif name in {'sha3_256'}:
return hashlib.sha3_256
elif name in {'sha3_384'}:
return hashlib.sha3_384
elif name in {'sha3_512'}:
return hashlib.sha3_512
elif name in {'shake_128'}:
return hashlib.shake_128
elif name in {'shake_256'}:
return hashlib.shake_256
raise ValueError('unsupported hash type ' + name)
def os2i(x: bytes) -> int:
return int.from_bytes(x, byteorder='big')
def i2os(x: int, xlen: int) -> bytes:
return x.to_bytes(xlen, byteorder='big')
def rsavp1(n: int, e: int, s: int) -> int:
if not (0 <= s <= n - 1):
raise ValueError("signature representative out of range")
return pow(s, e, n)
def mgf1(mgfSeed:bytes, maskLen:int, Hash=hashlib.sha1) -> bytes:
hLen = Hash().digest_size
if maskLen > (hLen << 32):
raise ValueError("mask too long")
T = b""
counter = 0
while len(T) < maskLen:
C = i2os(counter, 4)
T += Hash(mgfSeed + C).digest()
counter += 1
return T[:maskLen]
def emsa_pss_verify(M:bytes, EM:bytes, emBits:int, Hash = hashlib.sha256, MGF = mgf1, sLen = 20) -> str:
emLen = math.ceil(emBits/8)
hLen = Hash().digest_size
mHash = Hash(M).digest()
log(f'\nemLen={emLen}')
log(f'hLen={hLen}')
log(f'\nmHash={mHash.hex()}')
if emLen < hLen + sLen + 2:
return 'inconsistent'
if EM[-1] != 0xbc:
return 'inconsistent'
maskedDB = EM[:emLen - hLen - 1]
log(f'\nmaskedDB={maskedDB.hex()}')
H = EM[emLen - hLen - 1:emLen - 1]
log(f'\nH={H.hex()}')
top_bitmask = 0xff >> ((emLen * 8) - emBits)
log(f'\ntop_bitmask={top_bitmask}')
if (EM[0] & 0xff) != (EM[0] & top_bitmask):
return 'inconsistent'
dbMask = MGF(H, emLen - hLen - 1, Hash)
log(f'\ndbMask={dbMask.hex()}')
DB = bytearray(i2os(os2i(maskedDB) ^ os2i(dbMask), emLen - hLen - 1))
log(f'\nDB={DB.hex()}')
DB[0] &= top_bitmask
for i in range(0, len(DB) - sLen - 2):
if DB[i] != 0: return 'inconsistent'
if DB[len(DB) - sLen - 1] != 0x01:
return 'inconsistent'
salt = DB[-sLen::] if sLen > 0 else bytes(0)
log(f'\nsalt={salt.hex()}')
M2 = bytes(8) + mHash + salt
log(f'\nM2={M2.hex()}')
H2 = Hash(M2).digest()
log(f'\nH2={H2.hex()}')
return 'consistent' if H == H2 else 'inconsistent'
def rsassa_pss_verify(n:bytes, e:bytes, M:bytes, S:bytes, Hash = hashlib.sha256, MGF = mgf1, sLen = 20) -> str:
try:
if len(S) != len(n):
return 'invalid signature'
n = os2i(n)
e = os2i(e)
S = os2i(S)
m = rsavp1(n,e,S)
modBits = n.bit_length()
emLen = math.ceil ((modBits - 1)/8)
EM = i2os(m, emLen)
log(f'\nEM={EM.hex()}')
Result = emsa_pss_verify(M, EM, modBits - 1, Hash, MGF, sLen)
return 'valid signature' if Result == 'consistent' else 'invalid signature'
except Exception as e:
if __verbose:
print_error(f'Error: {str(e)}')
return 'invalid signature'
def main(argv, arc):
try:
parser = argparse.ArgumentParser(description='RSASSA-PSS MGF1 signature verification',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('n', metavar='n', type=str,
help='RSA public key modulus as hex string')
parser.add_argument('e', metavar='e', type=str,
help='RSA public key exponent as hex string')
parser.add_argument('M', metavar='M', type=str,
help='Signed message as hex string')
parser.add_argument('S', metavar='S', type=str,
help='RSASSA-PSS signature as hex string')
parser.add_argument('sLen', metavar='sLen', type=int,
help='The length of salt')
parser.add_argument('--hash', type=str, default='sha256',
help='Hash function which was used to produce signature')
parser.add_argument('--verbose', action=argparse.BooleanOptionalAction)
args = parser.parse_args()
global __verbose
__verbose = args.verbose
n = bytes.fromhex(args.n)
e = bytes.fromhex(args.e)
M = bytes.fromhex(args.M)
S = bytes.fromhex(args.S)
h = get_hash_constructor(args.hash)
print(f'\n Result: {rsassa_pss_verify(n,e,M,S, Hash=h, MGF=mgf1,sLen=args.sLen)}')
except Exception as e:
print_error(f'Error: {str(e)}')
sys.exit(1)
if __name__ == '__main__':
main(sys.argv, len(sys.argv))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment