Last active
April 29, 2024 05:42
-
-
Save mjsabby/b87b629fa902cd641b955892b1ab9604 to your computer and use it in GitHub Desktop.
Verify RS256 using more primitive operations to find where time is spent
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// <PackageReference Include="SymCryptNative" Version="103.4.2" /> | |
// <PackageReference Include="BenchmarkDotNet" Version="0.13.12" /> | |
namespace VerifyRS256JWTSignature | |
{ | |
using BenchmarkDotNet.Attributes; | |
using BenchmarkDotNet.Running; | |
using Microsoft.IdentityModel.Tokens; | |
using System; | |
using System.IdentityModel.Tokens.Jwt; | |
using System.Linq; | |
using System.Runtime.CompilerServices; | |
using System.Runtime.InteropServices; | |
using System.Security.Cryptography; | |
using System.Text; | |
using unsafe SymCryptRsaCoreEncFunctionPointer = delegate* unmanaged[Stdcall, SuppressGCTransition]<nint, byte*, nint, SYMCRYPT_NUMBER_FORMAT, int, byte*, nint, byte*, nint, int>; | |
public class Benchmarks : IDisposable | |
{ | |
private readonly JwtSecurityTokenHandler jwtSecurityTokenHandler; | |
private readonly RSA rsa; | |
private readonly string jwt; | |
private readonly TokenValidationParameters validationParameters; | |
private readonly SymCryptKeyHandle publicKey; | |
private readonly unsafe SymCryptRsaCoreEncFunctionPointer encrypt; | |
public Benchmarks() | |
{ | |
var dll = Program.GetDllPointer(NativeMethods.SymCrypt); | |
unsafe | |
{ | |
encrypt = (SymCryptRsaCoreEncFunctionPointer)Program.GetDllFunction(dll, "SymCryptRsaCoreEnc"); | |
} | |
byte[] modulus; | |
byte[] exponent; | |
rsa = RSA.Create(); | |
var pem = @"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu1SU1LfVLPHCozMxH2Mo4lgOEePzNm0tRgeLezV6ffAt0gunVTLw7onLRnrq0/IzW7yWR7QkrmBL7jTKEn5u+qKhbwKfBstIs+bMY2Zkp18gnTxKLxoS2tFczGkPLPgizskuemMghRniWaoLcyehkd3qqGElvW/VDL5AaWTg0nLVkjRo9z+40RQzuVaE8AkAFmxZzow3x+VJYKdjykkJ0iT9wCS0DRTXu269V264Vf/3jvredZiKRkgwlL9xNAwxXFg0x/XFw005UWVRIkdgcKWTjpBP2dPwVZ4WWC+9aGVd+Gyn1o0CLelf4rEjGoXbAAEgAqeGUxrcIlbjXfbcmwIDAQAB"; | |
rsa.ImportSubjectPublicKeyInfo(Convert.FromBase64String(pem), out _); | |
var ec = rsa.ExportParameters(false); | |
(modulus, exponent) = (ec.Modulus, ec.Exponent); | |
SYMCRYPT_RSA_PARAMS rsaParams = new() | |
{ | |
Version = 1, | |
BitsOfModulus = modulus.Length * 8, | |
NumOfPrimes = 0, | |
NumOfPublicExponents = 1 | |
}; | |
const int SYMCRYPT_FLAG_RSAKEY_ENCRYPT = 0x2000; | |
var key = NativeMethods.SymCryptRsakeyAllocate(in rsaParams, flags: 0); | |
var err = NativeMethods.SymCryptRsakeySetValue(modulus, modulus.Length, exponent, numberOfExponents: 1, 0, 0, 0, SYMCRYPT_NUMBER_FORMAT.MSB_FIRST, SYMCRYPT_FLAG_RSAKEY_ENCRYPT, key); | |
if (err != 0) | |
{ | |
throw new InvalidOperationException("Failed to set RSA key value"); | |
} | |
publicKey = new SymCryptKeyHandle(rsaParams.BitsOfModulus, key); | |
jwt = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ"; | |
if (!BenchmarkVerifySHA256Signature()) | |
{ | |
throw new Exception("Invalid signature"); | |
} | |
validationParameters = new TokenValidationParameters | |
{ | |
ValidateLifetime = false, | |
ValidateActor = false, | |
ValidateIssuer = false, | |
ValidateAudience = false, | |
ValidateTokenReplay = false, | |
ValidateSignatureLast = false, | |
ValidateWithLKG = false, | |
ValidateIssuerSigningKey = true, | |
IssuerSigningKey = new RsaSecurityKey(ec) | |
}; | |
jwtSecurityTokenHandler = new JwtSecurityTokenHandler(); | |
} | |
[Benchmark] | |
public unsafe bool BenchmarkVerifySHA256Signature() | |
{ | |
var jwtParts = jwt.Split('.'); | |
var dataToVerify2 = Encoding.UTF8.GetBytes($"{jwtParts[0]}.{jwtParts[1]}"); | |
var signature2 = Base64UrlDecode(jwtParts[2]); | |
return Program.VerifySHA256Signature(encrypt, dataToVerify2, signature2, publicKey); | |
} | |
[Benchmark] | |
public void BenchmarkValidateJWTUsingIdentityModel() | |
{ | |
_ = jwtSecurityTokenHandler.ValidateToken(jwt, validationParameters, out _); | |
} | |
private static byte[] Base64UrlDecode(string base64UrlEncodedString) | |
{ | |
string output = base64UrlEncodedString; | |
output = output.Replace('-', '+').Replace('_', '/'); | |
switch (output.Length % 4) | |
{ | |
case 2: | |
output += "=="; | |
break; | |
case 3: | |
output += "="; | |
break; | |
default: | |
break; // should throw, but meh | |
} | |
return Convert.FromBase64String(output); | |
} | |
public void Dispose() | |
{ | |
Dispose(true); | |
GC.SuppressFinalize(this); | |
} | |
protected virtual void Dispose(bool disposing) | |
{ | |
publicKey.Dispose(); | |
} | |
} | |
internal static partial class NativeMethods | |
{ | |
public const string Kernel32 = "kernel32.dll"; | |
public const string SymCrypt = "symcrypt"; | |
[LibraryImport(Kernel32)] | |
public static partial nint LoadLibraryA(nint dllToLoad); | |
[LibraryImport(Kernel32)] | |
public static partial nint GetProcAddress(nint hModule, nint procedureName); | |
[LibraryImport(Kernel32)] | |
public static partial int GetLastError(); | |
[LibraryImport(Kernel32)] | |
[return: MarshalAs(UnmanagedType.Bool)] | |
public static partial bool FreeLibrary(nint hModule); | |
[LibraryImport(SymCrypt)] | |
public static partial int SymCryptRsakeyFree(nint hKey); | |
[LibraryImport(SymCrypt)] | |
public static partial nint SymCryptRsakeyAllocate(in SYMCRYPT_RSA_PARAMS rsaParams, int flags); | |
[LibraryImport(SymCrypt)] | |
public static partial int SymCryptRsaCoreEncScratchSpace(nint ptr); | |
[LibraryImport(SymCrypt)] | |
public static unsafe partial int SymCryptRsakeySetValue(ReadOnlySpan<byte> modulus, nint modulusSize, ReadOnlySpan<byte> exponent, int numberOfExponents, nint _, nint __, int numPrimes, SYMCRYPT_NUMBER_FORMAT format, int flags, nint rsaKey); | |
} | |
internal static partial class Program | |
{ | |
public static void Main() | |
{ | |
_ = new Benchmarks(); | |
BenchmarkRunner.Run<Benchmarks>(); | |
} | |
[MethodImpl(MethodImplOptions.AggressiveOptimization)] | |
public static unsafe bool VerifySHA256Signature(SymCryptRsaCoreEncFunctionPointer encrypt, ReadOnlySpan<byte> dataToVerify, ReadOnlySpan<byte> signature, SymCryptKeyHandle publicKey) | |
{ | |
// https://datatracker.ietf.org/doc/html/rfc3447#section-9.2 | |
var pkcsPaddingWithSHA256DigestInfo = new ReadOnlySpan<byte>( | |
[ | |
0x00, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, | |
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x30, 0x31, 0x30, 0x0D, | |
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20 | |
]); | |
var sha256Hash = SHA256.HashData(dataToVerify); | |
var sha256HashLength = sha256Hash.Length; | |
// RS256 uses SHA-256 which is 32 bytes | |
if (sha256HashLength != 32) | |
{ | |
return false; | |
} | |
var bitLength = publicKey.BitLength; | |
// RS256 requires 2048 bits to be secure | |
if (bitLength < 2048) | |
{ | |
return false; | |
} | |
int roundedKeySizeInBytes = (bitLength + 7) >> 3; | |
// check if multiple of 16 bytes, i.e. 2048, 3072, 4096 bit length, etc. | |
if (((roundedKeySizeInBytes) & 0xF) != 0) | |
{ | |
return false; | |
} | |
if (signature.Length != roundedKeySizeInBytes) | |
{ | |
return false; | |
} | |
int scratchSize = NativeMethods.SymCryptRsaCoreEncScratchSpace(publicKey.Pointer); | |
const int SYMCRYPT_ASYM_ALIGN_VALUE = 32; | |
const int extra = SYMCRYPT_ASYM_ALIGN_VALUE - 1; | |
Span<byte> dest = stackalloc byte[roundedKeySizeInBytes + extra]; | |
Span<byte> scratch = stackalloc byte[scratchSize + extra]; | |
fixed (byte* scratchPtr = &scratch[0]) | |
fixed (byte* destPtr = &dest[0]) | |
fixed (byte* signaturePtr = &signature[0]) | |
{ | |
int destOffset = AlignPointerForSymCrypt(destPtr, SYMCRYPT_ASYM_ALIGN_VALUE); | |
byte* destAligned = destPtr + destOffset; | |
dest = dest.Slice(destOffset, roundedKeySizeInBytes); | |
int scratchOffset = AlignPointerForSymCrypt(scratchPtr, SYMCRYPT_ASYM_ALIGN_VALUE); | |
byte* scratchAligned = scratchPtr + scratchOffset; | |
scratch = scratch.Slice(scratchOffset, scratchSize); | |
if (encrypt(publicKey.Pointer, signaturePtr, signature.Length, SYMCRYPT_NUMBER_FORMAT.MSB_FIRST, 0, destAligned, roundedKeySizeInBytes, scratchAligned, scratchSize) != 0) | |
{ | |
return false; | |
} | |
} | |
if (!dest[^sha256HashLength..].SequenceEqual(sha256Hash)) | |
{ | |
return false; | |
} | |
if (!dest[..(roundedKeySizeInBytes - sha256HashLength)].SequenceEqual(pkcsPaddingWithSHA256DigestInfo)) | |
{ | |
return false; | |
} | |
return true; | |
} | |
public static nint GetDllPointer(string path) | |
{ | |
nint dllStringPtr = 0; | |
try | |
{ | |
dllStringPtr = Marshal.StringToHGlobalAnsi(path); | |
return NativeMethods.LoadLibraryA(dllStringPtr); | |
} | |
finally | |
{ | |
if (dllStringPtr != 0) | |
{ | |
Marshal.FreeHGlobal(dllStringPtr); | |
} | |
} | |
} | |
public static nint GetDllFunction(nint dll, string functionName) | |
{ | |
nint functionNamePtr = 0; | |
try | |
{ | |
functionNamePtr = Marshal.StringToHGlobalAnsi(functionName); | |
return NativeMethods.GetProcAddress(dll, functionNamePtr); | |
} | |
finally | |
{ | |
if (functionNamePtr != 0) | |
{ | |
Marshal.FreeHGlobal(functionNamePtr); | |
} | |
} | |
} | |
private static byte[] Base64UrlDecode(string base64UrlEncodedString) | |
{ | |
string output = base64UrlEncodedString; | |
output = output.Replace('-', '+').Replace('_', '/'); | |
switch (output.Length % 4) | |
{ | |
case 2: | |
output += "=="; | |
break; | |
case 3: | |
output += "="; | |
break; | |
default: | |
break; // should throw, but meh | |
} | |
return Convert.FromBase64String(output); | |
} | |
private static unsafe int AlignPointerForSymCrypt(byte* ptr, int alignment) | |
{ | |
long baseAddress = (long)ptr; | |
long offset = (alignment - (baseAddress % alignment)) % alignment; | |
return (int)offset; | |
} | |
} | |
internal enum SYMCRYPT_NUMBER_FORMAT | |
{ | |
LSB_FIRST = 1, | |
MSB_FIRST = 2 | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
internal struct SYMCRYPT_RSA_PARAMS | |
{ | |
public int Version { get; set; } | |
public int BitsOfModulus { get; set; } | |
public int NumOfPrimes { get; set; } | |
public int NumOfPublicExponents { get; set; } | |
} | |
internal sealed class SymCryptKeyHandle(int bitLength, nint handle) : IDisposable | |
{ | |
private bool _disposed; | |
~SymCryptKeyHandle() | |
{ | |
Dispose(false); | |
} | |
public nint Pointer { get; private set; } = handle; | |
public int BitLength { get; private set; } = bitLength; | |
public void Dispose() | |
{ | |
Dispose(true); | |
GC.SuppressFinalize(this); | |
} | |
private void Dispose(bool disposing) | |
{ | |
if (!_disposed) | |
{ | |
if (disposing) | |
{ | |
// Dispose managed resources | |
} | |
if (Pointer != 0) | |
{ | |
_ = NativeMethods.SymCryptRsakeyFree(Pointer); | |
Pointer = 0; | |
} | |
_disposed = true; | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment