Created
July 29, 2022 15:05
-
-
Save smlu/13b60e3544ea1272e1f6dc49de8c5fde to your computer and use it in GitHub Desktop.
RSASSA-PSS signature verification script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Crt Vavros | |
# MIT license | |
# Script implements RSASSA-PSS signature verification scheme as specified in RFC 8017 sec. 8.1.2 | |
# https://datatracker.ietf.org/doc/html/rfc8017#section-8.1.2 | |
# RSASSA-PSS MGF1 signature verification | |
# positional arguments: | |
# n RSA public key modulus as hex string | |
# e RSA public key exponent as hex string | |
# M Signed message as hex string | |
# S RSASSA-PSS signature as hex string | |
# sLen The length of salt | |
# options: | |
# -h, --help show this help message and exit | |
# --hash HASH Hash function which was used to produce signature (default: sha256) | |
# --verbose, --no-verbose | |
import argparse, hashlib, math, sys | |
__verbose = False | |
def print_error(*args, **kwargs): | |
print(*args, file=sys.stderr, **kwargs) | |
def log(*args, **kwargs): | |
if __verbose: | |
print(*args, **kwargs) | |
def get_hash_constructor(name:str): | |
if name in {'SHA1', 'sha1'}: | |
return hashlib.sha1 | |
elif name in {'MD5', 'md5'}: | |
return hashlib.md5 | |
elif name in {'SHA224', 'sha224'}: | |
return hashlib.sha224 | |
elif name in {'SHA256', 'sha256'}: | |
return hashlib.sha256 | |
elif name in {'SHA384', 'sha384'}: | |
return hashlib.sha384 | |
elif name in {'SHA512', 'sha512'}: | |
return hashlib.sha512 | |
elif name in {'blake2b'}: | |
return hashlib.blake2b | |
elif name in {'blake2s'}: | |
return hashlib.blake2s | |
elif name in {'sha3_224'}: | |
return hashlib.sha3_224 | |
elif name in {'sha3_256'}: | |
return hashlib.sha3_256 | |
elif name in {'sha3_384'}: | |
return hashlib.sha3_384 | |
elif name in {'sha3_512'}: | |
return hashlib.sha3_512 | |
elif name in {'shake_128'}: | |
return hashlib.shake_128 | |
elif name in {'shake_256'}: | |
return hashlib.shake_256 | |
raise ValueError('unsupported hash type ' + name) | |
def os2i(x: bytes) -> int: | |
return int.from_bytes(x, byteorder='big') | |
def i2os(x: int, xlen: int) -> bytes: | |
return x.to_bytes(xlen, byteorder='big') | |
def rsavp1(n: int, e: int, s: int) -> int: | |
if not (0 <= s <= n - 1): | |
raise ValueError("signature representative out of range") | |
return pow(s, e, n) | |
def mgf1(mgfSeed:bytes, maskLen:int, Hash=hashlib.sha1) -> bytes: | |
hLen = Hash().digest_size | |
if maskLen > (hLen << 32): | |
raise ValueError("mask too long") | |
T = b"" | |
counter = 0 | |
while len(T) < maskLen: | |
C = i2os(counter, 4) | |
T += Hash(mgfSeed + C).digest() | |
counter += 1 | |
return T[:maskLen] | |
def emsa_pss_verify(M:bytes, EM:bytes, emBits:int, Hash = hashlib.sha256, MGF = mgf1, sLen = 20) -> str: | |
emLen = math.ceil(emBits/8) | |
hLen = Hash().digest_size | |
mHash = Hash(M).digest() | |
log(f'\nemLen={emLen}') | |
log(f'hLen={hLen}') | |
log(f'\nmHash={mHash.hex()}') | |
if emLen < hLen + sLen + 2: | |
return 'inconsistent' | |
if EM[-1] != 0xbc: | |
return 'inconsistent' | |
maskedDB = EM[:emLen - hLen - 1] | |
log(f'\nmaskedDB={maskedDB.hex()}') | |
H = EM[emLen - hLen - 1:emLen - 1] | |
log(f'\nH={H.hex()}') | |
top_bitmask = 0xff >> ((emLen * 8) - emBits) | |
log(f'\ntop_bitmask={top_bitmask}') | |
if (EM[0] & 0xff) != (EM[0] & top_bitmask): | |
return 'inconsistent' | |
dbMask = MGF(H, emLen - hLen - 1, Hash) | |
log(f'\ndbMask={dbMask.hex()}') | |
DB = bytearray(i2os(os2i(maskedDB) ^ os2i(dbMask), emLen - hLen - 1)) | |
log(f'\nDB={DB.hex()}') | |
DB[0] &= top_bitmask | |
for i in range(0, len(DB) - sLen - 2): | |
if DB[i] != 0: return 'inconsistent' | |
if DB[len(DB) - sLen - 1] != 0x01: | |
return 'inconsistent' | |
salt = DB[-sLen::] if sLen > 0 else bytes(0) | |
log(f'\nsalt={salt.hex()}') | |
M2 = bytes(8) + mHash + salt | |
log(f'\nM2={M2.hex()}') | |
H2 = Hash(M2).digest() | |
log(f'\nH2={H2.hex()}') | |
return 'consistent' if H == H2 else 'inconsistent' | |
def rsassa_pss_verify(n:bytes, e:bytes, M:bytes, S:bytes, Hash = hashlib.sha256, MGF = mgf1, sLen = 20) -> str: | |
try: | |
if len(S) != len(n): | |
return 'invalid signature' | |
n = os2i(n) | |
e = os2i(e) | |
S = os2i(S) | |
m = rsavp1(n,e,S) | |
modBits = n.bit_length() | |
emLen = math.ceil ((modBits - 1)/8) | |
EM = i2os(m, emLen) | |
log(f'\nEM={EM.hex()}') | |
Result = emsa_pss_verify(M, EM, modBits - 1, Hash, MGF, sLen) | |
return 'valid signature' if Result == 'consistent' else 'invalid signature' | |
except Exception as e: | |
if __verbose: | |
print_error(f'Error: {str(e)}') | |
return 'invalid signature' | |
def main(argv, arc): | |
try: | |
parser = argparse.ArgumentParser(description='RSASSA-PSS MGF1 signature verification', | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('n', metavar='n', type=str, | |
help='RSA public key modulus as hex string') | |
parser.add_argument('e', metavar='e', type=str, | |
help='RSA public key exponent as hex string') | |
parser.add_argument('M', metavar='M', type=str, | |
help='Signed message as hex string') | |
parser.add_argument('S', metavar='S', type=str, | |
help='RSASSA-PSS signature as hex string') | |
parser.add_argument('sLen', metavar='sLen', type=int, | |
help='The length of salt') | |
parser.add_argument('--hash', type=str, default='sha256', | |
help='Hash function which was used to produce signature') | |
parser.add_argument('--verbose', action=argparse.BooleanOptionalAction) | |
args = parser.parse_args() | |
global __verbose | |
__verbose = args.verbose | |
n = bytes.fromhex(args.n) | |
e = bytes.fromhex(args.e) | |
M = bytes.fromhex(args.M) | |
S = bytes.fromhex(args.S) | |
h = get_hash_constructor(args.hash) | |
print(f'\n Result: {rsassa_pss_verify(n,e,M,S, Hash=h, MGF=mgf1,sLen=args.sLen)}') | |
except Exception as e: | |
print_error(f'Error: {str(e)}') | |
sys.exit(1) | |
if __name__ == '__main__': | |
main(sys.argv, len(sys.argv)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment